From 23811d9879559c1ae03560be101ed2952bac0248 Mon Sep 17 00:00:00 2001
From: 1Konny <lee.wonkwang94@gmail.com>
Date: Tue, 22 May 2018 18:34:55 +0900
Subject: [PATCH] add new version

---
 main.py            |  27 ++--
 model.py           | 173 ++++++++++++++--------
 run_3dchairs.sh    |   3 -
 run_3dchairs_H.sh  |   5 +
 run_celeba.sh      |   3 -
 run_celeba_H.sh    |   5 +
 run_dsprites.sh    |   3 -
 run_dsprites_B.sh  |   5 +
 run_dsprites_B2.sh |   5 +
 solver.py          | 358 +++++++++++++++++++++++++++++++++------------
 10 files changed, 411 insertions(+), 176 deletions(-)
 delete mode 100644 run_3dchairs.sh
 create mode 100644 run_3dchairs_H.sh
 delete mode 100644 run_celeba.sh
 create mode 100644 run_celeba_H.sh
 delete mode 100644 run_dsprites.sh
 create mode 100644 run_dsprites_B.sh
 create mode 100644 run_dsprites_B2.sh

diff --git a/main.py b/main.py
index c2a626f..85f230c 100644
--- a/main.py
+++ b/main.py
@@ -11,16 +11,13 @@ from utils import str2bool
 torch.backends.cudnn.enabled = True
 torch.backends.cudnn.benchmark = True
 
-init_seed = 1
-torch.manual_seed(init_seed)
-torch.cuda.manual_seed(init_seed)
-np.random.seed(init_seed)
-
-np.set_printoptions(precision=4)
-torch.set_printoptions(precision=4)
-
 
 def main(args):
+    seed = args.seed
+    torch.manual_seed(seed)
+    torch.cuda.manual_seed(seed)
+    np.random.seed(seed)
+
     net = Solver(args)
 
     if args.train:
@@ -33,15 +30,21 @@ if __name__ == "__main__":
     parser = argparse.ArgumentParser(description='toy Beta-VAE')
 
     parser.add_argument('--train', default=True, type=str2bool, help='train or traverse')
+    parser.add_argument('--seed', default=1, type=int, help='random seed')
     parser.add_argument('--cuda', default=True, type=str2bool, help='enable cuda')
     parser.add_argument('--max_iter', default=1e6, type=float, help='maximum training iteration')
     parser.add_argument('--batch_size', default=64, type=int, help='batch size')
 
     parser.add_argument('--z_dim', default=10, type=int, help='dimension of the representation z')
-    parser.add_argument('--beta', default=4, type=float, help='beta parameter for KL-term')
+    parser.add_argument('--beta', default=4, type=float, help='beta parameter for KL-term in original beta-VAE')
+    parser.add_argument('--objective', default='H', type=str, help='beta-vae objective proposed in Higgins et al. or Burgess et al. H/B')
+    parser.add_argument('--model', default='H', type=str, help='model proposed in Higgins et al. or Burgess et al. H/B')
+    parser.add_argument('--gamma', default=1000, type=float, help='gamma parameter for KL-term in understanding beta-VAE')
+    parser.add_argument('--C_max', default=25, type=float, help='capacity parameter(C) of bottleneck channel')
+    parser.add_argument('--C_stop_iter', default=1e5, type=float, help='when to stop increasing the capacity')
     parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')
-    parser.add_argument('--beta1', default=0.5, type=float, help='Adam optimizer beta1')
-    parser.add_argument('--beta2', default=0.9, type=float, help='Adam optimizer beta2')
+    parser.add_argument('--beta1', default=0.9, type=float, help='Adam optimizer beta1')
+    parser.add_argument('--beta2', default=0.999, type=float, help='Adam optimizer beta2')
 
     parser.add_argument('--dset_dir', default='data', type=str, help='dataset directory')
     parser.add_argument('--dataset', default='CelebA', type=str, help='dataset name')
@@ -53,7 +56,7 @@ if __name__ == "__main__":
     parser.add_argument('--viz_port', default=8097, type=str, help='visdom port number')
 
     parser.add_argument('--ckpt_dir', default='checkpoints', type=str, help='checkpoint directory')
-    parser.add_argument('--load_ckpt', default=True, type=str2bool, help='load last checkpoint')
+    parser.add_argument('--ckpt_name', default=None, type=str, help='load previous checkpoint. insert checkpoint filename')
 
     args = parser.parse_args()
 
diff --git a/model.py b/model.py
index fd973ea..61f4791 100644
--- a/model.py
+++ b/model.py
@@ -7,104 +7,141 @@ import torch.nn.init as init
 from torch.autograd import Variable
 
 
