mirror of
https://github.com/Stability-AI/stablediffusion.git
synced 2025-12-05 14:30:01 +01:00
no-ema in config, adapt noiseaugmtor
This commit is contained in:
@@ -16,6 +16,7 @@ model:
|
||||
conditioning_key: crossattn-adm
|
||||
scale_factor: 0.18215
|
||||
monitor: val/loss_simple_ema
|
||||
use_ema: False
|
||||
|
||||
embedder_config:
|
||||
target: ldm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
||||
|
||||
@@ -38,7 +38,7 @@ class ClassEmbedder(nn.Module):
|
||||
c = batch[key][:, None]
|
||||
if self.ucg_rate > 0. and not disable_dropout:
|
||||
mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
|
||||
c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1)
|
||||
c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1)
|
||||
c = c.long()
|
||||
c = self.embedding(c)
|
||||
return c
|
||||
@@ -58,18 +58,20 @@ def disabled_train(self, mode=True):
|
||||
|
||||
class FrozenT5Embedder(AbstractEncoder):
|
||||
"""Uses the T5 transformer encoder for text"""
|
||||
def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
|
||||
|
||||
def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77,
|
||||
freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
|
||||
super().__init__()
|
||||
self.tokenizer = T5Tokenizer.from_pretrained(version)
|
||||
self.transformer = T5EncoderModel.from_pretrained(version)
|
||||
self.device = device
|
||||
self.max_length = max_length # TODO: typical value?
|
||||
self.max_length = max_length # TODO: typical value?
|
||||
if freeze:
|
||||
self.freeze()
|
||||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
#self.train = disabled_train
|
||||
# self.train = disabled_train
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
@@ -93,6 +95,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
"pooled",
|
||||
"hidden"
|
||||
]
|
||||
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
|
||||
freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
@@ -111,7 +114,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
#self.train = disabled_train
|
||||
# self.train = disabled_train
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
@@ -119,7 +122,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
|
||||
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden")
|
||||
if self.layer == "last":
|
||||
z = outputs.last_hidden_state
|
||||
elif self.layer == "pooled":
|
||||
@@ -166,7 +169,7 @@ class ClipImageEmbedder(nn.Module):
|
||||
out = self.model.encode_image(self.preprocess(x))
|
||||
out = out.to(x.dtype)
|
||||
if self.ucg_rate > 0. and not no_dropout:
|
||||
out = torch.bernoulli((1.-self.ucg_rate) * torch.ones(out.shape[0], device=out.device))[:, None] * out
|
||||
out = torch.bernoulli((1. - self.ucg_rate) * torch.ones(out.shape[0], device=out.device))[:, None] * out
|
||||
return out
|
||||
|
||||
|
||||
@@ -175,10 +178,11 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
|
||||
Uses the OpenCLIP transformer encoder for text
|
||||
"""
|
||||
LAYERS = [
|
||||
#"pooled",
|
||||
# "pooled",
|
||||
"last",
|
||||
"penultimate"
|
||||
]
|
||||
|
||||
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
|
||||
freeze=True, layer="last"):
|
||||
super().__init__()
|
||||
@@ -218,7 +222,7 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
|
||||
x = self.model.ln_final(x)
|
||||
return x
|
||||
|
||||
def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
|
||||
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
|
||||
for i, r in enumerate(self.model.transformer.resblocks):
|
||||
if i == len(self.model.transformer.resblocks) - self.layer_idx:
|
||||
break
|
||||
@@ -236,6 +240,7 @@ class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
|
||||
"""
|
||||
Uses the OpenCLIP vision transformer encoder for images
|
||||
"""
|
||||
|
||||
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
|
||||
freeze=True, layer="pooled", antialias=True, ucg_rate=0.):
|
||||
super().__init__()
|
||||
@@ -278,7 +283,7 @@ class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
|
||||
def forward(self, image, no_dropout=False):
|
||||
z = self.encode_with_vision_transformer(image)
|
||||
if self.ucg_rate > 0. and not no_dropout:
|
||||
z = torch.bernoulli((1.-self.ucg_rate) * torch.ones(z.shape[0], device=z.device))[:, None] * z
|
||||
z = torch.bernoulli((1. - self.ucg_rate) * torch.ones(z.shape[0], device=z.device))[:, None] * z
|
||||
return z
|
||||
|
||||
def encode_with_vision_transformer(self, img):
|
||||
@@ -289,14 +294,15 @@ class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
|
||||
class FrozenCLIPT5Encoder(AbstractEncoder):
|
||||
def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
|
||||
clip_max_length=77, t5_max_length=77):
|
||||
super().__init__()
|
||||
self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
|
||||
self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
|
||||
print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
|
||||
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.")
|
||||
print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
|
||||
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params.")
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
@@ -309,17 +315,22 @@ class FrozenCLIPT5Encoder(AbstractEncoder):
|
||||
|
||||
from ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||
from ldm.modules.diffusionmodules.openaimodel import Timestep
|
||||
|
||||
|
||||
class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation):
|
||||
def __init__(self, *args, clip_stats_path, timestep_dim=256, **kwargs):
|
||||
def __init__(self, *args, clip_stats_path=None, timestep_dim=256, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
clip_mean, clip_std = torch.load(clip_stats_path, map_location="cpu")
|
||||
if clip_stats_path is None:
|
||||
clip_mean, clip_std = torch.zeros(timestep_dim), torch.ones(timestep_dim)
|
||||
else:
|
||||
clip_mean, clip_std = torch.load(clip_stats_path, map_location="cpu")
|
||||
self.register_buffer("data_mean", clip_mean[None, :], persistent=False)
|
||||
self.register_buffer("data_std", clip_std[None, :], persistent=False)
|
||||
self.time_embed = Timestep(timestep_dim)
|
||||
|
||||
def scale(self, x):
|
||||
# re-normalize to centered mean and unit variance
|
||||
x = (x - self.data_mean) * 1./self.data_std
|
||||
x = (x - self.data_mean) * 1. / self.data_std
|
||||
return x
|
||||
|
||||
def unscale(self, x):
|
||||
@@ -337,4 +348,3 @@ class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation):
|
||||
z = self.unscale(z)
|
||||
noise_level = self.time_embed(noise_level)
|
||||
return z, noise_level
|
||||
|
||||
|
||||
@@ -274,7 +274,7 @@ if __name__ == "__main__":
|
||||
st.title("Stable unCLIP")
|
||||
mode = "txt2img"
|
||||
version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
|
||||
use_karlo = st.checkbox("Use KARLO prior", False) and version in ["Stable unCLIP-L"]
|
||||
use_karlo = version in ["Stable unCLIP-L"] and st.checkbox("Use KARLO prior", False)
|
||||
state = init(version=version, load_karlo_prior=use_karlo)
|
||||
st.info(state["msg"])
|
||||
prompt = st.text_input("Prompt", "a professional photograph of an astronaut riding a horse")
|
||||
@@ -306,8 +306,6 @@ if __name__ == "__main__":
|
||||
sampler = DPMSolverSampler(state["model"])
|
||||
elif sampler == "DDIM":
|
||||
sampler = DDIMSampler(state["model"])
|
||||
if st.checkbox("Try oscillating guidance?", False):
|
||||
ucg_schedule = make_oscillating_guidance_schedule(num_steps=steps, max_weight=scale, min_weight=1.)
|
||||
else:
|
||||
raise ValueError(f"unknown sampler {sampler}!")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user