Skip to content
Snippets Groups Projects
solver.py 18.8 KiB
Newer Older
1Konny's avatar
1Konny committed
"""solver.py"""

from pathlib import Path
1Konny's avatar
1Konny committed
from tqdm import tqdm
1Konny's avatar
1Konny committed
import visdom
1Konny's avatar
1Konny committed

1Konny's avatar
1Konny committed
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
1Konny's avatar
1Konny committed
from torchvision.utils import make_grid, save_image
1Konny's avatar
1Konny committed

1Konny's avatar
1Konny committed
from utils import cuda, grid2gif
1Konny's avatar
1Konny committed
from model import BetaVAE_H, BetaVAE_B
1Konny's avatar
1Konny committed
from dataset import return_data


1Konny's avatar
1Konny committed
def reconstruction_loss(x, x_recon, distribution):
1Konny's avatar
1Konny committed
    batch_size = x.size(0)
1Konny's avatar
1Konny committed
    assert batch_size != 0

    if distribution == 'bernoulli':
1Konny's avatar
1Konny committed
        recon_loss = F.binary_cross_entropy_with_logits(x_recon, x, size_average=False).div(batch_size)
1Konny's avatar
1Konny committed
    elif distribution == 'gaussian':
        x_recon = F.sigmoid(x_recon)
        recon_loss = F.mse_loss(x_recon, x, size_average=False).div(batch_size)
    else:
        recon_loss = None

    return recon_loss


def kl_divergence(mu, logvar):
    batch_size = mu.size(0)
    assert batch_size != 0
    if mu.data.ndimension() == 4:
        mu = mu.view(mu.size(0), mu.size(1))
    if logvar.data.ndimension() == 4:
        logvar = logvar.view(logvar.size(0), logvar.size(1))

    klds = -0.5*(1 + logvar - mu.pow(2) - logvar.exp())
1Konny's avatar
1Konny committed
    total_kld = klds.sum(1).mean(0, True)
1Konny's avatar
1Konny committed
    dimension_wise_kld = klds.mean(0)
1Konny's avatar
1Konny committed
    mean_kld = klds.mean(1).mean(0, True)
1Konny's avatar
1Konny committed

    return total_kld, dimension_wise_kld, mean_kld


class DataGather(object):
    def __init__(self):
        self.data = self.get_empty_data_dict()

    def get_empty_data_dict(self):
        return dict(iter=[],
                    recon_loss=[],
                    total_kld=[],
                    dim_wise_kld=[],
                    mean_kld=[],
                    mu=[],
                    var=[],
                    images=[],)
1Konny's avatar
1Konny committed

1Konny's avatar
1Konny committed
    def insert(self, **kwargs):
        for key in kwargs:
            self.data[key].append(kwargs[key])
1Konny's avatar
1Konny committed

1Konny's avatar
1Konny committed
    def flush(self):
        self.data = self.get_empty_data_dict()
1Konny's avatar
1Konny committed


class Solver(object):
    def __init__(self, args):
        self.use_cuda = args.cuda and torch.cuda.is_available()
        self.max_iter = args.max_iter
        self.global_iter = 0

        self.z_dim = args.z_dim
        self.beta = args.beta
1Konny's avatar
1Konny committed
        self.gamma = args.gamma
        self.C_max = args.C_max
        self.C_stop_iter = args.C_stop_iter
        self.objective = args.objective
        self.model = args.model
1Konny's avatar
1Konny committed
        self.lr = args.lr
        self.beta1 = args.beta1
        self.beta2 = args.beta2

1Konny's avatar
1Konny committed
        if args.dataset.lower() == 'dsprites':
1Konny's avatar
1Konny committed
            self.nc = 1
            self.decoder_dist = 'bernoulli'
1Konny's avatar
1Konny committed
        elif args.dataset.lower() == '3dchairs':
1Konny's avatar
1Konny committed
            self.nc = 3
1Konny's avatar
1Konny committed
            self.decoder_dist = 'gaussian'
1Konny's avatar
1Konny committed
        elif args.dataset.lower() == 'celeba':
1Konny's avatar
1Konny committed
            self.nc = 3
            self.decoder_dist = 'gaussian'
1Konny's avatar
1Konny committed
        else:
            raise NotImplementedError

