From 9af77f77bc0d1e11a54adb166f8a3adc4c11f84c Mon Sep 17 00:00:00 2001
From: 1Konny <lee.wonkwang94@gmail.com>
Date: Wed, 23 May 2018 23:22:08 +0900
Subject: [PATCH] add traverse -> gif

---
 solver.py | 54 +++++++++++++++++++++++++++++++++++++++++-------------
 1 file changed, 41 insertions(+), 13 deletions(-)

diff --git a/solver.py b/solver.py
index cd9176e..128fbe7 100644
--- a/solver.py
+++ b/solver.py
@@ -8,9 +8,9 @@ import torch
 import torch.optim as optim
 import torch.nn.functional as F
 from torch.autograd import Variable
-from torchvision.utils import make_grid
+from torchvision.utils import make_grid, save_image
 
-from utils import cuda
+from utils import cuda, grid2gif
 from model import BetaVAE_H, BetaVAE_B
 from dataset import return_data
 
@@ -90,7 +90,7 @@ class Solver(object):
             self.decoder_dist = 'bernoulli'
         elif args.dataset.lower() == '3dchairs':
             self.nc = 3
-            self.decoder_dist = 'bernoulli'
+            self.decoder_dist = 'gaussian'
         elif args.dataset.lower() == 'celeba':
             self.nc = 3
             self.decoder_dist = 'gaussian'
@@ -125,12 +125,18 @@ class Solver(object):
         if self.ckpt_name is not None:
             self.load_checkpoint(self.ckpt_name)
 
+        self.save_output = args.save_output
+        self.output_dir = Path(args.output_dir).joinpath(args.viz_name)
+        if not self.output_dir.exists():
+            self.output_dir.mkdir(parents=True, exist_ok=True)
+
         self.dset_dir = args.dset_dir
         self.dataset = args.dataset
         self.batch_size = args.batch_size
         self.data_loader = return_data(args)
 
         self.gather = DataGather()
+        import ipdb; ipdb.set_trace()
 
     def train(self):
         self.net_mode(train=True)
@@ -167,10 +173,11 @@ class Solver(object):
                                        recon_loss=recon_loss.data, total_kld=total_kld.data,
                                        dim_wise_kld=dim_wise_kld.data, mean_kld=mean_kld.data)
 
-                if self.global_iter%5000 == 0:
+                if self.global_iter%10000 == 0:
                     self.gather.insert(images=x.data)
                     self.gather.insert(images=x_recon.data)
-                    self.visualize()
+                    self.viz_reconstruction()
+                    self.viz_lines()
                     self.gather.flush()
                     self.save_checkpoint('last')
                     pbar.write('[{}] recon_loss:{:.3f} total_kld:{:.3f} mean_kld:{:.3f}'.format(
@@ -185,8 +192,8 @@ class Solver(object):
                     if self.objective == 'advanced':
                         pbar.write('C:{:.3f}'.format(C.data[0]))
 
-                if self.global_iter%10000 == 0:
-                    self.traverse()
+                if self.global_iter%20000 == 0:
+                    self.viz_traverse()
 
                 if self.global_iter%50000 == 0:
                     self.save_checkpoint(str(self.global_iter))
@@ -198,7 +205,7 @@ class Solver(object):
         pbar.write("[Training Finished]")
         pbar.close()
 
-    def visualize(self):
+    def viz_reconstruction(self):
         self.net_mode(train=False)
         x = self.gather.data['images'][0][:100]
         x = make_grid(x, normalize=False)
@@ -207,7 +214,10 @@ class Solver(object):
         images = torch.stack([x, x_recon], dim=0).cpu()
         self.viz.images(images, env=self.viz_name+'_reconstruction',
                         opts=dict(title=str(self.global_iter)), nrow=10)
+        self.net_mode(train=True)
 
+    def viz_lines(self):
+        self.net_mode(train=False)
         recon_losses = torch.stack(self.gather.data['recon_loss']).cpu()
 
         mus = torch.stack(self.gather.data['mu']).cpu()
@@ -324,14 +334,13 @@ class Solver(object):
                                             title='posterior variance',))
         self.net_mode(train=True)
 
-    def traverse(self, limit=3, inter=2/3):
+    def viz_traverse(self, limit=3, inter=2/3, loc=-1):
         self.net_mode(train=False)
         import random
 
         decoder = self.net.decoder
         encoder = self.net.encoder
         interpolation = torch.arange(-limit, limit+0.1, inter)
-        viz = visdom.Visdom(env=self.viz_name+'/traverse', port=self.viz_port)
 
         n_dsets = len(self.data_loader.dataset)
         rand_idx = random.randint(1, n_dsets-1)
@@ -369,18 +378,37 @@ class Solver(object):
 
             Z = {'fixed_img':fixed_img_z, 'random_img':random_img_z, 'random_z':random_z}
 
+        gifs = []
         for key in Z.keys():
             z_ori = Z[key]
             samples = []
             for row in range(self.z_dim):
+                if loc != -1 and row != loc:
+                    continue
                 z = z_ori.clone()
                 for val in interpolation:
                     z[:, row] = val
-                    sample = F.sigmoid(decoder(z))
+                    sample = F.sigmoid(decoder(z)).data
                     samples.append(sample)
-            samples = torch.cat(samples, dim=0).data.cpu()
+                    gifs.append(sample)
+            samples = torch.cat(samples, dim=0).cpu()
             title = '{}_latent_traversal(iter:{})'.format(key, self.global_iter)
-            viz.images(samples, opts=dict(title=title), nrow=len(interpolation))
+            self.viz.images(samples, env=self.viz_name+'_traverse',
+                            opts=dict(title=title), nrow=len(interpolation))
+
+        if self.save_output:
+            output_dir = self.output_dir.joinpath(str(self.global_iter))
+            output_dir.mkdir(parents=True, exist_ok=True)
+            gifs = torch.cat(gifs)
+            gifs = gifs.view(len(Z), self.z_dim, len(interpolation), self.nc, 64, 64).transpose(1, 2)
+            for i, key in enumerate(Z.keys()):
+                for j, val in enumerate(interpolation):
+                    save_image(tensor=gifs[i][j].cpu(),
+                               filename=output_dir.joinpath('{}_{}.jpg'.format(key, j)),
+                               nrow=self.z_dim, pad_value=1)
+
+                grid2gif(str(output_dir.joinpath(key+'*.jpg')),
+                         str(output_dir.joinpath(key+'.gif')), delay=10)
 
         self.net_mode(train=True)
 
-- 
GitLab