Skip to content
Snippets Groups Projects
Commit 9af77f77 authored by 1Konny's avatar 1Konny
Browse files

add traverse -> gif

parent f050d1a1
No related branches found
No related tags found
No related merge requests found
...@@ -8,9 +8,9 @@ import torch ...@@ -8,9 +8,9 @@ import torch
import torch.optim as optim import torch.optim as optim
import torch.nn.functional as F import torch.nn.functional as F
from torch.autograd import Variable from torch.autograd import Variable
from torchvision.utils import make_grid from torchvision.utils import make_grid, save_image
from utils import cuda from utils import cuda, grid2gif
from model import BetaVAE_H, BetaVAE_B from model import BetaVAE_H, BetaVAE_B
from dataset import return_data from dataset import return_data
...@@ -90,7 +90,7 @@ class Solver(object): ...@@ -90,7 +90,7 @@ class Solver(object):
self.decoder_dist = 'bernoulli' self.decoder_dist = 'bernoulli'
elif args.dataset.lower() == '3dchairs': elif args.dataset.lower() == '3dchairs':
self.nc = 3 self.nc = 3
self.decoder_dist = 'bernoulli' self.decoder_dist = 'gaussian'
elif args.dataset.lower() == 'celeba': elif args.dataset.lower() == 'celeba':
self.nc = 3 self.nc = 3
self.decoder_dist = 'gaussian' self.decoder_dist = 'gaussian'
...@@ -125,12 +125,18 @@ class Solver(object): ...@@ -125,12 +125,18 @@ class Solver(object):
if self.ckpt_name is not None: if self.ckpt_name is not None:
self.load_checkpoint(self.ckpt_name) self.load_checkpoint(self.ckpt_name)
self.save_output = args.save_output
self.output_dir = Path(args.output_dir).joinpath(args.viz_name)
if not self.output_dir.exists():
self.output_dir.mkdir(parents=True, exist_ok=True)
self.dset_dir = args.dset_dir self.dset_dir = args.dset_dir
self.dataset = args.dataset self.dataset = args.dataset
self.batch_size = args.batch_size self.batch_size = args.batch_size
self.data_loader = return_data(args) self.data_loader = return_data(args)
self.gather = DataGather() self.gather = DataGather()
import ipdb; ipdb.set_trace()
def train(self): def train(self):
self.net_mode(train=True) self.net_mode(train=True)
...@@ -167,10 +173,11 @@ class Solver(object): ...@@ -167,10 +173,11 @@ class Solver(object):
recon_loss=recon_loss.data, total_kld=total_kld.data, recon_loss=recon_loss.data, total_kld=total_kld.data,
dim_wise_kld=dim_wise_kld.data, mean_kld=mean_kld.data) dim_wise_kld=dim_wise_kld.data, mean_kld=mean_kld.data)
if self.global_iter%5000 == 0: if self.global_iter%10000 == 0:
self.gather.insert(images=x.data) self.gather.insert(images=x.data)
self.gather.insert(images=x_recon.data) self.gather.insert(images=x_recon.data)
self.visualize() self.viz_reconstruction()
self.viz_lines()
self.gather.flush() self.gather.flush()
self.save_checkpoint('last') self.save_checkpoint('last')
pbar.write('[{}] recon_loss:{:.3f} total_kld:{:.3f} mean_kld:{:.3f}'.format( pbar.write('[{}] recon_loss:{:.3f} total_kld:{:.3f} mean_kld:{:.3f}'.format(
...@@ -185,8 +192,8 @@ class Solver(object): ...@@ -185,8 +192,8 @@ class Solver(object):
if self.objective == 'advanced': if self.objective == 'advanced':
pbar.write('C:{:.3f}'.format(C.data[0])) pbar.write('C:{:.3f}'.format(C.data[0]))
if self.global_iter%10000 == 0: if self.global_iter%20000 == 0:
self.traverse() self.viz_traverse()
if self.global_iter%50000 == 0: if self.global_iter%50000 == 0:
self.save_checkpoint(str(self.global_iter)) self.save_checkpoint(str(self.global_iter))
...@@ -198,7 +205,7 @@ class Solver(object): ...@@ -198,7 +205,7 @@ class Solver(object):
pbar.write("[Training Finished]") pbar.write("[Training Finished]")
pbar.close() pbar.close()
def visualize(self): def viz_reconstruction(self):
self.net_mode(train=False) self.net_mode(train=False)
x = self.gather.data['images'][0][:100] x = self.gather.data['images'][0][:100]
x = make_grid(x, normalize=False) x = make_grid(x, normalize=False)
...@@ -207,7 +214,10 @@ class Solver(object): ...@@ -207,7 +214,10 @@ class Solver(object):
images = torch.stack([x, x_recon], dim=0).cpu() images = torch.stack([x, x_recon], dim=0).cpu()
self.viz.images(images, env=self.viz_name+'_reconstruction', self.viz.images(images, env=self.viz_name+'_reconstruction',
opts=dict(title=str(self.global_iter)), nrow=10) opts=dict(title=str(self.global_iter)), nrow=10)
self.net_mode(train=True)
def viz_lines(self):
self.net_mode(train=False)
recon_losses = torch.stack(self.gather.data['recon_loss']).cpu() recon_losses = torch.stack(self.gather.data['recon_loss']).cpu()
mus = torch.stack(self.gather.data['mu']).cpu() mus = torch.stack(self.gather.data['mu']).cpu()
...@@ -324,14 +334,13 @@ class Solver(object): ...@@ -324,14 +334,13 @@ class Solver(object):
title='posterior variance',)) title='posterior variance',))
self.net_mode(train=True) self.net_mode(train=True)
def traverse(self, limit=3, inter=2/3): def viz_traverse(self, limit=3, inter=2/3, loc=-1):
self.net_mode(train=False) self.net_mode(train=False)
import random import random
decoder = self.net.decoder decoder = self.net.decoder
encoder = self.net.encoder encoder = self.net.encoder
interpolation = torch.arange(-limit, limit+0.1, inter) interpolation = torch.arange(-limit, limit+0.1, inter)
viz = visdom.Visdom(env=self.viz_name+'/traverse', port=self.viz_port)
n_dsets = len(self.data_loader.dataset) n_dsets = len(self.data_loader.dataset)
rand_idx = random.randint(1, n_dsets-1) rand_idx = random.randint(1, n_dsets-1)
...@@ -369,18 +378,37 @@ class Solver(object): ...@@ -369,18 +378,37 @@ class Solver(object):
Z = {'fixed_img':fixed_img_z, 'random_img':random_img_z, 'random_z':random_z} Z = {'fixed_img':fixed_img_z, 'random_img':random_img_z, 'random_z':random_z}
gifs = []
for key in Z.keys(): for key in Z.keys():
z_ori = Z[key] z_ori = Z[key]
samples = [] samples = []
for row in range(self.z_dim): for row in range(self.z_dim):
if loc != -1 and row != loc:
continue
z = z_ori.clone() z = z_ori.clone()
for val in interpolation: for val in interpolation:
z[:, row] = val z[:, row] = val
sample = F.sigmoid(decoder(z)) sample = F.sigmoid(decoder(z)).data
samples.append(sample) samples.append(sample)
samples = torch.cat(samples, dim=0).data.cpu() gifs.append(sample)
samples = torch.cat(samples, dim=0).cpu()
title = '{}_latent_traversal(iter:{})'.format(key, self.global_iter) title = '{}_latent_traversal(iter:{})'.format(key, self.global_iter)
viz.images(samples, opts=dict(title=title), nrow=len(interpolation)) self.viz.images(samples, env=self.viz_name+'_traverse',
opts=dict(title=title), nrow=len(interpolation))
if self.save_output:
output_dir = self.output_dir.joinpath(str(self.global_iter))
output_dir.mkdir(parents=True, exist_ok=True)
gifs = torch.cat(gifs)
gifs = gifs.view(len(Z), self.z_dim, len(interpolation), self.nc, 64, 64).transpose(1, 2)
for i, key in enumerate(Z.keys()):
for j, val in enumerate(interpolation):
save_image(tensor=gifs[i][j].cpu(),
filename=output_dir.joinpath('{}_{}.jpg'.format(key, j)),
nrow=self.z_dim, pad_value=1)
grid2gif(str(output_dir.joinpath(key+'*.jpg')),
str(output_dir.joinpath(key+'.gif')), delay=10)
self.net_mode(train=True) self.net_mode(train=True)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment