From 5999fc170363c5cdf60c3bfebb38db60dfcce217 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 3 May 2024 09:34:08 -0700 Subject: [PATCH] trust Tero Karras as well as https://github.com/lucidrains/meshgpt-pytorch/issues/64 and start removing groupnorms from all repos --- README.md | 3 +- .../classifier_free_guidance.py | 27 ++++---- .../denoising_diffusion_pytorch.py | 61 +++++++++---------- .../denoising_diffusion_pytorch_1d.py | 27 ++++---- .../guided_diffusion.py | 27 ++++---- .../simple_diffusion.py | 23 +++---- denoising_diffusion_pytorch/version.py | 2 +- 7 files changed, 79 insertions(+), 91 deletions(-) diff --git a/README.md b/README.md index 100ee6ea3..e6074d041 100644 --- a/README.md +++ b/README.md @@ -127,13 +127,14 @@ diffusion = GaussianDiffusion1D( ) training_seq = torch.rand(64, 32, 128) # features are normalized from 0 to 1 -dataset = Dataset1D(training_seq) # this is just an example, but you can formulate your own Dataset and pass it into the `Trainer1D` below loss = diffusion(training_seq) loss.backward() # Or using trainer +dataset = Dataset1D(training_seq) # this is just an example, but you can formulate your own Dataset and pass it into the `Trainer1D` below + trainer = Trainer1D( diffusion, dataset = dataset, diff --git a/denoising_diffusion_pytorch/classifier_free_guidance.py b/denoising_diffusion_pytorch/classifier_free_guidance.py index 59b5e2f91..19bbbd3d4 100644 --- a/denoising_diffusion_pytorch/classifier_free_guidance.py +++ b/denoising_diffusion_pytorch/classifier_free_guidance.py @@ -148,10 +148,10 @@ def forward(self, x): # building block modules class Block(nn.Module): - def __init__(self, dim, dim_out, groups = 8): + def __init__(self, dim, dim_out): super().__init__() self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1) - self.norm = nn.GroupNorm(groups, dim_out) + self.norm = RMSNorm(dim_out) self.act = nn.SiLU() def forward(self, x, scale_shift = None): @@ -166,15 +166,15 @@ def forward(self, x, scale_shift = None): return x class ResnetBlock(nn.Module): - def __init__(self, dim, dim_out, *, time_emb_dim = None, classes_emb_dim = None, groups = 8): + def __init__(self, dim, dim_out, *, time_emb_dim = None, classes_emb_dim = None): super().__init__() self.mlp = nn.Sequential( nn.SiLU(), nn.Linear(int(time_emb_dim) + int(classes_emb_dim), dim_out * 2) ) if exists(time_emb_dim) or exists(classes_emb_dim) else None - self.block1 = Block(dim, dim_out, groups = groups) - self.block2 = Block(dim_out, dim_out, groups = groups) + self.block1 = Block(dim, dim_out) + self.block2 = Block(dim_out, dim_out) self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() def forward(self, x, time_emb = None, class_emb = None): @@ -258,7 +258,6 @@ def __init__( out_dim = None, dim_mults=(1, 2, 4, 8), channels = 3, - resnet_block_groups = 8, learned_variance = False, learned_sinusoidal_cond = False, random_fourier_features = False, @@ -283,8 +282,6 @@ def __init__( dims = [init_dim, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) - block_klass = partial(ResnetBlock, groups = resnet_block_groups) - # time embeddings time_dim = dim * 4 @@ -328,23 +325,23 @@ def __init__( is_last = ind >= (num_resolutions - 1) self.downs.append(nn.ModuleList([ - block_klass(dim_in, dim_in, time_emb_dim = time_dim, classes_emb_dim = classes_dim), - block_klass(dim_in, dim_in, time_emb_dim = time_dim, classes_emb_dim = classes_dim), + ResnetBlock(dim_in, dim_in, time_emb_dim = time_dim, classes_emb_dim = classes_dim), + ResnetBlock(dim_in, dim_in, time_emb_dim = time_dim, classes_emb_dim = classes_dim), Residual(PreNorm(dim_in, LinearAttention(dim_in))), Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1) ])) mid_dim = dims[-1] - self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim, classes_emb_dim = classes_dim) + self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim = time_dim, classes_emb_dim = classes_dim) self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim, dim_head = attn_dim_head, heads = attn_heads))) - self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim, classes_emb_dim = classes_dim) + self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim = time_dim, classes_emb_dim = classes_dim) for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): is_last = ind == (len(in_out) - 1) self.ups.append(nn.ModuleList([ - block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim, classes_emb_dim = classes_dim), - block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim, classes_emb_dim = classes_dim), + ResnetBlock(dim_out + dim_in, dim_out, time_emb_dim = time_dim, classes_emb_dim = classes_dim), + ResnetBlock(dim_out + dim_in, dim_out, time_emb_dim = time_dim, classes_emb_dim = classes_dim), Residual(PreNorm(dim_out, LinearAttention(dim_out))), Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1) ])) @@ -352,7 +349,7 @@ def __init__( default_out_dim = channels * (1 if not learned_variance else 2) self.out_dim = default(out_dim, default_out_dim) - self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim, classes_emb_dim = classes_dim) + self.final_res_block = ResnetBlock(dim * 2, dim, time_emb_dim = time_dim, classes_emb_dim = classes_dim) self.final_conv = nn.Conv2d(dim, self.out_dim, 1) def forward_with_cond_scale( diff --git a/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py b/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py index 1ce813cbe..41e2cebb9 100644 --- a/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +++ b/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py @@ -8,8 +8,9 @@ import torch from torch import nn, einsum -from torch.cuda.amp import autocast import torch.nn.functional as F +from torch.nn import Module, ModuleList +from torch.cuda.amp import autocast from torch.utils.data import Dataset, DataLoader from torch.optim import Adam @@ -98,17 +99,18 @@ def Downsample(dim, dim_out = None): nn.Conv2d(dim * 4, default(dim_out, dim), 1) ) -class RMSNorm(nn.Module): +class RMSNorm(Module): def __init__(self, dim): super().__init__() + self.scale = dim ** 0.5 self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) def forward(self, x): - return F.normalize(x, dim = 1) * self.g * (x.shape[1] ** 0.5) + return F.normalize(x, dim = 1) * self.g * self.scale # sinusoidal positional embeds -class SinusoidalPosEmb(nn.Module): +class SinusoidalPosEmb(Module): def __init__(self, dim, theta = 10000): super().__init__() self.dim = dim @@ -123,7 +125,7 @@ def forward(self, x): emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb -class RandomOrLearnedSinusoidalPosEmb(nn.Module): +class RandomOrLearnedSinusoidalPosEmb(Module): """ following @crowsonkb 's lead with random (learned optional) sinusoidal pos emb """ """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ @@ -142,11 +144,11 @@ def forward(self, x): # building block modules -class Block(nn.Module): - def __init__(self, dim, dim_out, groups = 8): +class Block(Module): + def __init__(self, dim, dim_out): super().__init__() self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1) - self.norm = nn.GroupNorm(groups, dim_out) + self.norm = RMSNorm(dim_out) self.act = nn.SiLU() def forward(self, x, scale_shift = None): @@ -160,16 +162,16 @@ def forward(self, x, scale_shift = None): x = self.act(x) return x -class ResnetBlock(nn.Module): - def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8): +class ResnetBlock(Module): + def __init__(self, dim, dim_out, *, time_emb_dim = None): super().__init__() self.mlp = nn.Sequential( nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2) ) if exists(time_emb_dim) else None - self.block1 = Block(dim, dim_out, groups = groups) - self.block2 = Block(dim_out, dim_out, groups = groups) + self.block1 = Block(dim, dim_out) + self.block2 = Block(dim_out, dim_out) self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() def forward(self, x, time_emb = None): @@ -186,7 +188,7 @@ def forward(self, x, time_emb = None): return h + self.res_conv(x) -class LinearAttention(nn.Module): +class LinearAttention(Module): def __init__( self, dim, @@ -231,7 +233,7 @@ def forward(self, x): out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w) return self.to_out(out) -class Attention(nn.Module): +class Attention(Module): def __init__( self, dim, @@ -269,7 +271,7 @@ def forward(self, x): # model -class Unet(nn.Module): +class Unet(Module): def __init__( self, dim, @@ -278,7 +280,6 @@ def __init__( dim_mults = (1, 2, 4, 8), channels = 3, self_condition = False, - resnet_block_groups = 8, learned_variance = False, learned_sinusoidal_cond = False, random_fourier_features = False, @@ -303,8 +304,6 @@ def __init__( dims = [init_dim, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) - block_klass = partial(ResnetBlock, groups = resnet_block_groups) - # time embeddings time_dim = dim * 4 @@ -341,8 +340,8 @@ def __init__( # layers - self.downs = nn.ModuleList([]) - self.ups = nn.ModuleList([]) + self.downs = ModuleList([]) + self.ups = ModuleList([]) num_resolutions = len(in_out) for ind, ((dim_in, dim_out), layer_full_attn, layer_attn_heads, layer_attn_dim_head) in enumerate(zip(in_out, full_attn, attn_heads, attn_dim_head)): @@ -350,26 +349,26 @@ def __init__( attn_klass = FullAttention if layer_full_attn else LinearAttention - self.downs.append(nn.ModuleList([ - block_klass(dim_in, dim_in, time_emb_dim = time_dim), - block_klass(dim_in, dim_in, time_emb_dim = time_dim), + self.downs.append(ModuleList([ + ResnetBlock(dim_in, dim_in, time_emb_dim = time_dim), + ResnetBlock(dim_in, dim_in, time_emb_dim = time_dim), attn_klass(dim_in, dim_head = layer_attn_dim_head, heads = layer_attn_heads), Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1) ])) mid_dim = dims[-1] - self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) + self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim = time_dim) self.mid_attn = FullAttention(mid_dim, heads = attn_heads[-1], dim_head = attn_dim_head[-1]) - self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) + self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim = time_dim) for ind, ((dim_in, dim_out), layer_full_attn, layer_attn_heads, layer_attn_dim_head) in enumerate(zip(*map(reversed, (in_out, full_attn, attn_heads, attn_dim_head)))): is_last = ind == (len(in_out) - 1) attn_klass = FullAttention if layer_full_attn else LinearAttention - self.ups.append(nn.ModuleList([ - block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), - block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), + self.ups.append(ModuleList([ + ResnetBlock(dim_out + dim_in, dim_out, time_emb_dim = time_dim), + ResnetBlock(dim_out + dim_in, dim_out, time_emb_dim = time_dim), attn_klass(dim_out, dim_head = layer_attn_dim_head, heads = layer_attn_heads), Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1) ])) @@ -377,7 +376,7 @@ def __init__( default_out_dim = channels * (1 if not learned_variance else 2) self.out_dim = default(out_dim, default_out_dim) - self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim) + self.final_res_block = ResnetBlock(dim * 2, dim, time_emb_dim = time_dim) self.final_conv = nn.Conv2d(dim, self.out_dim, 1) @property @@ -470,7 +469,7 @@ def sigmoid_beta_schedule(timesteps, start = -3, end = 3, tau = 1, clamp_min = 1 betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) return torch.clip(betas, 0, 0.999) -class GaussianDiffusion(nn.Module): +class GaussianDiffusion(Module): def __init__( self, model, @@ -856,7 +855,7 @@ def __getitem__(self, index): # trainer class -class Trainer(object): +class Trainer: def __init__( self, diffusion_model, diff --git a/denoising_diffusion_pytorch/denoising_diffusion_pytorch_1d.py b/denoising_diffusion_pytorch/denoising_diffusion_pytorch_1d.py index 03391d22f..5da033f27 100644 --- a/denoising_diffusion_pytorch/denoising_diffusion_pytorch_1d.py +++ b/denoising_diffusion_pytorch/denoising_diffusion_pytorch_1d.py @@ -155,10 +155,10 @@ def forward(self, x): # building block modules class Block(nn.Module): - def __init__(self, dim, dim_out, groups = 8): + def __init__(self, dim, dim_out): super().__init__() self.proj = nn.Conv1d(dim, dim_out, 3, padding = 1) - self.norm = nn.GroupNorm(groups, dim_out) + self.norm = RMSNorm(dim_out) self.act = nn.SiLU() def forward(self, x, scale_shift = None): @@ -173,15 +173,15 @@ def forward(self, x, scale_shift = None): return x class ResnetBlock(nn.Module): - def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8): + def __init__(self, dim, dim_out, *, time_emb_dim = None): super().__init__() self.mlp = nn.Sequential( nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2) ) if exists(time_emb_dim) else None - self.block1 = Block(dim, dim_out, groups = groups) - self.block2 = Block(dim_out, dim_out, groups = groups) + self.block1 = Block(dim, dim_out) + self.block2 = Block(dim_out, dim_out) self.res_conv = nn.Conv1d(dim, dim_out, 1) if dim != dim_out else nn.Identity() def forward(self, x, time_emb = None): @@ -262,7 +262,6 @@ def __init__( dim_mults=(1, 2, 4, 8), channels = 3, self_condition = False, - resnet_block_groups = 8, learned_variance = False, learned_sinusoidal_cond = False, random_fourier_features = False, @@ -285,8 +284,6 @@ def __init__( dims = [init_dim, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) - block_klass = partial(ResnetBlock, groups = resnet_block_groups) - # time embeddings time_dim = dim * 4 @@ -317,23 +314,23 @@ def __init__( is_last = ind >= (num_resolutions - 1) self.downs.append(nn.ModuleList([ - block_klass(dim_in, dim_in, time_emb_dim = time_dim), - block_klass(dim_in, dim_in, time_emb_dim = time_dim), + ResnetBlock(dim_in, dim_in, time_emb_dim = time_dim), + ResnetBlock(dim_in, dim_in, time_emb_dim = time_dim), Residual(PreNorm(dim_in, LinearAttention(dim_in))), Downsample(dim_in, dim_out) if not is_last else nn.Conv1d(dim_in, dim_out, 3, padding = 1) ])) mid_dim = dims[-1] - self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) + self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim = time_dim) self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim, dim_head = attn_dim_head, heads = attn_heads))) - self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) + self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim = time_dim) for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): is_last = ind == (len(in_out) - 1) self.ups.append(nn.ModuleList([ - block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), - block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), + ResnetBlock(dim_out + dim_in, dim_out, time_emb_dim = time_dim), + ResnetBlock(dim_out + dim_in, dim_out, time_emb_dim = time_dim), Residual(PreNorm(dim_out, LinearAttention(dim_out))), Upsample(dim_out, dim_in) if not is_last else nn.Conv1d(dim_out, dim_in, 3, padding = 1) ])) @@ -341,7 +338,7 @@ def __init__( default_out_dim = channels * (1 if not learned_variance else 2) self.out_dim = default(out_dim, default_out_dim) - self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim) + self.final_res_block = ResnetBlock(dim * 2, dim, time_emb_dim = time_dim) self.final_conv = nn.Conv1d(dim, self.out_dim, 1) def forward(self, x, time, x_self_cond = None): diff --git a/denoising_diffusion_pytorch/guided_diffusion.py b/denoising_diffusion_pytorch/guided_diffusion.py index 052a148de..b41676c58 100644 --- a/denoising_diffusion_pytorch/guided_diffusion.py +++ b/denoising_diffusion_pytorch/guided_diffusion.py @@ -148,10 +148,10 @@ def forward(self, x): # building block modules class Block(nn.Module): - def __init__(self, dim, dim_out, groups = 8): + def __init__(self, dim, dim_out): super().__init__() self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1) - self.norm = nn.GroupNorm(groups, dim_out) + self.norm = RMSNorm(groups, dim_out) self.act = nn.SiLU() def forward(self, x, scale_shift = None): @@ -166,15 +166,15 @@ def forward(self, x, scale_shift = None): return x class ResnetBlock(nn.Module): - def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8): + def __init__(self, dim, dim_out, *, time_emb_dim = None): super().__init__() self.mlp = nn.Sequential( nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2) ) if exists(time_emb_dim) else None - self.block1 = Block(dim, dim_out, groups = groups) - self.block2 = Block(dim_out, dim_out, groups = groups) + self.block1 = Block(dim, dim_out) + self.block2 = Block(dim_out, dim_out) self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() def forward(self, x, time_emb = None): @@ -255,7 +255,6 @@ def __init__( dim_mults=(1, 2, 4, 8), channels = 3, self_condition = False, - resnet_block_groups = 8, learned_variance = False, learned_sinusoidal_cond = False, random_fourier_features = False, @@ -275,8 +274,6 @@ def __init__( dims = [init_dim, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) - block_klass = partial(ResnetBlock, groups = resnet_block_groups) - # time embeddings time_dim = dim * 4 @@ -307,23 +304,23 @@ def __init__( is_last = ind >= (num_resolutions - 1) self.downs.append(nn.ModuleList([ - block_klass(dim_in, dim_in, time_emb_dim = time_dim), - block_klass(dim_in, dim_in, time_emb_dim = time_dim), + ResnetBlock(dim_in, dim_in, time_emb_dim = time_dim), + ResnetBlock(dim_in, dim_in, time_emb_dim = time_dim), Residual(PreNorm(dim_in, LinearAttention(dim_in))), Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1) ])) mid_dim = dims[-1] - self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) + self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim = time_dim) self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) - self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) + self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim = time_dim) for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): is_last = ind == (len(in_out) - 1) self.ups.append(nn.ModuleList([ - block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), - block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), + ResnetBlock(dim_out + dim_in, dim_out, time_emb_dim = time_dim), + ResnetBlock(dim_out + dim_in, dim_out, time_emb_dim = time_dim), Residual(PreNorm(dim_out, LinearAttention(dim_out))), Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1) ])) @@ -331,7 +328,7 @@ def __init__( default_out_dim = channels * (1 if not learned_variance else 2) self.out_dim = default(out_dim, default_out_dim) - self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim) + self.final_res_block = ResnetBlock(dim * 2, dim, time_emb_dim = time_dim) self.final_conv = nn.Conv2d(dim, self.out_dim, 1) def forward(self, x, time, x_self_cond = None): diff --git a/denoising_diffusion_pytorch/simple_diffusion.py b/denoising_diffusion_pytorch/simple_diffusion.py index 1ed21574d..26bbad00e 100644 --- a/denoising_diffusion_pytorch/simple_diffusion.py +++ b/denoising_diffusion_pytorch/simple_diffusion.py @@ -116,10 +116,10 @@ def forward(self, x): # building block modules class Block(nn.Module): - def __init__(self, dim, dim_out, groups = 8): + def __init__(self, dim, dim_out): super().__init__() self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1) - self.norm = nn.GroupNorm(groups, dim_out) + self.norm = RMSNorm(dim_out, normalize_dim = 1) self.act = nn.SiLU() def forward(self, x, scale_shift = None): @@ -134,15 +134,15 @@ def forward(self, x, scale_shift = None): return x class ResnetBlock(nn.Module): - def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8): + def __init__(self, dim, dim_out, *, time_emb_dim = None): super().__init__() self.mlp = nn.Sequential( nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2) ) if exists(time_emb_dim) else None - self.block1 = Block(dim, dim_out, groups = groups) - self.block2 = Block(dim_out, dim_out, groups = groups) + self.block1 = Block(dim, dim_out) + self.block2 = Block(dim_out, dim_out) self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() def forward(self, x, time_emb = None): @@ -319,7 +319,6 @@ def __init__( attn_dim_head = 32, attn_heads = 4, ff_mult = 4, - resnet_block_groups = 8, learned_sinusoidal_dim = 16, init_img_transform: callable = None, final_img_itransform: callable = None, @@ -369,8 +368,6 @@ def __init__( dims = [init_dim, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) - resnet_block = partial(ResnetBlock, groups = resnet_block_groups) - # time embeddings time_dim = dim * 4 @@ -400,8 +397,8 @@ def __init__( is_last = ind >= (num_resolutions - 1) self.downs.append(nn.ModuleList([ - resnet_block(dim_in, dim_in, time_emb_dim = time_dim), - resnet_block(dim_in, dim_in, time_emb_dim = time_dim), + ResnetBlock(dim_in, dim_in, time_emb_dim = time_dim), + ResnetBlock(dim_in, dim_in, time_emb_dim = time_dim), LinearAttention(dim_in), Downsample(dim_in, dim_out, factor = factor) ])) @@ -423,15 +420,15 @@ def __init__( self.ups.append(nn.ModuleList([ Upsample(dim_out, dim_in, factor = factor), - resnet_block(dim_in * 2, dim_in, time_emb_dim = time_dim), - resnet_block(dim_in * 2, dim_in, time_emb_dim = time_dim), + ResnetBlock(dim_in * 2, dim_in, time_emb_dim = time_dim), + ResnetBlock(dim_in * 2, dim_in, time_emb_dim = time_dim), LinearAttention(dim_in), ])) default_out_dim = input_channels self.out_dim = default(out_dim, default_out_dim) - self.final_res_block = resnet_block(dim * 2, dim, time_emb_dim = time_dim) + self.final_res_block = ResnetBlock(dim * 2, dim, time_emb_dim = time_dim) self.final_conv = nn.Conv2d(dim, self.out_dim, 1) def forward(self, x, time): diff --git a/denoising_diffusion_pytorch/version.py b/denoising_diffusion_pytorch/version.py index 522ba08c7..afced1472 100644 --- a/denoising_diffusion_pytorch/version.py +++ b/denoising_diffusion_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.11.1' +__version__ = '2.0.0'