From 7d4d80e2721d8b42c8c7d82dc1720980175576b2 Mon Sep 17 00:00:00 2001
From: 1Konny <lee.wonkwang94@gmail.com>
Date: Tue, 17 Apr 2018 10:43:55 +0900
Subject: [PATCH] add traverse

---
 solver.py | 45 +++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 45 insertions(+)

diff --git a/solver.py b/solver.py
index 5c63823..a7abcdf 100644
--- a/solver.py
+++ b/solver.py
@@ -163,6 +163,51 @@ class Solver(object):
                                             xlabel='iteration',
                                             ylabel='kl divergence',))
 
+    def traverse(self):
+        import random
+
+        decoder = self.net.decode
+        encoder = self.net.encode
+        interpolation = torch.arange(-6, 6.1, 2)
+        viz = visdom.Visdom(env=self.viz_name+'/traverse', port=self.viz_port)
+
+        n_dsets = self.data_loader.dataset.__len__()
+        fixed_idx = 0
+        rand_idx = random.randint(1, n_dsets-1)
+
+        fixed_img = self.data_loader.dataset.__getitem__(fixed_idx)
+        random_img = self.data_loader.dataset.__getitem__(rand_idx)
+        fixed_img = Variable(cuda(fixed_img, self.use_cuda), volatile=True).unsqueeze(0)
+        random_img = Variable(cuda(random_img, self.use_cuda), volatile=True).unsqueeze(0)
+        zero_z = Variable(cuda(torch.zeros(1, self.z_dim, 1, 1), self.use_cuda), volatile=True)
+        random_z = Variable(cuda(torch.rand(1, self.z_dim, 1, 1), self.use_cuda), volatile=True)
+
+        src = [fixed_img, random_img, zero_z, random_z]
+        for i, vector in enumerate(src):
+            if i < 2:
+                z_ori = encoder(vector)[:, :self.z_dim]
+            else:
+                z_ori = vector
+
+            samples = []
+            for row in range(self.z_dim):
+                z = z_ori.clone()
+                for val in interpolation:
+                    z[:, row] = val
+                    sample = F.sigmoid(decoder(z))
+                    samples.append(sample)
+            samples = torch.cat(samples, dim=0).data.cpu()
+            if i==0:
+                title = 'traverse representation from fixed image'
+            elif i==1:
+                title = 'traverse representation random image'
+            elif i==2:
+                title = 'traverse zero representation vector'
+            elif i==3:
+                title = 'traverse random gaussian representation vector'
+            title += '(iter:{})'.format(self.global_iter)
+            viz.images(samples, opts=dict(title=title), nrow=len(interpolation))
+
     def net_mode(self, train):
         if not isinstance(train, bool):
             raise('Only bool type is supported. True or False')
-- 
GitLab