-class BetaVAE_3D(nn.Module):
-    def __init__(self, z_dim=10):
-        super(BetaVAE_3D, self).__init__()
+def reparametrize(mu, logvar):
+    std = logvar.div(2).exp()
+    eps = Variable(std.data.new(std.size()).normal_())
+    return mu + std*eps
+
+
+class View(nn.Module):
+    def __init__(self, size):
+        super(View, self).__init__()
+        self.size = size
+
+    def forward(self, tensor):
+        return tensor.view(self.size)
+
+
+class BetaVAE_H(nn.Module):
+    """Model proposed in original beta-VAE paper(Higgins et al, ICLR, 2017)."""
+
+    def __init__(self, z_dim=10, nc=3):
+        super(BetaVAE_H, self).__init__()
         self.z_dim = z_dim
-        self.encode = nn.Sequential(
-            nn.Conv2d(3, 32, 4, 2, 1),
+        self.nc = nc
+        self.encoder = nn.Sequential(
+            nn.Conv2d(nc, 32, 4, 2, 1),          # B,  32, 32, 32
             nn.ReLU(True),
-            nn.Conv2d(32, 32, 4, 2, 1),
+            nn.Conv2d(32, 32, 4, 2, 1),          # B,  32, 16, 16
             nn.ReLU(True),
-            nn.Conv2d(32, 64, 4, 2, 1),
+            nn.Conv2d(32, 64, 4, 2, 1),          # B,  64,  8,  8
             nn.ReLU(True),
-            nn.Conv2d(64, 64, 4, 2, 1),
+            nn.Conv2d(64, 64, 4, 2, 1),          # B,  64,  4,  4
             nn.ReLU(True),
-            nn.Conv2d(64, 256, 4, 1),
+            nn.Conv2d(64, 256, 4, 1),            # B, 256,  1,  1
             nn.ReLU(True),
-            nn.Conv2d(256, 2*self.z_dim, 1)
+            View((-1, 256*1*1)),                 # B, 256
+            nn.Linear(256, z_dim*2),             # B, z_dim*2
         )
-        self.decode = nn.Sequential(
-            nn.Conv2d(self.z_dim, 256, 1),
+        self.decoder = nn.Sequential(
+            nn.Linear(z_dim, 256),               # B, 256
+            View((-1, 256, 1, 1)),               # B, 256,  1,  1
             nn.ReLU(True),
-            nn.ConvTranspose2d(256, 64, 4),
+            nn.ConvTranspose2d(256, 64, 4),      # B,  64,  4,  4
             nn.ReLU(True),
-            nn.ConvTranspose2d(64, 64, 4, 2, 1),
+            nn.ConvTranspose2d(64, 64, 4, 2, 1), # B,  64,  8,  8
             nn.ReLU(True),
-            nn.ConvTranspose2d(64, 32, 4, 2, 1),
+            nn.ConvTranspose2d(64, 32, 4, 2, 1), # B,  32, 16, 16
             nn.ReLU(True),
-            nn.ConvTranspose2d(32, 32, 4, 2, 1),
+            nn.ConvTranspose2d(32, 32, 4, 2, 1), # B,  32, 32, 32
             nn.ReLU(True),
-            nn.ConvTranspose2d(32, 3, 4, 2, 1),
+            nn.ConvTranspose2d(32, nc, 4, 2, 1),  # B, nc, 64, 64
         )
-        self.weight_init()
 
     def weight_init(self):
         for block in self._modules:
             for m in self._modules[block]:
                 kaiming_init(m)
 
-    def reparametrize(self, mu, logvar):
-        std = logvar.mul(0.5).exp_()
-        eps = Variable(std.data.new(std.size()).normal_())
-        return eps.mul(std).add_(mu)
-
     def forward(self, x):
-        stats = self.encode(x)
-        mu = stats[:, :self.z_dim]
-        logvar = stats[:, self.z_dim:]
-        z = self.reparametrize(mu, logvar)
-        x_recon = self.decode(z)
+        distributions = self._encode(x)
+        mu = distributions[:, :self.z_dim]
+        logvar = distributions[:, self.z_dim:]
+        z = reparametrize(mu, logvar)
+        x_recon = self._decode(z)
 
-        return x_recon, mu.squeeze(), logvar.squeeze()
+        return x_recon, mu, logvar
 
+    def _encode(self, x):
+        return self.encoder(x)
 
-class View(nn.Module):
-    def __init__(self, size):
-        super(View, self).__init__()
-        self.size = size
+    def _decode(self, z):
+        return self.decoder(z)
 
-    def forward(self, tensor):
-        return tensor.view(self.size)
 
+class BetaVAE_B(BetaVAE_H):
+    """Model proposed in understanding beta-VAE paper(Burgess et al, arxiv:1804.03599, 2018)."""
 
-class BetaVAE_2D(BetaVAE_3D):
-    def __init__(self, z_dim=10):
-        super(BetaVAE_2D, self).__init__()
+    def __init__(self, z_dim=10, nc=1):
+        super(BetaVAE_B, self).__init__()
+        self.nc = nc
         self.z_dim = z_dim
 
-        # Views are applied just for the consistency in shape with CONV-based models
-        self.encode = nn.Sequential(
-            View((-1, 4096)),
-            nn.Linear(4096, 1200),
+        self.encoder = nn.Sequential(
+            nn.Conv2d(nc, 32, 4, 2, 1),          # B,  32, 32, 32
+            nn.ReLU(True),
+            nn.Conv2d(32, 32, 4, 2, 1),          # B,  32, 16, 16
             nn.ReLU(True),
-            nn.Linear(1200, 1200),
+            nn.Conv2d(32, 32, 4, 2, 1),          # B,  32,  8,  8
             nn.ReLU(True),
-            nn.Linear(1200, 2*self.z_dim),
-            View((-1, 2*self.z_dim, 1, 1)),
+            nn.Conv2d(32, 32, 4, 2, 1),          # B,  32,  4,  4
+            nn.ReLU(True),
+            View((-1, 32*4*4)),                  # B, 512
+            nn.Linear(32*4*4, 256),              # B, 256
+            nn.ReLU(True),
+            nn.Linear(256, 256),                 # B, 256
+            nn.ReLU(True),
+            nn.Linear(256, z_dim*2),             # B, z_dim*2
         )
-        self.decode = nn.Sequential(
-            View((-1, self.z_dim)),
-            nn.Linear(self.z_dim, 1200),
-            nn.Tanh(),
-            nn.Linear(1200, 1200),
-            nn.Tanh(),
-            nn.Linear(1200, 1200),
-            nn.Tanh(),
-            nn.Linear(1200, 4096),
-            View((-1, 1, 64, 64)),
+
+        self.decoder = nn.Sequential(
+            nn.Linear(z_dim, 256),               # B, 256
+            nn.ReLU(True),
+            nn.Linear(256, 256),                 # B, 256
+            nn.ReLU(True),
+            nn.Linear(256, 32*4*4),              # B, 512
+            nn.ReLU(True),
+            View((-1, 32, 4, 4)),                # B,  32,  4,  4
+            nn.ConvTranspose2d(32, 32, 4, 2, 1), # B,  32,  8,  8
+            nn.ReLU(True),
+            nn.ConvTranspose2d(32, 32, 4, 2, 1), # B,  32, 16, 16
+            nn.ReLU(True),
+            nn.ConvTranspose2d(32, 32, 4, 2, 1), # B,  32, 32, 32
+            nn.ReLU(True),
+            nn.ConvTranspose2d(32, nc, 4, 2, 1), # B,  nc, 64, 64
         )
         self.weight_init()
 
+    def weight_init(self):
+        for block in self._modules:
+            for m in self._modules[block]:
+                kaiming_init(m)
+
     def forward(self, x):
-        stats = self.encode(x)
-        mu = stats[:, :self.z_dim]
-        logvar = stats[:, self.z_dim:]
-        z = self.reparametrize(mu, logvar)
-        x_recon = self.decode(z).view(x.size())
+        distributions = self._encode(x)
+        mu = distributions[:, :self.z_dim]
+        logvar = distributions[:, self.z_dim:]
+        z = reparametrize(mu, logvar)
+        x_recon = self._decode(z).view(x.size())
 
         return x_recon, mu, logvar
 
+    def _encode(self, x):
+        return self.encoder(x)
+
+    def _decode(self, z):
+        return self.decoder(z)
+
 
 def kaiming_init(m):
     if isinstance(m, (nn.Linear, nn.Conv2d)):
@@ -117,8 +154,16 @@ def kaiming_init(m):
             m.bias.data.fill_(0)
 
 
+def normal_init(m, mean, std):
+    if isinstance(m, (nn.Linear, nn.Conv2d)):
+        m.weight.data.normal_(mean, std)
+        if m.bias.data is not None:
+            m.bias.data.zero_()
+    elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
+        m.weight.data.fill_(1)
+        if m.bias.data is not None:
+            m.bias.data.zero_()
+
+
 if __name__ == '__main__':
-    import ipdb; ipdb.set_trace()
-    net = BetaVAE(32)
-    x = Variable(torch.rand(1, 3, 64, 64))
-    net(x)
+    pass
diff --git a/run_3dchairs.sh b/run_3dchairs.sh
deleted file mode 100644
index bf3e82d..0000000
--- a/run_3dchairs.sh
+++ /dev/null
@@ -1,3 +0,0 @@
-#! /bin/sh
-
-python main.py --dataset 3dchairs --max_iter 1e6 --beta 4 --batch_size 64 --lr 1e-4 --z_dim 10 --viz_name beta_vae_3dchairs --viz_port 55558
diff --git a/run_3dchairs_H.sh b/run_3dchairs_H.sh
new file mode 100644
index 0000000..81ed6e8
--- /dev/null
+++ b/run_3dchairs_H.sh
@@ -0,0 +1,5 @@
+#! /bin/sh
+
+python main.py --dataset 3dchairs --seed 1 --lr 1e-4 --beta1 0.9 --beta 0.999 \
+    --objective H --model H --batch_size 64 --z_dim 10 --max_iter 1e6 \
+    --beta 4 --viz_name 3dchairs_H
diff --git a/run_celeba.sh b/run_celeba.sh
deleted file mode 100644
index 8a221e8..0000000
--- a/run_celeba.sh
+++ /dev/null
@@ -1,3 +0,0 @@
-#! /bin/sh
-
-python main.py --dataset celeba --max_iter 1e6 --beta 64 --batch_size 64 --lr 1e-4 --z_dim 10 --viz_name beta_vae_celeba --viz_port 55558
diff --git a/run_celeba_H.sh b/run_celeba_H.sh
new file mode 100644
index 0000000..1c2dee8
--- /dev/null
+++ b/run_celeba_H.sh
@@ -0,0 +1,5 @@
+#! /bin/sh
+
+python main.py --dataset celeba --seed 1 --lr 1e-4 --beta1 0.9 --beta2 0.999 \
+    --objective H --model H --batch_size 64 --z_dim 10 --max_iter 1e6 \
+    --beta 64 --viz_name celeba_H
diff --git a/run_dsprites.sh b/run_dsprites.sh
deleted file mode 100644
index 0658184..0000000
--- a/run_dsprites.sh
+++ /dev/null
@@ -1,3 +0,0 @@
-#! /bin/sh
-
-python main.py --dataset dsprites --max_iter 3e5 --beta 4 --batch_size 64 --lr 1e-4 --z_dim 10 --viz_name beta_vae_dsprites --viz_port 55558
diff --git a/run_dsprites_B.sh b/run_dsprites_B.sh
new file mode 100644
index 0000000..346f5c2
--- /dev/null
+++ b/run_dsprites_B.sh
@@ -0,0 +1,5 @@
+#! /bin/sh
+
+python main.py --dataset dsprites --seed 2 --lr 5e-4 --beta1 0.9 --beta2 0.999 \
+    --objective B --model B --batch_size 64 --z_dim 10 --max_iter 1e6 \
+    --C_stop_iter 1e5 --C_max 20 --gamma 100 --viz_name dsprites_B
diff --git a/run_dsprites_B2.sh b/run_dsprites_B2.sh
new file mode 100644
index 0000000..f373406
--- /dev/null
+++ b/run_dsprites_B2.sh
@@ -0,0 +1,5 @@
+#! /bin/sh
+
+python main.py --dataset dsprites --seed 2 --lr 5e-4 --beta1 0.9 --beta2 0.999 \
+    --objective B --model B --batch_size 64 --z_dim 10 --max_iter 1e6 \
+    --C_stop_iter 1e5 --C_max 20 --gamma 100 --viz_name dsprites_B --load_ckpt last --viz_port 55558
diff --git a/solver.py b/solver.py
index 15c6dfd..27a8779 100644
--- a/solver.py
+++ b/solver.py
@@ -1,9 +1,9 @@
 """solver.py"""
 
-import time
 from pathlib import Path
-
+from tqdm import tqdm
 import visdom
+
 import torch
 import torch.optim as optim
 import torch.nn.functional as F
@@ -11,194 +11,364 @@ from torch.autograd import Variable
 from torchvision.utils import make_grid
 
 from utils import cuda
-from model import BetaVAE_2D, BetaVAE_3D
+from model import BetaVAE_H, BetaVAE_B
 from dataset import return_data
 
 
-def original_vae_loss(x, x_recon, mu, logvar):
+def reconstruction_loss(x, x_recon, distribution):
     batch_size = x.size(0)
-    if batch_size == 0:
-        recon_loss = 0
-        kl_divergence = 0
-    else:
+    assert batch_size != 0
+
+    if distribution == 'bernoulli':
         recon_loss = F.binary_cross_entropy_with_logits(x_recon, x, size_average=False).div(batch_size)
-        # kld which one is correct? from the equation in most of papers,
-        # I think the first one is correct but official pytorch code uses the second one.
+    elif distribution == 'gaussian':
+        x_recon = F.sigmoid(x_recon)
+        recon_loss = F.mse_loss(x_recon, x, size_average=False).div(batch_size)
+    else:
+        recon_loss = None
+
+    return recon_loss
+
+
+def kl_divergence(mu, logvar):
+    batch_size = mu.size(0)
+    assert batch_size != 0
+    if mu.data.ndimension() == 4:
+        mu = mu.view(mu.size(0), mu.size(1))
+    if logvar.data.ndimension() == 4:
+        logvar = logvar.view(logvar.size(0), logvar.size(1))
+
+    klds = -0.5*(1 + logvar - mu.pow(2) - logvar.exp())
+    total_kld = klds.sum(1).mean()
+    dimension_wise_kld = klds.mean(0)
+    mean_kld = klds.mean()
+
+    return total_kld, dimension_wise_kld, mean_kld
+
+
+class DataGather(object):
+    def __init__(self):
+        self.data = self.get_empty_data_dict()
+
+    def get_empty_data_dict(self):
+        return dict(iter=[],
+                    recon_loss=[],
+                    total_kld=[],
+                    dim_wise_kld=[],
+                    mean_kld=[],
+                    mu=[],
+                    var=[],
+                    images=[],)
 
-        kl_divergence = -0.5*(1 + logvar - mu**2 - logvar.exp()).sum(1).mean()
-        #kl_divergence = -0.5*(1 + logvar - mu**2 - logvar.exp()).sum()
-        #dimension_wise_kl_divergence = -0.5*(1 + logvar - mu**2 - logvar.exp()).mean(0)
+    def insert(self, **kwargs):
+        for key in kwargs:
+            self.data[key].append(kwargs[key])
 
-    return recon_loss, kl_divergence
+    def flush(self):
+        self.data = self.get_empty_data_dict()
 
 
 class Solver(object):
     def __init__(self, args):
-
-        # Misc
         self.use_cuda = args.cuda and torch.cuda.is_available()
         self.max_iter = args.max_iter
         self.global_iter = 0
 
-        # Networks & Optimizers
         self.z_dim = args.z_dim
         self.beta = args.beta
-
+        self.gamma = args.gamma
+        self.C_max = args.C_max
+        self.C_stop_iter = args.C_stop_iter
+        self.objective = args.objective
+        self.model = args.model
         self.lr = args.lr
         self.beta1 = args.beta1
         self.beta2 = args.beta2
 
         if args.dataset.lower() == 'dsprites':
-            net = BetaVAE_2D
+            self.nc = 1
+            self.decoder_dist = 'bernoulli'
         elif args.dataset.lower() == '3dchairs':
-            net = BetaVAE_3D
+            self.nc = 3
+            self.decoder_dist = 'bernoulli'
         elif args.dataset.lower() == 'celeba':
-            net = BetaVAE_3D
+            self.nc = 3
+            self.decoder_dist = 'gaussian'
         else:
             raise NotImplementedError
 
-        self.net = cuda(net(self.z_dim), self.use_cuda)
+        if args.model == 'H':
+            net = BetaVAE_H
+        elif args.model == 'B':
+            net = BetaVAE_B
+        else:
+            raise NotImplementedError('only support model H or B')
+
+        self.net = cuda(net(self.z_dim, self.nc), self.use_cuda)
         self.optim = optim.Adam(self.net.parameters(), lr=self.lr,
                                     betas=(self.beta1, self.beta2))
 
-        # Visdom
         self.viz_name = args.viz_name
         self.viz_port = args.viz_port
         self.viz_on = args.viz_on
         if self.viz_on:
-            self.viz = visdom.Visdom(env=self.viz_name, port=self.viz_port)
-            self.viz_curves = visdom.Visdom(env=self.viz_name+'/train_curves', port=self.viz_port)
+            self.viz = visdom.Visdom(env=self.viz_name+'_lines', port=self.viz_port)
             self.win_recon = None
             self.win_kld = None
+            self.win_mu = None
+            self.win_var = None
 
-        # Checkpoint
         self.ckpt_dir = Path(args.ckpt_dir).joinpath(args.viz_name)
         if not self.ckpt_dir.exists():
             self.ckpt_dir.mkdir(parents=True, exist_ok=True)
+        self.ckpt_name = args.ckpt_name
+        if self.ckpt_name is not None:
+            self.load_checkpoint(self.ckpt_name)
 
-        self.load_ckpt = args.load_ckpt
-        if self.load_ckpt:
-            self.load_checkpoint()
-
-        # Data
         self.dset_dir = args.dset_dir
+        self.dataset = args.dataset
         self.batch_size = args.batch_size
         self.data_loader = return_data(args)
 
+        self.gather = DataGather()
+
     def train(self):
         self.net_mode(train=True)
+        self.C_max = Variable(cuda(torch.FloatTensor([self.C_max]), self.use_cuda))
         out = False
+
+        pbar = tqdm(total=self.max_iter)
+        pbar.update(self.global_iter)
         while not out:
-            start = time.time()
             curve_data = []
+            curves = dict(iter=[], total_kld=[], dim_wise_kld=[], mean_kld=[])
             for x in self.data_loader:
                 self.global_iter += 1
+                pbar.update(1)
 
                 x = Variable(cuda(x, self.use_cuda))
                 x_recon, mu, logvar = self.net(x)
-                recon_loss, kld = original_vae_loss(x, x_recon, mu, logvar)
+                recon_loss = reconstruction_loss(x, x_recon, self.decoder_dist)
+                total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
 
-                beta_vae_loss = recon_loss + self.beta*kld
+                if self.objective == 'H':
+                    beta_vae_loss = recon_loss + self.beta*total_kld
+                elif self.objective == 'B':
+                    C = torch.clamp(self.C_max/self.C_stop_iter*self.global_iter, 0, self.C_max.data[0])
+                    beta_vae_loss = recon_loss + self.gamma*(total_kld-C).abs()
 
                 self.optim.zero_grad()
                 beta_vae_loss.backward()
                 self.optim.step()
 
                 if self.global_iter%1000 == 0:
-                    curve_data.append(torch.Tensor([self.global_iter,
-                                                    recon_loss.data[0],
-                                                    kld.data[0],]))
+                    self.gather.insert(iter=self.global_iter,
+                                       mu=mu.mean(0).data, var=logvar.exp().mean(0).data,
+                                       recon_loss=recon_loss.data, total_kld=total_kld.data,
+                                       dim_wise_kld=dim_wise_kld.data, mean_kld=mean_kld.data)
 
                 if self.global_iter%5000 == 0:
-                    self.save_checkpoint()
-                    self.visualize(dict(image=[x, x_recon], curve=curve_data))
-                    print('[{}] recon_loss:{:.3f} beta*kld:{:.3f}'.format(
-                        self.global_iter, recon_loss.data[0], self.beta*kld.data[0]))
-                    curve_data = []
-
-                if self.global_iter%100000 == 0:
+                    self.gather.insert(images=x.data)
+                    self.gather.insert(images=x_recon.data)
+                    self.visualize()
+                    self.gather.flush()
+                    self.save_checkpoint('last')
+                    pbar.write('[{}] recon_loss:{:.3f} total_kld:{:.3f} mean_kld:{:.3f}'.format(
+                        self.global_iter, recon_loss.data[0], total_kld.data[0], mean_kld.data[0]))
+
+                    var = logvar.exp().mean(0).data
+                    var_str = ''
+                    for j, var_j in enumerate(var):
+                        var_str += 'var{}:{:.4f} '.format(j+1, var_j)
+                    pbar.write(var_str)
+
+                    if self.objective == 'advanced':
+                        pbar.write('C:{:.3f}'.format(C.data[0]))
+
+                if self.global_iter%10000 == 0:
                     self.traverse()
 
+                if self.global_iter%50000 == 0:
+                    self.save_checkpoint(str(self.global_iter))
+
                 if self.global_iter >= self.max_iter:
                     out = True
                     break
 
-            end = time.time()
-            print('[time elapsed] {:.2f}s/epoch'.format(end-start))
-        print("[Training Finished]")
+        pbar.write("[Training Finished]")
+        pbar.close()
 
-    def visualize(self, data):
-        x, x_recon = data['image']
-        curve_data = data['curve']
+    def visualize(self):
+        self.net_mode(train=False)
+        x = self.gather.data['images'][0][:100]
+        x = make_grid(x, normalize=False)
+        x_recon = F.sigmoid(self.gather.data['images'][1])[:100]
+        x_recon = make_grid(x_recon, normalize=False)
+        images = torch.stack([x, x_recon], dim=0).cpu()
+        self.viz.images(images, env=self.viz_name+'_reconstruction',
+                        opts=dict(title=str(self.global_iter)), nrow=10)
 
-        sample_x = make_grid(x.data.cpu(), normalize=False)
-        sample_x_recon = make_grid(F.sigmoid(x_recon).data.cpu(), normalize=False)
-        samples = torch.stack([sample_x, sample_x_recon], dim=0)
-        self.viz.images(samples, opts=dict(title=str(self.global_iter)))
+        recon_losses = torch.stack(self.gather.data['recon_loss']).cpu()
 
-        curve_data = torch.stack(curve_data, dim=0)
-        curve_iter = curve_data[:, 0]
-        curve_recon = curve_data[:, 1]
-        curve_kld = curve_data[:, 2]
+        mus = torch.stack(self.gather.data['mu']).cpu()
+        vars = torch.stack(self.gather.data['var']).cpu()
+
+        dim_wise_klds = torch.stack(self.gather.data['dim_wise_kld'])
+        mean_klds = torch.stack(self.gather.data['mean_kld'])
+        total_klds = torch.stack(self.gather.data['total_kld'])
+        klds = torch.cat([dim_wise_klds, mean_klds, total_klds], 1).cpu()
+        legend = []
+        for z_j in range(self.z_dim):
+            legend.append('z_{}'.format(z_j))
+        legend.append('mean')
+        legend.append('total')
+
+        iters = torch.Tensor(self.gather.data['iter'])
 
         if self.win_recon is None:
-            self.win_recon = self.viz_curves.line(
-                                        X=curve_iter,
-                                        Y=curve_recon,
+            self.win_recon = self.viz.line(
+                                        X=iters,
+                                        Y=recon_losses,
+                                        env=self.viz_name+'_lines',
                                         opts=dict(
+                                            width=400,
+                                            height=400,
                                             xlabel='iteration',
-                                            ylabel='reconsturction loss',))
+                                            title='reconsturction loss',))
         else:
-            self.win_recon = self.viz_curves.line(
-                                        X=curve_iter,
-                                        Y=curve_recon,
+            self.win_recon = self.viz.line(
+                                        X=iters,
+                                        Y=recon_losses,
+                                        env=self.viz_name+'_lines',
                                         win=self.win_recon,
                                         update='append',
                                         opts=dict(
+                                            width=400,
+                                            height=400,
                                             xlabel='iteration',
-                                            ylabel='reconsturction loss',))
+                                            title='reconsturction loss',))
 
         if self.win_kld is None:
-            self.win_kld = self.viz_curves.line(
-                                        X=curve_iter,
-                                        Y=curve_kld,
+            self.win_kld = self.viz.line(
+                                        X=iters,
+                                        Y=klds,
+                                        env=self.viz_name+'_lines',
                                         opts=dict(
+                                            width=400,
+                                            height=400,
+                                            legend=legend,
                                             xlabel='iteration',
-                                            ylabel='kl divergence',))
+                                            title='kl divergence',))
         else:
-            self.win_kld = self.viz_curves.line(
-                                        X=curve_iter,
-                                        Y=curve_kld,
+            self.win_kld = self.viz.line(
+                                        X=iters,
+                                        Y=klds,
+                                        env=self.viz_name+'_lines',
                                         win=self.win_kld,
                                         update='append',
                                         opts=dict(
+                                            width=400,
+                                            height=400,
+                                            legend=legend,
                                             xlabel='iteration',
-                                            ylabel='kl divergence',))
+                                            title='kl divergence',))
 