1Konny's avatar
1Konny committed
        if args.model == 'H':
            net = BetaVAE_H
        elif args.model == 'B':
            net = BetaVAE_B
        else:
            raise NotImplementedError('only support model H or B')

        self.net = cuda(net(self.z_dim, self.nc), self.use_cuda)
1Konny's avatar
1Konny committed
        self.optim = optim.Adam(self.net.parameters(), lr=self.lr,
                                    betas=(self.beta1, self.beta2))

        self.viz_name = args.viz_name
        self.viz_port = args.viz_port
        self.viz_on = args.viz_on
        if self.viz_on:
1Konny's avatar
1Konny committed
            self.viz = visdom.Visdom(env=self.viz_name+'_lines', port=self.viz_port)
1Konny's avatar
1Konny committed
            self.win_recon = None
            self.win_kld = None
1Konny's avatar
1Konny committed
            self.win_mu = None
            self.win_var = None
1Konny's avatar
1Konny committed

        self.ckpt_dir = Path(args.ckpt_dir).joinpath(args.viz_name)
        if not self.ckpt_dir.exists():
            self.ckpt_dir.mkdir(parents=True, exist_ok=True)
1Konny's avatar
1Konny committed
        self.ckpt_name = args.ckpt_name
        if self.ckpt_name is not None:
            self.load_checkpoint(self.ckpt_name)
1Konny's avatar
1Konny committed

1Konny's avatar
1Konny committed
        self.save_output = args.save_output
        self.output_dir = Path(args.output_dir).joinpath(args.viz_name)
        if not self.output_dir.exists():
            self.output_dir.mkdir(parents=True, exist_ok=True)

        self.gather_step = args.gather_step
        self.display_step = args.display_step
        self.save_step = args.save_step

1Konny's avatar
1Konny committed
        self.dset_dir = args.dset_dir
1Konny's avatar
1Konny committed
        self.dataset = args.dataset
1Konny's avatar
1Konny committed
        self.batch_size = args.batch_size
        self.data_loader = return_data(args)

1Konny's avatar
1Konny committed
        self.gather = DataGather()

1Konny's avatar
1Konny committed
    def train(self):
        self.net_mode(train=True)
1Konny's avatar
1Konny committed
        self.C_max = Variable(cuda(torch.FloatTensor([self.C_max]), self.use_cuda))
1Konny's avatar
1Konny committed
        out = False
1Konny's avatar
1Konny committed

        pbar = tqdm(total=self.max_iter)
        pbar.update(self.global_iter)
1Konny's avatar
1Konny committed
        while not out:
            for x in self.data_loader:
                self.global_iter += 1
1Konny's avatar
1Konny committed
                pbar.update(1)
1Konny's avatar
1Konny committed

                x = Variable(cuda(x, self.use_cuda))
                x_recon, mu, logvar = self.net(x)
1Konny's avatar
1Konny committed
                recon_loss = reconstruction_loss(x, x_recon, self.decoder_dist)
                total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
1Konny's avatar
1Konny committed

1Konny's avatar
1Konny committed
                if self.objective == 'H':
                    beta_vae_loss = recon_loss + self.beta*total_kld
                elif self.objective == 'B':
                    C = torch.clamp(self.C_max/self.C_stop_iter*self.global_iter, 0, self.C_max.data[0])
                    beta_vae_loss = recon_loss + self.gamma*(total_kld-C).abs()
1Konny's avatar
1Konny committed

                self.optim.zero_grad()
                beta_vae_loss.backward()
                self.optim.step()

                if self.viz_on and self.global_iter%self.gather_step == 0:
1Konny's avatar
1Konny committed
                    self.gather.insert(iter=self.global_iter,
                                       mu=mu.mean(0).data, var=logvar.exp().mean(0).data,
                                       recon_loss=recon_loss.data, total_kld=total_kld.data,
                                       dim_wise_kld=dim_wise_kld.data, mean_kld=mean_kld.data)
1Konny's avatar
1Konny committed

                if self.global_iter%self.display_step == 0:
                    if self.viz_on:
                        self.gather.insert(images=x.data)
                        self.gather.insert(images=F.sigmoid(x_recon).data)
                        self.viz_reconstruction()
                        self.viz_lines()
1Konny's avatar
1Konny committed
                    self.gather.flush()
