no-ema in config, adapt noiseaugmtor

This commit is contained in:
Robin Rombach
2023-01-29 23:38:31 +01:00
parent d7980a2ae6
commit c81b231008
3 changed files with 28 additions and 19 deletions

View File

@@ -16,6 +16,7 @@ model:
conditioning_key: crossattn-adm
scale_factor: 0.18215
monitor: val/loss_simple_ema
use_ema: False
embedder_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder

View File

@@ -38,7 +38,7 @@ class ClassEmbedder(nn.Module):
c = batch[key][:, None]
if self.ucg_rate > 0. and not disable_dropout:
mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1)
c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1)
c = c.long()
c = self.embedding(c)
return c
@@ -58,18 +58,20 @@ def disabled_train(self, mode=True):
class FrozenT5Embedder(AbstractEncoder):
"""Uses the T5 transformer encoder for text"""
def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77,
freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
super().__init__()
self.tokenizer = T5Tokenizer.from_pretrained(version)
self.transformer = T5EncoderModel.from_pretrained(version)
self.device = device
self.max_length = max_length # TODO: typical value?
self.max_length = max_length # TODO: typical value?
if freeze:
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
#self.train = disabled_train
# self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
@@ -93,6 +95,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
"pooled",
"hidden"
]
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
super().__init__()
@@ -111,7 +114,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
def freeze(self):
self.transformer = self.transformer.eval()
#self.train = disabled_train
# self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
@@ -119,7 +122,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden")
if self.layer == "last":
z = outputs.last_hidden_state
elif self.layer == "pooled":
@@ -166,7 +169,7 @@ class ClipImageEmbedder(nn.Module):
out = self.model.encode_image(self.preprocess(x))
out = out.to(x.dtype)
if self.ucg_rate > 0. and not no_dropout:
out = torch.bernoulli((1.-self.ucg_rate) * torch.ones(out.shape[0], device=out.device))[:, None] * out
out = torch.bernoulli((1. - self.ucg_rate) * torch.ones(out.shape[0], device=out.device))[:, None] * out
return out
@@ -175,10 +178,11 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
Uses the OpenCLIP transformer encoder for text
"""
LAYERS = [
#"pooled",
# "pooled",
"last",
"penultimate"
]
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
freeze=True, layer="last"):
super().__init__()
@@ -218,7 +222,7 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
x = self.model.ln_final(x)
return x
def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
for i, r in enumerate(self.model.transformer.resblocks):
if i == len(self.model.transformer.resblocks) - self.layer_idx:
break
@@ -236,6 +240,7 @@ class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
"""
Uses the OpenCLIP vision transformer encoder for images
"""
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
freeze=True, layer="pooled", antialias=True, ucg_rate=0.):
super().__init__()
@@ -278,7 +283,7 @@ class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
def forward(self, image, no_dropout=False):
z = self.encode_with_vision_transformer(image)
if self.ucg_rate > 0. and not no_dropout:
z = torch.bernoulli((1.-self.ucg_rate) * torch.ones(z.shape[0], device=z.device))[:, None] * z
z = torch.bernoulli((1. - self.ucg_rate) * torch.ones(z.shape[0], device=z.device))[:, None] * z
return z
def encode_with_vision_transformer(self, img):
@@ -289,14 +294,15 @@ class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
def encode(self, text):
return self(text)
class FrozenCLIPT5Encoder(AbstractEncoder):
def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
clip_max_length=77, t5_max_length=77):
super().__init__()
self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.")
print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params.")
def encode(self, text):
return self(text)
@@ -309,17 +315,22 @@ class FrozenCLIPT5Encoder(AbstractEncoder):
from ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
from ldm.modules.diffusionmodules.openaimodel import Timestep
class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation):
def __init__(self, *args, clip_stats_path, timestep_dim=256, **kwargs):
def __init__(self, *args, clip_stats_path=None, timestep_dim=256, **kwargs):
super().__init__(*args, **kwargs)
clip_mean, clip_std = torch.load(clip_stats_path, map_location="cpu")
if clip_stats_path is None:
clip_mean, clip_std = torch.zeros(timestep_dim), torch.ones(timestep_dim)
else:
clip_mean, clip_std = torch.load(clip_stats_path, map_location="cpu")
self.register_buffer("data_mean", clip_mean[None, :], persistent=False)
self.register_buffer("data_std", clip_std[None, :], persistent=False)
self.time_embed = Timestep(timestep_dim)
def scale(self, x):
# re-normalize to centered mean and unit variance
x = (x - self.data_mean) * 1./self.data_std
x = (x - self.data_mean) * 1. / self.data_std
return x
def unscale(self, x):
@@ -337,4 +348,3 @@ class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation):
z = self.unscale(z)
noise_level = self.time_embed(noise_level)
return z, noise_level

View File

@@ -274,7 +274,7 @@ if __name__ == "__main__":
st.title("Stable unCLIP")
mode = "txt2img"
version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
use_karlo = st.checkbox("Use KARLO prior", False) and version in ["Stable unCLIP-L"]
use_karlo = version in ["Stable unCLIP-L"] and st.checkbox("Use KARLO prior", False)
state = init(version=version, load_karlo_prior=use_karlo)
st.info(state["msg"])
prompt = st.text_input("Prompt", "a professional photograph of an astronaut riding a horse")
@@ -306,8 +306,6 @@ if __name__ == "__main__":
sampler = DPMSolverSampler(state["model"])
elif sampler == "DDIM":
sampler = DDIMSampler(state["model"])
if st.checkbox("Try oscillating guidance?", False):
ucg_schedule = make_oscillating_guidance_schedule(num_steps=steps, max_weight=scale, min_weight=1.)
else:
raise ValueError(f"unknown sampler {sampler}!")