From 23811d9879559c1ae03560be101ed2952bac0248 Mon Sep 17 00:00:00 2001 From: 1Konny <lee.wonkwang94@gmail.com> Date: Tue, 22 May 2018 18:34:55 +0900 Subject: [PATCH] add new version --- main.py | 27 ++-- model.py | 173 ++++++++++++++-------- run_3dchairs.sh | 3 - run_3dchairs_H.sh | 5 + run_celeba.sh | 3 - run_celeba_H.sh | 5 + run_dsprites.sh | 3 - run_dsprites_B.sh | 5 + run_dsprites_B2.sh | 5 + solver.py | 358 +++++++++++++++++++++++++++++++++------------ 10 files changed, 411 insertions(+), 176 deletions(-) delete mode 100644 run_3dchairs.sh create mode 100644 run_3dchairs_H.sh delete mode 100644 run_celeba.sh create mode 100644 run_celeba_H.sh delete mode 100644 run_dsprites.sh create mode 100644 run_dsprites_B.sh create mode 100644 run_dsprites_B2.sh diff --git a/main.py b/main.py index c2a626f..85f230c 100644 --- a/main.py +++ b/main.py @@ -11,16 +11,13 @@ from utils import str2bool torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True -init_seed = 1 -torch.manual_seed(init_seed) -torch.cuda.manual_seed(init_seed) -np.random.seed(init_seed) - -np.set_printoptions(precision=4) -torch.set_printoptions(precision=4) - def main(args): + seed = args.seed + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + np.random.seed(seed) + net = Solver(args) if args.train: @@ -33,15 +30,21 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(description='toy Beta-VAE') parser.add_argument('--train', default=True, type=str2bool, help='train or traverse') + parser.add_argument('--seed', default=1, type=int, help='random seed') parser.add_argument('--cuda', default=True, type=str2bool, help='enable cuda') parser.add_argument('--max_iter', default=1e6, type=float, help='maximum training iteration') parser.add_argument('--batch_size', default=64, type=int, help='batch size') parser.add_argument('--z_dim', default=10, type=int, help='dimension of the representation z') - parser.add_argument('--beta', default=4, type=float, help='beta parameter for KL-term') + parser.add_argument('--beta', default=4, type=float, help='beta parameter for KL-term in original beta-VAE') + parser.add_argument('--objective', default='H', type=str, help='beta-vae objective proposed in Higgins et al. or Burgess et al. H/B') + parser.add_argument('--model', default='H', type=str, help='model proposed in Higgins et al. or Burgess et al. H/B') + parser.add_argument('--gamma', default=1000, type=float, help='gamma parameter for KL-term in understanding beta-VAE') + parser.add_argument('--C_max', default=25, type=float, help='capacity parameter(C) of bottleneck channel') + parser.add_argument('--C_stop_iter', default=1e5, type=float, help='when to stop increasing the capacity') parser.add_argument('--lr', default=1e-4, type=float, help='learning rate') - parser.add_argument('--beta1', default=0.5, type=float, help='Adam optimizer beta1') - parser.add_argument('--beta2', default=0.9, type=float, help='Adam optimizer beta2') + parser.add_argument('--beta1', default=0.9, type=float, help='Adam optimizer beta1') + parser.add_argument('--beta2', default=0.999, type=float, help='Adam optimizer beta2') parser.add_argument('--dset_dir', default='data', type=str, help='dataset directory') parser.add_argument('--dataset', default='CelebA', type=str, help='dataset name') @@ -53,7 +56,7 @@ if __name__ == "__main__": parser.add_argument('--viz_port', default=8097, type=str, help='visdom port number') parser.add_argument('--ckpt_dir', default='checkpoints', type=str, help='checkpoint directory') - parser.add_argument('--load_ckpt', default=True, type=str2bool, help='load last checkpoint') + parser.add_argument('--ckpt_name', default=None, type=str, help='load previous checkpoint. insert checkpoint filename') args = parser.parse_args() diff --git a/model.py b/model.py index fd973ea..61f4791 100644 --- a/model.py +++ b/model.py @@ -7,104 +7,141 @@ import torch.nn.init as init from torch.autograd import Variable -class BetaVAE_3D(nn.Module): - def __init__(self, z_dim=10): - super(BetaVAE_3D, self).__init__() +def reparametrize(mu, logvar): + std = logvar.div(2).exp() + eps = Variable(std.data.new(std.size()).normal_()) + return mu + std*eps + + +class View(nn.Module): + def __init__(self, size): + super(View, self).__init__() + self.size = size + + def forward(self, tensor): + return tensor.view(self.size) + + +class BetaVAE_H(nn.Module): + """Model proposed in original beta-VAE paper(Higgins et al, ICLR, 2017).""" + + def __init__(self, z_dim=10, nc=3): + super(BetaVAE_H, self).__init__() self.z_dim = z_dim - self.encode = nn.Sequential( - nn.Conv2d(3, 32, 4, 2, 1), + self.nc = nc + self.encoder = nn.Sequential( + nn.Conv2d(nc, 32, 4, 2, 1), # B, 32, 32, 32 nn.ReLU(True), - nn.Conv2d(32, 32, 4, 2, 1), + nn.Conv2d(32, 32, 4, 2, 1), # B, 32, 16, 16 nn.ReLU(True), - nn.Conv2d(32, 64, 4, 2, 1), + nn.Conv2d(32, 64, 4, 2, 1), # B, 64, 8, 8 nn.ReLU(True), - nn.Conv2d(64, 64, 4, 2, 1), + nn.Conv2d(64, 64, 4, 2, 1), # B, 64, 4, 4 nn.ReLU(True), - nn.Conv2d(64, 256, 4, 1), + nn.Conv2d(64, 256, 4, 1), # B, 256, 1, 1 nn.ReLU(True), - nn.Conv2d(256, 2*self.z_dim, 1) + View((-1, 256*1*1)), # B, 256 + nn.Linear(256, z_dim*2), # B, z_dim*2 ) - self.decode = nn.Sequential( - nn.Conv2d(self.z_dim, 256, 1), + self.decoder = nn.Sequential( + nn.Linear(z_dim, 256), # B, 256 + View((-1, 256, 1, 1)), # B, 256, 1, 1 nn.ReLU(True), - nn.ConvTranspose2d(256, 64, 4), + nn.ConvTranspose2d(256, 64, 4), # B, 64, 4, 4 nn.ReLU(True), - nn.ConvTranspose2d(64, 64, 4, 2, 1), + nn.ConvTranspose2d(64, 64, 4, 2, 1), # B, 64, 8, 8 nn.ReLU(True), - nn.ConvTranspose2d(64, 32, 4, 2, 1), + nn.ConvTranspose2d(64, 32, 4, 2, 1), # B, 32, 16, 16 nn.ReLU(True), - nn.ConvTranspose2d(32, 32, 4, 2, 1), + nn.ConvTranspose2d(32, 32, 4, 2, 1), # B, 32, 32, 32 nn.ReLU(True), - nn.ConvTranspose2d(32, 3, 4, 2, 1), + nn.ConvTranspose2d(32, nc, 4, 2, 1), # B, nc, 64, 64 ) - self.weight_init() def weight_init(self): for block in self._modules: for m in self._modules[block]: kaiming_init(m) - def reparametrize(self, mu, logvar): - std = logvar.mul(0.5).exp_() - eps = Variable(std.data.new(std.size()).normal_()) - return eps.mul(std).add_(mu) - def forward(self, x): - stats = self.encode(x) - mu = stats[:, :self.z_dim] - logvar = stats[:, self.z_dim:] - z = self.reparametrize(mu, logvar) - x_recon = self.decode(z) + distributions = self._encode(x) + mu = distributions[:, :self.z_dim] + logvar = distributions[:, self.z_dim:] + z = reparametrize(mu, logvar) + x_recon = self._decode(z) - return x_recon, mu.squeeze(), logvar.squeeze() + return x_recon, mu, logvar + def _encode(self, x): + return self.encoder(x) -class View(nn.Module): - def __init__(self, size): - super(View, self).__init__() - self.size = size + def _decode(self, z): + return self.decoder(z) - def forward(self, tensor): - return tensor.view(self.size) +class BetaVAE_B(BetaVAE_H): + """Model proposed in understanding beta-VAE paper(Burgess et al, arxiv:1804.03599, 2018).""" -class BetaVAE_2D(BetaVAE_3D): - def __init__(self, z_dim=10): - super(BetaVAE_2D, self).__init__() + def __init__(self, z_dim=10, nc=1): + super(BetaVAE_B, self).__init__() + self.nc = nc self.z_dim = z_dim - # Views are applied just for the consistency in shape with CONV-based models - self.encode = nn.Sequential( - View((-1, 4096)), - nn.Linear(4096, 1200), + self.encoder = nn.Sequential( + nn.Conv2d(nc, 32, 4, 2, 1), # B, 32, 32, 32 + nn.ReLU(True), + nn.Conv2d(32, 32, 4, 2, 1), # B, 32, 16, 16 nn.ReLU(True), - nn.Linear(1200, 1200), + nn.Conv2d(32, 32, 4, 2, 1), # B, 32, 8, 8 nn.ReLU(True), - nn.Linear(1200, 2*self.z_dim), - View((-1, 2*self.z_dim, 1, 1)), + nn.Conv2d(32, 32, 4, 2, 1), # B, 32, 4, 4 + nn.ReLU(True), + View((-1, 32*4*4)), # B, 512 + nn.Linear(32*4*4, 256), # B, 256 + nn.ReLU(True), + nn.Linear(256, 256), # B, 256 + nn.ReLU(True), + nn.Linear(256, z_dim*2), # B, z_dim*2 ) - self.decode = nn.Sequential( - View((-1, self.z_dim)), - nn.Linear(self.z_dim, 1200), - nn.Tanh(), - nn.Linear(1200, 1200), - nn.Tanh(), - nn.Linear(1200, 1200), - nn.Tanh(), - nn.Linear(1200, 4096), - View((-1, 1, 64, 64)), + + self.decoder = nn.Sequential( + nn.Linear(z_dim, 256), # B, 256 + nn.ReLU(True), + nn.Linear(256, 256), # B, 256 + nn.ReLU(True), + nn.Linear(256, 32*4*4), # B, 512 + nn.ReLU(True), + View((-1, 32, 4, 4)), # B, 32, 4, 4 + nn.ConvTranspose2d(32, 32, 4, 2, 1), # B, 32, 8, 8 + nn.ReLU(True), + nn.ConvTranspose2d(32, 32, 4, 2, 1), # B, 32, 16, 16 + nn.ReLU(True), + nn.ConvTranspose2d(32, 32, 4, 2, 1), # B, 32, 32, 32 + nn.ReLU(True), + nn.ConvTranspose2d(32, nc, 4, 2, 1), # B, nc, 64, 64 ) self.weight_init() + def weight_init(self): + for block in self._modules: + for m in self._modules[block]: + kaiming_init(m) + def forward(self, x): - stats = self.encode(x) - mu = stats[:, :self.z_dim] - logvar = stats[:, self.z_dim:] - z = self.reparametrize(mu, logvar) - x_recon = self.decode(z).view(x.size()) + distributions = self._encode(x) + mu = distributions[:, :self.z_dim] + logvar = distributions[:, self.z_dim:] + z = reparametrize(mu, logvar) + x_recon = self._decode(z).view(x.size()) return x_recon, mu, logvar + def _encode(self, x): + return self.encoder(x) + + def _decode(self, z): + return self.decoder(z) + def kaiming_init(m): if isinstance(m, (nn.Linear, nn.Conv2d)): @@ -117,8 +154,16 @@ def kaiming_init(m): m.bias.data.fill_(0) +def normal_init(m, mean, std): + if isinstance(m, (nn.Linear, nn.Conv2d)): + m.weight.data.normal_(mean, std) + if m.bias.data is not None: + m.bias.data.zero_() + elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)): + m.weight.data.fill_(1) + if m.bias.data is not None: + m.bias.data.zero_() + + if __name__ == '__main__': - import ipdb; ipdb.set_trace() - net = BetaVAE(32) - x = Variable(torch.rand(1, 3, 64, 64)) - net(x) + pass diff --git a/run_3dchairs.sh b/run_3dchairs.sh deleted file mode 100644 index bf3e82d..0000000 --- a/run_3dchairs.sh +++ /dev/null @@ -1,3 +0,0 @@ -#! /bin/sh - -python main.py --dataset 3dchairs --max_iter 1e6 --beta 4 --batch_size 64 --lr 1e-4 --z_dim 10 --viz_name beta_vae_3dchairs --viz_port 55558 diff --git a/run_3dchairs_H.sh b/run_3dchairs_H.sh new file mode 100644 index 0000000..81ed6e8 --- /dev/null +++ b/run_3dchairs_H.sh @@ -0,0 +1,5 @@ +#! /bin/sh + +python main.py --dataset 3dchairs --seed 1 --lr 1e-4 --beta1 0.9 --beta 0.999 \ + --objective H --model H --batch_size 64 --z_dim 10 --max_iter 1e6 \ + --beta 4 --viz_name 3dchairs_H diff --git a/run_celeba.sh b/run_celeba.sh deleted file mode 100644 index 8a221e8..0000000 --- a/run_celeba.sh +++ /dev/null @@ -1,3 +0,0 @@ -#! /bin/sh - -python main.py --dataset celeba --max_iter 1e6 --beta 64 --batch_size 64 --lr 1e-4 --z_dim 10 --viz_name beta_vae_celeba --viz_port 55558 diff --git a/run_celeba_H.sh b/run_celeba_H.sh new file mode 100644 index 0000000..1c2dee8 --- /dev/null +++ b/run_celeba_H.sh @@ -0,0 +1,5 @@ +#! /bin/sh + +python main.py --dataset celeba --seed 1 --lr 1e-4 --beta1 0.9 --beta2 0.999 \ + --objective H --model H --batch_size 64 --z_dim 10 --max_iter 1e6 \ + --beta 64 --viz_name celeba_H diff --git a/run_dsprites.sh b/run_dsprites.sh deleted file mode 100644 index 0658184..0000000 --- a/run_dsprites.sh +++ /dev/null @@ -1,3 +0,0 @@ -#! /bin/sh - -python main.py --dataset dsprites --max_iter 3e5 --beta 4 --batch_size 64 --lr 1e-4 --z_dim 10 --viz_name beta_vae_dsprites --viz_port 55558 diff --git a/run_dsprites_B.sh b/run_dsprites_B.sh new file mode 100644 index 0000000..346f5c2 --- /dev/null +++ b/run_dsprites_B.sh @@ -0,0 +1,5 @@ +#! /bin/sh + +python main.py --dataset dsprites --seed 2 --lr 5e-4 --beta1 0.9 --beta2 0.999 \ + --objective B --model B --batch_size 64 --z_dim 10 --max_iter 1e6 \ + --C_stop_iter 1e5 --C_max 20 --gamma 100 --viz_name dsprites_B diff --git a/run_dsprites_B2.sh b/run_dsprites_B2.sh new file mode 100644 index 0000000..f373406 --- /dev/null +++ b/run_dsprites_B2.sh @@ -0,0 +1,5 @@ +#! /bin/sh + +python main.py --dataset dsprites --seed 2 --lr 5e-4 --beta1 0.9 --beta2 0.999 \ + --objective B --model B --batch_size 64 --z_dim 10 --max_iter 1e6 \ + --C_stop_iter 1e5 --C_max 20 --gamma 100 --viz_name dsprites_B --load_ckpt last --viz_port 55558 diff --git a/solver.py b/solver.py index 15c6dfd..27a8779 100644 --- a/solver.py +++ b/solver.py @@ -1,9 +1,9 @@ """solver.py""" -import time from pathlib import Path - +from tqdm import tqdm import visdom + import torch import torch.optim as optim import torch.nn.functional as F @@ -11,194 +11,364 @@ from torch.autograd import Variable from torchvision.utils import make_grid from utils import cuda -from model import BetaVAE_2D, BetaVAE_3D +from model import BetaVAE_H, BetaVAE_B from dataset import return_data -def original_vae_loss(x, x_recon, mu, logvar): +def reconstruction_loss(x, x_recon, distribution): batch_size = x.size(0) - if batch_size == 0: - recon_loss = 0 - kl_divergence = 0 - else: + assert batch_size != 0 + + if distribution == 'bernoulli': recon_loss = F.binary_cross_entropy_with_logits(x_recon, x, size_average=False).div(batch_size) - # kld which one is correct? from the equation in most of papers, - # I think the first one is correct but official pytorch code uses the second one. + elif distribution == 'gaussian': + x_recon = F.sigmoid(x_recon) + recon_loss = F.mse_loss(x_recon, x, size_average=False).div(batch_size) + else: + recon_loss = None + + return recon_loss + + +def kl_divergence(mu, logvar): + batch_size = mu.size(0) + assert batch_size != 0 + if mu.data.ndimension() == 4: + mu = mu.view(mu.size(0), mu.size(1)) + if logvar.data.ndimension() == 4: + logvar = logvar.view(logvar.size(0), logvar.size(1)) + + klds = -0.5*(1 + logvar - mu.pow(2) - logvar.exp()) + total_kld = klds.sum(1).mean() + dimension_wise_kld = klds.mean(0) + mean_kld = klds.mean() + + return total_kld, dimension_wise_kld, mean_kld + + +class DataGather(object): + def __init__(self): + self.data = self.get_empty_data_dict() + + def get_empty_data_dict(self): + return dict(iter=[], + recon_loss=[], + total_kld=[], + dim_wise_kld=[], + mean_kld=[], + mu=[], + var=[], + images=[],) - kl_divergence = -0.5*(1 + logvar - mu**2 - logvar.exp()).sum(1).mean() - #kl_divergence = -0.5*(1 + logvar - mu**2 - logvar.exp()).sum() - #dimension_wise_kl_divergence = -0.5*(1 + logvar - mu**2 - logvar.exp()).mean(0) + def insert(self, **kwargs): + for key in kwargs: + self.data[key].append(kwargs[key]) - return recon_loss, kl_divergence + def flush(self): + self.data = self.get_empty_data_dict() class Solver(object): def __init__(self, args): - - # Misc self.use_cuda = args.cuda and torch.cuda.is_available() self.max_iter = args.max_iter self.global_iter = 0 - # Networks & Optimizers self.z_dim = args.z_dim self.beta = args.beta - + self.gamma = args.gamma + self.C_max = args.C_max + self.C_stop_iter = args.C_stop_iter + self.objective = args.objective + self.model = args.model self.lr = args.lr self.beta1 = args.beta1 self.beta2 = args.beta2 if args.dataset.lower() == 'dsprites': - net = BetaVAE_2D + self.nc = 1 + self.decoder_dist = 'bernoulli' elif args.dataset.lower() == '3dchairs': - net = BetaVAE_3D + self.nc = 3 + self.decoder_dist = 'bernoulli' elif args.dataset.lower() == 'celeba': - net = BetaVAE_3D + self.nc = 3 + self.decoder_dist = 'gaussian' else: raise NotImplementedError - self.net = cuda(net(self.z_dim), self.use_cuda) + if args.model == 'H': + net = BetaVAE_H + elif args.model == 'B': + net = BetaVAE_B + else: + raise NotImplementedError('only support model H or B') + + self.net = cuda(net(self.z_dim, self.nc), self.use_cuda) self.optim = optim.Adam(self.net.parameters(), lr=self.lr, betas=(self.beta1, self.beta2)) - # Visdom self.viz_name = args.viz_name self.viz_port = args.viz_port self.viz_on = args.viz_on if self.viz_on: - self.viz = visdom.Visdom(env=self.viz_name, port=self.viz_port) - self.viz_curves = visdom.Visdom(env=self.viz_name+'/train_curves', port=self.viz_port) + self.viz = visdom.Visdom(env=self.viz_name+'_lines', port=self.viz_port) self.win_recon = None self.win_kld = None + self.win_mu = None + self.win_var = None - # Checkpoint self.ckpt_dir = Path(args.ckpt_dir).joinpath(args.viz_name) if not self.ckpt_dir.exists(): self.ckpt_dir.mkdir(parents=True, exist_ok=True) + self.ckpt_name = args.ckpt_name + if self.ckpt_name is not None: + self.load_checkpoint(self.ckpt_name) - self.load_ckpt = args.load_ckpt - if self.load_ckpt: - self.load_checkpoint() - - # Data self.dset_dir = args.dset_dir + self.dataset = args.dataset self.batch_size = args.batch_size self.data_loader = return_data(args) + self.gather = DataGather() + def train(self): self.net_mode(train=True) + self.C_max = Variable(cuda(torch.FloatTensor([self.C_max]), self.use_cuda)) out = False + + pbar = tqdm(total=self.max_iter) + pbar.update(self.global_iter) while not out: - start = time.time() curve_data = [] + curves = dict(iter=[], total_kld=[], dim_wise_kld=[], mean_kld=[]) for x in self.data_loader: self.global_iter += 1 + pbar.update(1) x = Variable(cuda(x, self.use_cuda)) x_recon, mu, logvar = self.net(x) - recon_loss, kld = original_vae_loss(x, x_recon, mu, logvar) + recon_loss = reconstruction_loss(x, x_recon, self.decoder_dist) + total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar) - beta_vae_loss = recon_loss + self.beta*kld + if self.objective == 'H': + beta_vae_loss = recon_loss + self.beta*total_kld + elif self.objective == 'B': + C = torch.clamp(self.C_max/self.C_stop_iter*self.global_iter, 0, self.C_max.data[0]) + beta_vae_loss = recon_loss + self.gamma*(total_kld-C).abs() self.optim.zero_grad() beta_vae_loss.backward() self.optim.step() if self.global_iter%1000 == 0: - curve_data.append(torch.Tensor([self.global_iter, - recon_loss.data[0], - kld.data[0],])) + self.gather.insert(iter=self.global_iter, + mu=mu.mean(0).data, var=logvar.exp().mean(0).data, + recon_loss=recon_loss.data, total_kld=total_kld.data, + dim_wise_kld=dim_wise_kld.data, mean_kld=mean_kld.data) if self.global_iter%5000 == 0: - self.save_checkpoint() - self.visualize(dict(image=[x, x_recon], curve=curve_data)) - print('[{}] recon_loss:{:.3f} beta*kld:{:.3f}'.format( - self.global_iter, recon_loss.data[0], self.beta*kld.data[0])) - curve_data = [] - - if self.global_iter%100000 == 0: + self.gather.insert(images=x.data) + self.gather.insert(images=x_recon.data) + self.visualize() + self.gather.flush() + self.save_checkpoint('last') + pbar.write('[{}] recon_loss:{:.3f} total_kld:{:.3f} mean_kld:{:.3f}'.format( + self.global_iter, recon_loss.data[0], total_kld.data[0], mean_kld.data[0])) + + var = logvar.exp().mean(0).data + var_str = '' + for j, var_j in enumerate(var): + var_str += 'var{}:{:.4f} '.format(j+1, var_j) + pbar.write(var_str) + + if self.objective == 'advanced': + pbar.write('C:{:.3f}'.format(C.data[0])) + + if self.global_iter%10000 == 0: self.traverse() + if self.global_iter%50000 == 0: + self.save_checkpoint(str(self.global_iter)) + if self.global_iter >= self.max_iter: out = True break - end = time.time() - print('[time elapsed] {:.2f}s/epoch'.format(end-start)) - print("[Training Finished]") + pbar.write("[Training Finished]") + pbar.close() - def visualize(self, data): - x, x_recon = data['image'] - curve_data = data['curve'] + def visualize(self): + self.net_mode(train=False) + x = self.gather.data['images'][0][:100] + x = make_grid(x, normalize=False) + x_recon = F.sigmoid(self.gather.data['images'][1])[:100] + x_recon = make_grid(x_recon, normalize=False) + images = torch.stack([x, x_recon], dim=0).cpu() + self.viz.images(images, env=self.viz_name+'_reconstruction', + opts=dict(title=str(self.global_iter)), nrow=10) - sample_x = make_grid(x.data.cpu(), normalize=False) - sample_x_recon = make_grid(F.sigmoid(x_recon).data.cpu(), normalize=False) - samples = torch.stack([sample_x, sample_x_recon], dim=0) - self.viz.images(samples, opts=dict(title=str(self.global_iter))) + recon_losses = torch.stack(self.gather.data['recon_loss']).cpu() - curve_data = torch.stack(curve_data, dim=0) - curve_iter = curve_data[:, 0] - curve_recon = curve_data[:, 1] - curve_kld = curve_data[:, 2] + mus = torch.stack(self.gather.data['mu']).cpu() + vars = torch.stack(self.gather.data['var']).cpu() + + dim_wise_klds = torch.stack(self.gather.data['dim_wise_kld']) + mean_klds = torch.stack(self.gather.data['mean_kld']) + total_klds = torch.stack(self.gather.data['total_kld']) + klds = torch.cat([dim_wise_klds, mean_klds, total_klds], 1).cpu() + legend = [] + for z_j in range(self.z_dim): + legend.append('z_{}'.format(z_j)) + legend.append('mean') + legend.append('total') + + iters = torch.Tensor(self.gather.data['iter']) if self.win_recon is None: - self.win_recon = self.viz_curves.line( - X=curve_iter, - Y=curve_recon, + self.win_recon = self.viz.line( + X=iters, + Y=recon_losses, + env=self.viz_name+'_lines', opts=dict( + width=400, + height=400, xlabel='iteration', - ylabel='reconsturction loss',)) + title='reconsturction loss',)) else: - self.win_recon = self.viz_curves.line( - X=curve_iter, - Y=curve_recon, + self.win_recon = self.viz.line( + X=iters, + Y=recon_losses, + env=self.viz_name+'_lines', win=self.win_recon, update='append', opts=dict( + width=400, + height=400, xlabel='iteration', - ylabel='reconsturction loss',)) + title='reconsturction loss',)) if self.win_kld is None: - self.win_kld = self.viz_curves.line( - X=curve_iter, - Y=curve_kld, + self.win_kld = self.viz.line( + X=iters, + Y=klds, + env=self.viz_name+'_lines', opts=dict( + width=400, + height=400, + legend=legend, xlabel='iteration', - ylabel='kl divergence',)) + title='kl divergence',)) else: - self.win_kld = self.viz_curves.line( - X=curve_iter, - Y=curve_kld, + self.win_kld = self.viz.line( + X=iters, + Y=klds, + env=self.viz_name+'_lines', win=self.win_kld, update='append', opts=dict( + width=400, + height=400, + legend=legend, xlabel='iteration', - ylabel='kl divergence',)) + title='kl divergence',)) - def traverse(self): + if self.win_mu is None: + self.win_mu = self.viz.line( + X=iters, + Y=mus, + env=self.viz_name+'_lines', + opts=dict( + width=400, + height=400, + legend=legend[:self.z_dim], + xlabel='iteration', + title='posterior mean',)) + else: + self.win_mu = self.viz.line( + X=iters, + Y=vars, + env=self.viz_name+'_lines', + win=self.win_mu, + update='append', + opts=dict( + width=400, + height=400, + legend=legend[:self.z_dim], + xlabel='iteration', + title='posterior mean',)) + + if self.win_var is None: + self.win_var = self.viz.line( + X=iters, + Y=vars, + env=self.viz_name+'_lines', + opts=dict( + width=400, + height=400, + legend=legend[:self.z_dim], + xlabel='iteration', + title='posterior variance',)) + else: + self.win_var = self.viz.line( + X=iters, + Y=vars, + env=self.viz_name+'_lines', + win=self.win_var, + update='append', + opts=dict( + width=400, + height=400, + legend=legend[:self.z_dim], + xlabel='iteration', + title='posterior variance',)) + self.net_mode(train=True) + + def traverse(self, limit=3, inter=2/3): + self.net_mode(train=False) import random - decoder = self.net.decode - encoder = self.net.encode - interpolation = torch.arange(-6, 6.1, 1) + decoder = self.net.decoder + encoder = self.net.encoder + interpolation = torch.arange(-limit, limit+0.1, inter) viz = visdom.Visdom(env=self.viz_name+'/traverse', port=self.viz_port) - n_dsets = self.data_loader.dataset.__len__() - fixed_idx = 0 + n_dsets = len(self.data_loader.dataset) rand_idx = random.randint(1, n_dsets-1) - fixed_img = self.data_loader.dataset.__getitem__(fixed_idx) - fixed_img = Variable(cuda(fixed_img, self.use_cuda), volatile=True).unsqueeze(0) - fixed_img_z = encoder(fixed_img)[:, :self.z_dim] - random_img = self.data_loader.dataset.__getitem__(rand_idx) random_img = Variable(cuda(random_img, self.use_cuda), volatile=True).unsqueeze(0) random_img_z = encoder(random_img)[:, :self.z_dim] - zero_z = Variable(cuda(torch.zeros(1, self.z_dim, 1, 1), self.use_cuda), volatile=True) - random_z = Variable(cuda(torch.rand(1, self.z_dim, 1, 1), self.use_cuda), volatile=True) + random_z = Variable(cuda(torch.rand(1, self.z_dim), self.use_cuda), volatile=True) + + if self.dataset == 'dsprites': + fixed_idx1 = 87040 # square + fixed_idx2 = 332800 # ellipse + fixed_idx3 = 578560 # heart + + fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1) + fixed_img1 = Variable(cuda(fixed_img1, self.use_cuda), volatile=True).unsqueeze(0) + fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim] + + fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2) + fixed_img2 = Variable(cuda(fixed_img2, self.use_cuda), volatile=True).unsqueeze(0) + fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim] + + fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3) + fixed_img3 = Variable(cuda(fixed_img3, self.use_cuda), volatile=True).unsqueeze(0) + fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim] + + Z = {'fixed_square':fixed_img_z1, 'fixed_ellipse':fixed_img_z2, + 'fixed_heart':fixed_img_z3, 'random_img':random_img_z} + else: + fixed_idx = 0 + fixed_img = self.data_loader.dataset.__getitem__(fixed_idx) + fixed_img = Variable(cuda(fixed_img, self.use_cuda), volatile=True).unsqueeze(0) + fixed_img_z = encoder(fixed_img)[:, :self.z_dim] + + Z = {'fixed_img':fixed_img_z, 'random_img':random_img_z, 'random_z':random_z} - Z = {'fixed_img':fixed_img_z, 'random_img':random_img_z, 'random_z':random_z, 'zero_z':zero_z} for key in Z.keys(): z_ori = Z[key] samples = [] @@ -209,9 +379,11 @@ class Solver(object): sample = F.sigmoid(decoder(z)) samples.append(sample) samples = torch.cat(samples, dim=0).data.cpu() - title = '{}_row_traverse(iter:{})'.format(key, self.global_iter) + title = '{}_latent_traversal(iter:{})'.format(key, self.global_iter) viz.images(samples, opts=dict(title=title), nrow=len(interpolation)) + self.net_mode(train=True) + def net_mode(self, train): if not isinstance(train, bool): raise('Only bool type is supported. True or False') @@ -221,11 +393,13 @@ class Solver(object): else: self.net.eval() - def save_checkpoint(self, filename='ckpt.tar', silent=True): + def save_checkpoint(self, filename, silent=True): model_states = {'net':self.net.state_dict(),} optim_states = {'optim':self.optim.state_dict(),} win_states = {'recon':self.win_recon, - 'kld':self.win_kld,} + 'kld':self.win_kld, + 'mu':self.win_mu, + 'var':self.win_var,} states = {'iter':self.global_iter, 'win_states':win_states, 'model_states':model_states, @@ -236,13 +410,15 @@ class Solver(object): if not silent: print("=> saved checkpoint '{}' (iter {})".format(file_path, self.global_iter)) - def load_checkpoint(self, filename='ckpt.tar'): + def load_checkpoint(self, filename): file_path = self.ckpt_dir.joinpath(filename) if file_path.is_file(): checkpoint = torch.load(file_path.open('rb')) self.global_iter = checkpoint['iter'] self.win_recon = checkpoint['win_states']['recon'] self.win_kld = checkpoint['win_states']['kld'] + self.win_var = checkpoint['win_states']['var'] + self.win_mu = checkpoint['win_states']['mu'] self.net.load_state_dict(checkpoint['model_states']['net']) self.optim.load_state_dict(checkpoint['optim_states']['optim']) print("=> loaded checkpoint '{} (iter {})'".format(file_path, self.global_iter)) -- GitLab