1Konny's avatar
1Konny committed
                    pbar.write('[{}] recon_loss:{:.3f} total_kld:{:.3f} mean_kld:{:.3f}'.format(
                        self.global_iter, recon_loss.data[0], total_kld.data[0], mean_kld.data[0]))

                    var = logvar.exp().mean(0).data
                    var_str = ''
                    for j, var_j in enumerate(var):
                        var_str += 'var{}:{:.4f} '.format(j+1, var_j)
                    pbar.write(var_str)

                    if self.objective == 'advanced':
                        pbar.write('C:{:.3f}'.format(C.data[0]))
                    
                if self.global_iter%self.save_step == 0:
                    self.save_checkpoint('last')
                    pbar.write('Saved checkpoint')
1Konny's avatar
1Konny committed

1Konny's avatar
1Konny committed
                if self.global_iter%20000 == 0:
                    self.viz_traverse()
1Konny's avatar
1Konny committed
                if self.global_iter%50000 == 0:
                    self.save_checkpoint(str(self.global_iter))

1Konny's avatar
1Konny committed
                if self.global_iter >= self.max_iter:
                    out = True
                    break

1Konny's avatar
1Konny committed
        pbar.write("[Training Finished]")
        pbar.close()
1Konny's avatar
1Konny committed

1Konny's avatar
1Konny committed
    def viz_reconstruction(self):
1Konny's avatar
1Konny committed
        self.net_mode(train=False)
        x = self.gather.data['images'][0][:100]
        x = make_grid(x, normalize=True)
        x_recon = self.gather.data['images'][1][:100]
        x_recon = make_grid(x_recon, normalize=True)
1Konny's avatar
1Konny committed
        images = torch.stack([x, x_recon], dim=0).cpu()
        self.viz.images(images, env=self.viz_name+'_reconstruction',
                        opts=dict(title=str(self.global_iter)), nrow=10)
1Konny's avatar
1Konny committed
        self.net_mode(train=True)
1Konny's avatar
1Konny committed

1Konny's avatar
1Konny committed
    def viz_lines(self):
        self.net_mode(train=False)
1Konny's avatar
1Konny committed
        recon_losses = torch.stack(self.gather.data['recon_loss']).cpu()
1Konny's avatar
1Konny committed

1Konny's avatar
1Konny committed
        mus = torch.stack(self.gather.data['mu']).cpu()
        vars = torch.stack(self.gather.data['var']).cpu()

        dim_wise_klds = torch.stack(self.gather.data['dim_wise_kld'])
        mean_klds = torch.stack(self.gather.data['mean_kld'])
        total_klds = torch.stack(self.gather.data['total_kld'])
        klds = torch.cat([dim_wise_klds, mean_klds, total_klds], 1).cpu()
1Konny's avatar
.  
1Konny committed
        iters = torch.Tensor(self.gather.data['iter'])

1Konny's avatar
1Konny committed
        legend = []
        for z_j in range(self.z_dim):
            legend.append('z_{}'.format(z_j))
        legend.append('mean')
        legend.append('total')

1Konny's avatar
1Konny committed
        if self.win_recon is None:
1Konny's avatar
1Konny committed
            self.win_recon = self.viz.line(
                                        X=iters,
                                        Y=recon_losses,
                                        env=self.viz_name+'_lines',
1Konny's avatar
1Konny committed
                                        opts=dict(
1Konny's avatar
1Konny committed
                                            width=400,
                                            height=400,
1Konny's avatar
1Konny committed
                                            xlabel='iteration',
1Konny's avatar
1Konny committed
                                            title='reconsturction loss',))
1Konny's avatar
1Konny committed
        else:
1Konny's avatar
1Konny committed
            self.win_recon = self.viz.line(
                                        X=iters,
                                        Y=recon_losses,
                                        env=self.viz_name+'_lines',
1Konny's avatar
1Konny committed
                                        win=self.win_recon,
                                        update='append',
                                        opts=dict(
1Konny's avatar
1Konny committed
                                            width=400,
                                            height=400,
1Konny's avatar
1Konny committed
                                            xlabel='iteration',
1Konny's avatar
1Konny committed
                                            title='reconsturction loss',))
1Konny's avatar
1Konny committed

        if self.win_kld is None:
1Konny's avatar
1Konny committed
            self.win_kld = self.viz.line(
                                        X=iters,
                                        Y=klds,
                                        env=self.viz_name+'_lines',
1Konny's avatar
1Konny committed
                                        opts=dict(
1Konny's avatar
1Konny committed
                                            width=400,
                                            height=400,
                                            legend=legend,
1Konny's avatar
1Konny committed
                                            xlabel='iteration',
1Konny's avatar
1Konny committed
                                            title='kl divergence',))