-    def traverse(self):
+        if self.win_mu is None:
+            self.win_mu = self.viz.line(
+                                        X=iters,
+                                        Y=mus,
+                                        env=self.viz_name+'_lines',
+                                        opts=dict(
+                                            width=400,
+                                            height=400,
+                                            legend=legend[:self.z_dim],
+                                            xlabel='iteration',
+                                            title='posterior mean',))
+        else:
+            self.win_mu = self.viz.line(
+                                        X=iters,
+                                        Y=vars,
+                                        env=self.viz_name+'_lines',
+                                        win=self.win_mu,
+                                        update='append',
+                                        opts=dict(
+                                            width=400,
+                                            height=400,
+                                            legend=legend[:self.z_dim],
+                                            xlabel='iteration',
+                                            title='posterior mean',))
+
+        if self.win_var is None:
+            self.win_var = self.viz.line(
+                                        X=iters,
+                                        Y=vars,
+                                        env=self.viz_name+'_lines',
+                                        opts=dict(
+                                            width=400,
+                                            height=400,
+                                            legend=legend[:self.z_dim],
+                                            xlabel='iteration',
+                                            title='posterior variance',))
+        else:
+            self.win_var = self.viz.line(
+                                        X=iters,
+                                        Y=vars,
+                                        env=self.viz_name+'_lines',
+                                        win=self.win_var,
+                                        update='append',
+                                        opts=dict(
+                                            width=400,
+                                            height=400,
+                                            legend=legend[:self.z_dim],
+                                            xlabel='iteration',
+                                            title='posterior variance',))
+        self.net_mode(train=True)
+
+    def traverse(self, limit=3, inter=2/3):
+        self.net_mode(train=False)
         import random
 
