From 7d4d80e2721d8b42c8c7d82dc1720980175576b2 Mon Sep 17 00:00:00 2001 From: 1Konny <lee.wonkwang94@gmail.com> Date: Tue, 17 Apr 2018 10:43:55 +0900 Subject: [PATCH] add traverse --- solver.py | 45 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/solver.py b/solver.py index 5c63823..a7abcdf 100644 --- a/solver.py +++ b/solver.py @@ -163,6 +163,51 @@ class Solver(object): xlabel='iteration', ylabel='kl divergence',)) + def traverse(self): + import random + + decoder = self.net.decode + encoder = self.net.encode + interpolation = torch.arange(-6, 6.1, 2) + viz = visdom.Visdom(env=self.viz_name+'/traverse', port=self.viz_port) + + n_dsets = self.data_loader.dataset.__len__() + fixed_idx = 0 + rand_idx = random.randint(1, n_dsets-1) + + fixed_img = self.data_loader.dataset.__getitem__(fixed_idx) + random_img = self.data_loader.dataset.__getitem__(rand_idx) + fixed_img = Variable(cuda(fixed_img, self.use_cuda), volatile=True).unsqueeze(0) + random_img = Variable(cuda(random_img, self.use_cuda), volatile=True).unsqueeze(0) + zero_z = Variable(cuda(torch.zeros(1, self.z_dim, 1, 1), self.use_cuda), volatile=True) + random_z = Variable(cuda(torch.rand(1, self.z_dim, 1, 1), self.use_cuda), volatile=True) + + src = [fixed_img, random_img, zero_z, random_z] + for i, vector in enumerate(src): + if i < 2: + z_ori = encoder(vector)[:, :self.z_dim] + else: + z_ori = vector + + samples = [] + for row in range(self.z_dim): + z = z_ori.clone() + for val in interpolation: + z[:, row] = val + sample = F.sigmoid(decoder(z)) + samples.append(sample) + samples = torch.cat(samples, dim=0).data.cpu() + if i==0: + title = 'traverse representation from fixed image' + elif i==1: + title = 'traverse representation random image' + elif i==2: + title = 'traverse zero representation vector' + elif i==3: + title = 'traverse random gaussian representation vector' + title += '(iter:{})'.format(self.global_iter) + viz.images(samples, opts=dict(title=title), nrow=len(interpolation)) + def net_mode(self, train): if not isinstance(train, bool): raise('Only bool type is supported. True or False') -- GitLab