1Konny's avatar
1Konny committed
        else:
1Konny's avatar
1Konny committed
            self.win_kld = self.viz.line(
                                        X=iters,
                                        Y=klds,
                                        env=self.viz_name+'_lines',
1Konny's avatar
1Konny committed
                                        win=self.win_kld,
                                        update='append',
                                        opts=dict(
1Konny's avatar
1Konny committed
                                            width=400,
                                            height=400,
                                            legend=legend,
1Konny's avatar
1Konny committed
                                            xlabel='iteration',
1Konny's avatar
1Konny committed
                                            title='kl divergence',))
1Konny's avatar
1Konny committed

1Konny's avatar
1Konny committed
        if self.win_mu is None:
            self.win_mu = self.viz.line(
                                        X=iters,
                                        Y=mus,
                                        env=self.viz_name+'_lines',
                                        opts=dict(
                                            width=400,
                                            height=400,
                                            legend=legend[:self.z_dim],
                                            xlabel='iteration',
                                            title='posterior mean',))
        else:
            self.win_mu = self.viz.line(
                                        X=iters,
                                        Y=vars,
                                        env=self.viz_name+'_lines',
                                        win=self.win_mu,
                                        update='append',
                                        opts=dict(
                                            width=400,
                                            height=400,
                                            legend=legend[:self.z_dim],
                                            xlabel='iteration',
                                            title='posterior mean',))

        if self.win_var is None:
            self.win_var = self.viz.line(
                                        X=iters,
                                        Y=vars,
                                        env=self.viz_name+'_lines',
                                        opts=dict(
                                            width=400,
                                            height=400,
                                            legend=legend[:self.z_dim],
                                            xlabel='iteration',
                                            title='posterior variance',))
        else:
            self.win_var = self.viz.line(
                                        X=iters,
                                        Y=vars,
                                        env=self.viz_name+'_lines',
                                        win=self.win_var,
                                        update='append',
                                        opts=dict(
                                            width=400,
                                            height=400,
                                            legend=legend[:self.z_dim],
                                            xlabel='iteration',
                                            title='posterior variance',))
        self.net_mode(train=True)

1Konny's avatar
1Konny committed
    def viz_traverse(self, limit=3, inter=2/3, loc=-1):
1Konny's avatar
1Konny committed
        self.net_mode(train=False)
1Konny's avatar
1Konny committed
        import random

1Konny's avatar
1Konny committed
        decoder = self.net.decoder
        encoder = self.net.encoder
        interpolation = torch.arange(-limit, limit+0.1, inter)
1Konny's avatar
1Konny committed

1Konny's avatar
1Konny committed
        n_dsets = len(self.data_loader.dataset)
1Konny's avatar
1Konny committed
        rand_idx = random.randint(1, n_dsets-1)

1Konny's avatar
1Konny committed
        random_img = self.data_loader.dataset.__getitem__(rand_idx)
1Konny's avatar
1Konny committed
        random_img = Variable(cuda(random_img, self.use_cuda), volatile=True).unsqueeze(0)
1Konny's avatar
1Konny committed
        random_img_z = encoder(random_img)[:, :self.z_dim]

1Konny's avatar
1Konny committed
        random_z = Variable(cuda(torch.rand(1, self.z_dim), self.use_cuda), volatile=True)

        if self.dataset == 'dsprites':
            fixed_idx1 = 87040 # square
            fixed_idx2 = 332800 # ellipse
            fixed_idx3 = 578560 # heart

            fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)
            fixed_img1 = Variable(cuda(fixed_img1, self.use_cuda), volatile=True).unsqueeze(0)
            fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim]

            fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)
            fixed_img2 = Variable(cuda(fixed_img2, self.use_cuda), volatile=True).unsqueeze(0)
            fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim]

            fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)
            fixed_img3 = Variable(cuda(fixed_img3, self.use_cuda), volatile=True).unsqueeze(0)
            fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim]

            Z = {'fixed_square':fixed_img_z1, 'fixed_ellipse':fixed_img_z2,
                 'fixed_heart':fixed_img_z3, 'random_img':random_img_z}
        else:
            fixed_idx = 0
            fixed_img = self.data_loader.dataset.__getitem__(fixed_idx)
            fixed_img = Variable(cuda(fixed_img, self.use_cuda), volatile=True).unsqueeze(0)
            fixed_img_z = encoder(fixed_img)[:, :self.z_dim]

            Z = {'fixed_img':fixed_img_z, 'random_img':random_img_z, 'random_z':random_z}