-        decoder = self.net.decode
-        encoder = self.net.encode
-        interpolation = torch.arange(-6, 6.1, 1)
+        decoder = self.net.decoder
+        encoder = self.net.encoder
+        interpolation = torch.arange(-limit, limit+0.1, inter)
         viz = visdom.Visdom(env=self.viz_name+'/traverse', port=self.viz_port)
 
-        n_dsets = self.data_loader.dataset.__len__()
-        fixed_idx = 0
+        n_dsets = len(self.data_loader.dataset)
         rand_idx = random.randint(1, n_dsets-1)
 
-        fixed_img = self.data_loader.dataset.__getitem__(fixed_idx)
-        fixed_img = Variable(cuda(fixed_img, self.use_cuda), volatile=True).unsqueeze(0)
-        fixed_img_z = encoder(fixed_img)[:, :self.z_dim]
-
         random_img = self.data_loader.dataset.__getitem__(rand_idx)
         random_img = Variable(cuda(random_img, self.use_cuda), volatile=True).unsqueeze(0)
         random_img_z = encoder(random_img)[:, :self.z_dim]
 
-        zero_z = Variable(cuda(torch.zeros(1, self.z_dim, 1, 1), self.use_cuda), volatile=True)
-        random_z = Variable(cuda(torch.rand(1, self.z_dim, 1, 1), self.use_cuda), volatile=True)
+        random_z = Variable(cuda(torch.rand(1, self.z_dim), self.use_cuda), volatile=True)
+
+        if self.dataset == 'dsprites':
+            fixed_idx1 = 87040 # square
+            fixed_idx2 = 332800 # ellipse
+            fixed_idx3 = 578560 # heart
+
+            fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)
+            fixed_img1 = Variable(cuda(fixed_img1, self.use_cuda), volatile=True).unsqueeze(0)
+            fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim]
+
+            fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)
+            fixed_img2 = Variable(cuda(fixed_img2, self.use_cuda), volatile=True).unsqueeze(0)
+            fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim]
+
+            fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)
+            fixed_img3 = Variable(cuda(fixed_img3, self.use_cuda), volatile=True).unsqueeze(0)
+            fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim]
+
+            Z = {'fixed_square':fixed_img_z1, 'fixed_ellipse':fixed_img_z2,
+                 'fixed_heart':fixed_img_z3, 'random_img':random_img_z}
+        else:
+            fixed_idx = 0
+            fixed_img = self.data_loader.dataset.__getitem__(fixed_idx)
+            fixed_img = Variable(cuda(fixed_img, self.use_cuda), volatile=True).unsqueeze(0)
+            fixed_img_z = encoder(fixed_img)[:, :self.z_dim]
+
+            Z = {'fixed_img':fixed_img_z, 'random_img':random_img_z, 'random_z':random_z}
 
