mirror of
https://github.com/Stability-AI/stablediffusion.git
synced 2025-12-05 14:30:01 +01:00
* Force cast to fp32 to avoid atten layer overflow
This commit is contained in:
@@ -167,9 +167,13 @@ class CrossAttention(nn.Module):
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
# force cast to fp32 to avoid overflowing
|
||||
with torch.autocast(enabled=False, device_type = 'cuda'):
|
||||
q, k = q.float(), k.float()
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
|
||||
del q, k
|
||||
|
||||
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b ... -> b (...)')
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
|
||||
Reference in New Issue
Block a user