1Konny's avatar
1Konny committed

1Konny's avatar
1Konny committed
        gifs = []
1Konny's avatar
1Konny committed
        for key in Z.keys():
            z_ori = Z[key]
1Konny's avatar
1Konny committed
            samples = []
            for row in range(self.z_dim):
1Konny's avatar
1Konny committed
                if loc != -1 and row != loc:
                    continue
1Konny's avatar
1Konny committed
                z = z_ori.clone()
                for val in interpolation:
                    z[:, row] = val
1Konny's avatar
1Konny committed
                    sample = F.sigmoid(decoder(z)).data
1Konny's avatar
1Konny committed
                    samples.append(sample)
1Konny's avatar
1Konny committed
                    gifs.append(sample)
            samples = torch.cat(samples, dim=0).cpu()
1Konny's avatar
1Konny committed
            title = '{}_latent_traversal(iter:{})'.format(key, self.global_iter)
1Konny's avatar
1Konny committed
            self.viz.images(samples, env=self.viz_name+'_traverse',
                            opts=dict(title=title), nrow=len(interpolation))

        if self.save_output:
            output_dir = self.output_dir.joinpath(str(self.global_iter))
            output_dir.mkdir(parents=True, exist_ok=True)
            gifs = torch.cat(gifs)
            gifs = gifs.view(len(Z), self.z_dim, len(interpolation), self.nc, 64, 64).transpose(1, 2)
            for i, key in enumerate(Z.keys()):
                for j, val in enumerate(interpolation):
                    save_image(tensor=gifs[i][j].cpu(),
                               filename=output_dir.joinpath('{}_{}.jpg'.format(key, j)),
                               nrow=self.z_dim, pad_value=1)

                grid2gif(str(output_dir.joinpath(key+'*.jpg')),
                         str(output_dir.joinpath(key+'.gif')), delay=10)
1Konny's avatar
1Konny committed

1Konny's avatar
1Konny committed
        self.net_mode(train=True)

1Konny's avatar
1Konny committed
    def net_mode(self, train):
        if not isinstance(train, bool):
            raise('Only bool type is supported. True or False')

        if train:
            self.net.train()
        else:
            self.net.eval()

1Konny's avatar
1Konny committed
    def save_checkpoint(self, filename, silent=True):
1Konny's avatar
1Konny committed
        model_states = {'net':self.net.state_dict(),}
        optim_states = {'optim':self.optim.state_dict(),}
        if self.viz_on:
            win_states = {'recon':self.win_recon,
                          'kld':self.win_kld,
                          'mu':self.win_mu,
                          'var':self.win_var,}
        else:
            win_states = {'recon':None,
                          'kld':None,
                          'mu':None,
                          'var':None,}
1Konny's avatar
1Konny committed
        states = {'iter':self.global_iter,
                  'win_states':win_states,
                  'model_states':model_states,
                  'optim_states':optim_states}

        file_path = self.ckpt_dir.joinpath(filename)
        torch.save(states, file_path.open('wb+'))
        if not silent:
            print("=> saved checkpoint '{}' (iter {})".format(file_path, self.global_iter))

1Konny's avatar
1Konny committed
    def load_checkpoint(self, filename):
1Konny's avatar
1Konny committed
        file_path = self.ckpt_dir.joinpath(filename)
        if file_path.is_file():
            checkpoint = torch.load(file_path.open('rb'))
            self.global_iter = checkpoint['iter']
            self.win_recon = checkpoint['win_states']['recon']
            self.win_kld = checkpoint['win_states']['kld']
1Konny's avatar
1Konny committed
            self.win_var = checkpoint['win_states']['var']
            self.win_mu = checkpoint['win_states']['mu']
1Konny's avatar
1Konny committed
            self.net.load_state_dict(checkpoint['model_states']['net'])
            self.optim.load_state_dict(checkpoint['optim_states']['optim'])
            print("=> loaded checkpoint '{} (iter {})'".format(file_path, self.global_iter))
        else:
            print("=> no checkpoint found at '{}'".format(file_path))