-        Z = {'fixed_img':fixed_img_z, 'random_img':random_img_z, 'random_z':random_z, 'zero_z':zero_z}
         for key in Z.keys():
             z_ori = Z[key]
             samples = []
@@ -209,9 +379,11 @@ class Solver(object):
                     sample = F.sigmoid(decoder(z))
                     samples.append(sample)
             samples = torch.cat(samples, dim=0).data.cpu()
-            title = '{}_row_traverse(iter:{})'.format(key, self.global_iter)
+            title = '{}_latent_traversal(iter:{})'.format(key, self.global_iter)
             viz.images(samples, opts=dict(title=title), nrow=len(interpolation))
 
+        self.net_mode(train=True)
+
     def net_mode(self, train):
         if not isinstance(train, bool):
             raise('Only bool type is supported. True or False')
@@ -221,11 +393,13 @@ class Solver(object):
         else:
             self.net.eval()
 
-    def save_checkpoint(self, filename='ckpt.tar', silent=True):
+    def save_checkpoint(self, filename, silent=True):
         model_states = {'net':self.net.state_dict(),}
         optim_states = {'optim':self.optim.state_dict(),}
         win_states = {'recon':self.win_recon,
-                      'kld':self.win_kld,}
+                      'kld':self.win_kld,
+                      'mu':self.win_mu,
+                      'var':self.win_var,}
         states = {'iter':self.global_iter,
                   'win_states':win_states,
                   'model_states':model_states,
@@ -236,13 +410,15 @@ class Solver(object):
         if not silent:
             print("=> saved checkpoint '{}' (iter {})".format(file_path, self.global_iter))
 
-    def load_checkpoint(self, filename='ckpt.tar'):
+    def load_checkpoint(self, filename):
         file_path = self.ckpt_dir.joinpath(filename)
         if file_path.is_file():
             checkpoint = torch.load(file_path.open('rb'))
             self.global_iter = checkpoint['iter']
             self.win_recon = checkpoint['win_states']['recon']
             self.win_kld = checkpoint['win_states']['kld']
+            self.win_var = checkpoint['win_states']['var']
+            self.win_mu = checkpoint['win_states']['mu']
             self.net.load_state_dict(checkpoint['model_states']['net'])
             self.optim.load_state_dict(checkpoint['optim_states']['optim'])
             print("=> loaded checkpoint '{} (iter {})'".format(file_path, self.global_iter))
-- 
GitLab