Skip to content

Commit

Permalink
ok
Browse files Browse the repository at this point in the history
  • Loading branch information
johndpope committed Jun 4, 2024
1 parent 7bcf6af commit 9eb3a11
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 61 deletions.
4 changes: 2 additions & 2 deletions configs/training/stage1-base.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
data:
train_width: 256
train_height: 256
train_width: 512
train_height: 512
sample_rate: 25
n_sample_frames: 1
n_motion_frames: 2
Expand Down
52 changes: 39 additions & 13 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1800,25 +1800,51 @@ def forward(self, predicted_gaze, target_gaze, face_image):

return loss / len(eye_landmarks)

class Discriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3):
super(Discriminator, self).__init__()
# class Discriminator(nn.Module):
# def __init__(self, input_nc, ndf=64, n_layers=3):
# super(Discriminator, self).__init__()

layers = [nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, True)]
# layers = [nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=1),
# nn.LeakyReLU(0.2, True)]

for i in range(1, n_layers):
layers += [nn.Conv2d(ndf * 2**(i-1), ndf * 2**i, kernel_size=4, stride=2, padding=1),
nn.InstanceNorm2d(ndf * 2**i),
nn.LeakyReLU(0.2, True)]
# for i in range(1, n_layers):
# layers += [nn.Conv2d(ndf * 2**(i-1), ndf * 2**i, kernel_size=4, stride=2, padding=1),
# nn.InstanceNorm2d(ndf * 2**i),
# nn.LeakyReLU(0.2, True)]

layers += [nn.Conv2d(ndf * 2**(n_layers-1), 1, kernel_size=4, stride=1, padding=1)]
# layers += [nn.Conv2d(ndf * 2**(n_layers-1), 1, kernel_size=4, stride=1, padding=1)]

self.model = nn.Sequential(*layers)
# self.model = nn.Sequential(*layers)

def forward(self, x):
return self.model(x)
# def forward(self, x):
# return self.model(x)


class Discriminator(nn.Module):
def __init__(self, in_channels=3):
super(Discriminator, self).__init__()

