From 9af77f77bc0d1e11a54adb166f8a3adc4c11f84c Mon Sep 17 00:00:00 2001 From: 1Konny <lee.wonkwang94@gmail.com> Date: Wed, 23 May 2018 23:22:08 +0900 Subject: [PATCH] add traverse -> gif --- solver.py | 54 +++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 41 insertions(+), 13 deletions(-) diff --git a/solver.py b/solver.py index cd9176e..128fbe7 100644 --- a/solver.py +++ b/solver.py @@ -8,9 +8,9 @@ import torch import torch.optim as optim import torch.nn.functional as F from torch.autograd import Variable -from torchvision.utils import make_grid +from torchvision.utils import make_grid, save_image -from utils import cuda +from utils import cuda, grid2gif from model import BetaVAE_H, BetaVAE_B from dataset import return_data @@ -90,7 +90,7 @@ class Solver(object): self.decoder_dist = 'bernoulli' elif args.dataset.lower() == '3dchairs': self.nc = 3 - self.decoder_dist = 'bernoulli' + self.decoder_dist = 'gaussian' elif args.dataset.lower() == 'celeba': self.nc = 3 self.decoder_dist = 'gaussian' @@ -125,12 +125,18 @@ class Solver(object): if self.ckpt_name is not None: self.load_checkpoint(self.ckpt_name) + self.save_output = args.save_output + self.output_dir = Path(args.output_dir).joinpath(args.viz_name) + if not self.output_dir.exists(): + self.output_dir.mkdir(parents=True, exist_ok=True) + self.dset_dir = args.dset_dir self.dataset = args.dataset self.batch_size = args.batch_size self.data_loader = return_data(args) self.gather = DataGather() + import ipdb; ipdb.set_trace() def train(self): self.net_mode(train=True) @@ -167,10 +173,11 @@ class Solver(object): recon_loss=recon_loss.data, total_kld=total_kld.data, dim_wise_kld=dim_wise_kld.data, mean_kld=mean_kld.data) - if self.global_iter%5000 == 0: + if self.global_iter%10000 == 0: self.gather.insert(images=x.data) self.gather.insert(images=x_recon.data) - self.visualize() + self.viz_reconstruction() + self.viz_lines() self.gather.flush() self.save_checkpoint('last') pbar.write('[{}] recon_loss:{:.3f} total_kld:{:.3f} mean_kld:{:.3f}'.format( @@ -185,8 +192,8 @@ class Solver(object): if self.objective == 'advanced': pbar.write('C:{:.3f}'.format(C.data[0])) - if self.global_iter%10000 == 0: - self.traverse() + if self.global_iter%20000 == 0: + self.viz_traverse() if self.global_iter%50000 == 0: self.save_checkpoint(str(self.global_iter)) @@ -198,7 +205,7 @@ class Solver(object): pbar.write("[Training Finished]") pbar.close() - def visualize(self): + def viz_reconstruction(self): self.net_mode(train=False) x = self.gather.data['images'][0][:100] x = make_grid(x, normalize=False) @@ -207,7 +214,10 @@ class Solver(object): images = torch.stack([x, x_recon], dim=0).cpu() self.viz.images(images, env=self.viz_name+'_reconstruction', opts=dict(title=str(self.global_iter)), nrow=10) + self.net_mode(train=True) + def viz_lines(self): + self.net_mode(train=False) recon_losses = torch.stack(self.gather.data['recon_loss']).cpu() mus = torch.stack(self.gather.data['mu']).cpu() @@ -324,14 +334,13 @@ class Solver(object): title='posterior variance',)) self.net_mode(train=True) - def traverse(self, limit=3, inter=2/3): + def viz_traverse(self, limit=3, inter=2/3, loc=-1): self.net_mode(train=False) import random decoder = self.net.decoder encoder = self.net.encoder interpolation = torch.arange(-limit, limit+0.1, inter) - viz = visdom.Visdom(env=self.viz_name+'/traverse', port=self.viz_port) n_dsets = len(self.data_loader.dataset) rand_idx = random.randint(1, n_dsets-1) @@ -369,18 +378,37 @@ class Solver(object): Z = {'fixed_img':fixed_img_z, 'random_img':random_img_z, 'random_z':random_z} + gifs = [] for key in Z.keys(): z_ori = Z[key] samples = [] for row in range(self.z_dim): + if loc != -1 and row != loc: + continue z = z_ori.clone() for val in interpolation: z[:, row] = val - sample = F.sigmoid(decoder(z)) + sample = F.sigmoid(decoder(z)).data samples.append(sample) - samples = torch.cat(samples, dim=0).data.cpu() + gifs.append(sample) + samples = torch.cat(samples, dim=0).cpu() title = '{}_latent_traversal(iter:{})'.format(key, self.global_iter) - viz.images(samples, opts=dict(title=title), nrow=len(interpolation)) + self.viz.images(samples, env=self.viz_name+'_traverse', + opts=dict(title=title), nrow=len(interpolation)) + + if self.save_output: + output_dir = self.output_dir.joinpath(str(self.global_iter)) + output_dir.mkdir(parents=True, exist_ok=True) + gifs = torch.cat(gifs) + gifs = gifs.view(len(Z), self.z_dim, len(interpolation), self.nc, 64, 64).transpose(1, 2) + for i, key in enumerate(Z.keys()): + for j, val in enumerate(interpolation): + save_image(tensor=gifs[i][j].cpu(), + filename=output_dir.joinpath('{}_{}.jpg'.format(key, j)), + nrow=self.z_dim, pad_value=1) + + grid2gif(str(output_dir.joinpath(key+'*.jpg')), + str(output_dir.joinpath(key+'.gif')), delay=10) self.net_mode(train=True) -- GitLab