def discriminator_block(in_filters, out_filters, normalization=True):
"""Returns downsampling layers of each discriminator block"""
layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
if normalization:
layers.append(nn.InstanceNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers

self.model = nn.Sequential(
*discriminator_block(in_channels * 2, 64, normalization=False),
*discriminator_block(64, 128),
*discriminator_block(128, 256),
*discriminator_block(256, 512),
nn.ZeroPad2d((1, 0, 1, 0)),
nn.Conv2d(512, 1, 4, padding=1, bias=False)
)

def forward(self, img_A, img_B):
# Concatenate image and condition image by channels to produce input
img_input = torch.cat((img_A, img_B), 1)
return self.model(img_input)

class PerceptualLoss(nn.Module):
def __init__(self, device, weights={'vgg19': 20.0, 'vggface':5.0, 'gaze': 4.0}):
Expand Down
86 changes: 40 additions & 46 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torchvision.utils as vutils
import time
from torch.cuda.amp import autocast, GradScaler

from torch.autograd import Variable

output_dir = "output_images"
os.makedirs(output_dir, exist_ok=True)
Expand All @@ -31,14 +31,6 @@
device = torch.device("cuda" if use_cuda else "cpu")


# In the adversarial_loss function, we now use the hinge loss for the generator.
# The loss is calculated as the negative mean of the discriminator's prediction
# for the fake frame. This encourages the generator to produce frames that can fool
# the discriminator.
def adversarial_loss(output_frame, discriminator):
fake_pred = discriminator(output_frame)
loss = -torch.mean(fake_pred)
return loss.requires_grad_()


# align to cyclegan
Expand All @@ -52,15 +44,7 @@ def discriminator_loss(real_pred, fake_pred, loss_type='lsgan'):
else:
raise NotImplementedError(f'Loss type {loss_type} is not implemented.')

return (real_loss + fake_loss) * 0.5


def feature_matching_loss(real_features, fake_features):
loss = 0
for real_feat, fake_feat in zip(real_features, fake_features):
loss += torch.mean(torch.abs(real_feat - fake_feat))
return loss.requires_grad_()

return ((real_loss + fake_loss) * 0.5).requires_grad_()


# cosine distance formula
Expand Down Expand Up @@ -89,10 +73,12 @@ def cosine_loss(pos_pairs, neg_pairs, s=5.0, m=0.2):
loss = loss + torch.log(torch.exp(pos_dist) / (torch.exp(pos_dist) + neg_term))

assert len(pos_pairs) > 0, "pos_pairs should not be empty"
return (-loss / len(pos_pairs)).requires_grad_()

return torch.mean(-loss / len(pos_pairs)).requires_grad_()

def train_base(cfg, Gbase, Dbase, dataloader):
patch = (1, cfg.data.train_width // 2 ** 4, cfg.data.train_height // 2 ** 4)
hinge_loss = nn.HingeEmbeddingLoss(reduction='mean')
feature_matching_loss = nn.MSELoss()
Gbase.train()
Dbase.train()
optimizer_G = torch.optim.AdamW(Gbase.parameters(), lr=cfg.training.lr, betas=(0.5, 0.999), weight_decay=1e-2)
Expand Down Expand Up @@ -166,12 +152,37 @@ def train_base(cfg, Gbase, Dbase, dataloader):

# Calculate perceptual losses
loss_G_per = perceptual_loss_fn(pred_frame, source_frame)

# Adversarial ground truths - from Kevin Fringe
valid = Variable(torch.Tensor(np.ones((driving_frame.size(0), *patch))), requires_grad=False).to(device)
fake = Variable(torch.Tensor(-1 * np.ones((driving_frame.size(0), *patch))), requires_grad=False).to(device)

# real loss
real_pred = Dbase(driving_frame, source_frame)
loss_real = hinge_loss(real_pred, valid)

# fake loss
fake_pred = Dbase(pred_frame.detach(), source_frame)
loss_fake = hinge_loss(fake_pred, fake)

# Train discriminator
optimizer_D.zero_grad()

# Calculate adversarial losses
loss_G_adv = adversarial_loss(pred_frame, Dbase)
loss_fm = perceptual_loss_fn(pred_frame, source_frame, use_fm_loss=True)

real_pred = Dbase(driving_frame, source_frame)
fake_pred = Dbase(pred_frame.detach(), source_frame)
loss_D = discriminator_loss(real_pred, fake_pred, loss_type='lsgan')

scaler.scale(loss_D).backward()
scaler.step(optimizer_D)
scaler.update()

# Calculate adversarial losses
loss_G_adv = 0.5 * (loss_real + loss_fake)

# Feature matching loss
loss_fm = feature_matching_loss(pred_frame, driving_frame)


# The other objective CycleGAN regularizes the training and introduces disentanglement between the motion and canonical space
# In order to calculate this loss, we use an additional source-driving pair x𝑠∗ and x𝑑∗ ,
# which is sampled from a different video! and therefore has different appearance from the current x𝑠 , x𝑑 pair.
Expand Down Expand Up @@ -201,34 +212,17 @@ def train_base(cfg, Gbase, Dbase, dataloader):
loss_G_cos = cosine_loss(P, N)


# Combine the losses

# Backpropagate and update generator
optimizer_G.zero_grad()
total_loss = cfg.training.w_per * loss_G_per + \
cfg.training.w_adv * loss_G_adv + \
cfg.training.w_fm * loss_fm + \
cfg.training.w_cos * loss_G_cos

# Convert total_loss to a scalar value
total_loss = torch.mean(total_loss)


# Backpropagate and update generator
cfg.training.w_cos * loss_G_cos
scaler.scale(total_loss).backward()
scaler.step(optimizer_G)
scaler.update()


# Train discriminator
optimizer_D.zero_grad()

with autocast():
# Calculate adversarial losses
real_pred = Dbase(driving_frame)
fake_pred = Dbase(pred_frame.detach())
loss_D = discriminator_loss(real_pred, fake_pred, loss_type='lsgan')

scaler.scale(loss_D).backward()
scaler.step(optimizer_D)
scaler.update()


scheduler_G.step()
Expand Down Expand Up @@ -303,7 +297,7 @@ def main(cfg: OmegaConf) -> None:


Gbase = model.Gbase().to(device)
Dbase = model.Discriminator(input_nc=3).to(device)
Dbase = model.Discriminator().to(device)

train_base(cfg, Gbase, Dbase, dataloader)
torch.save(Gbase.state_dict(), 'Gbase.pth')
Expand Down

0 comments on commit 9eb3a11

Please sign in to comment.