Newer
Older
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
recon_loss = F.binary_cross_entropy_with_logits(x_recon, x, size_average=False).div(batch_size)
elif distribution == 'gaussian':
x_recon = F.sigmoid(x_recon)
recon_loss = F.mse_loss(x_recon, x, size_average=False).div(batch_size)
else:
recon_loss = None
return recon_loss
def kl_divergence(mu, logvar):
batch_size = mu.size(0)
assert batch_size != 0
if mu.data.ndimension() == 4:
mu = mu.view(mu.size(0), mu.size(1))
if logvar.data.ndimension() == 4:
logvar = logvar.view(logvar.size(0), logvar.size(1))
klds = -0.5*(1 + logvar - mu.pow(2) - logvar.exp())
return total_kld, dimension_wise_kld, mean_kld
class DataGather(object):
def __init__(self):
self.data = self.get_empty_data_dict()
def get_empty_data_dict(self):
return dict(iter=[],
recon_loss=[],
total_kld=[],
dim_wise_kld=[],
mean_kld=[],
mu=[],
var=[],
images=[],)
def insert(self, **kwargs):
for key in kwargs:
self.data[key].append(kwargs[key])
class Solver(object):
def __init__(self, args):
self.use_cuda = args.cuda and torch.cuda.is_available()
self.max_iter = args.max_iter
self.global_iter = 0
self.z_dim = args.z_dim
self.beta = args.beta
self.gamma = args.gamma
self.C_max = args.C_max
self.C_stop_iter = args.C_stop_iter
self.objective = args.objective
self.model = args.model
self.lr = args.lr
self.beta1 = args.beta1
self.beta2 = args.beta2
if args.model == 'H':
net = BetaVAE_H
elif args.model == 'B':
net = BetaVAE_B
else:
raise NotImplementedError('only support model H or B')
self.net = cuda(net(self.z_dim, self.nc), self.use_cuda)
self.optim = optim.Adam(self.net.parameters(), lr=self.lr,
betas=(self.beta1, self.beta2))
self.viz_name = args.viz_name
self.viz_port = args.viz_port
self.viz_on = args.viz_on
if self.viz_on:
self.viz = visdom.Visdom(env=self.viz_name+'_lines', port=self.viz_port)
self.ckpt_dir = Path(args.ckpt_dir).joinpath(args.viz_name)
if not self.ckpt_dir.exists():
self.ckpt_dir.mkdir(parents=True, exist_ok=True)
self.ckpt_name = args.ckpt_name
if self.ckpt_name is not None:
self.load_checkpoint(self.ckpt_name)
self.save_output = args.save_output
self.output_dir = Path(args.output_dir).joinpath(args.viz_name)
if not self.output_dir.exists():
self.output_dir.mkdir(parents=True, exist_ok=True)
self.gather_step = args.gather_step
self.display_step = args.display_step
self.save_step = args.save_step
self.batch_size = args.batch_size
self.data_loader = return_data(args)
self.C_max = Variable(cuda(torch.FloatTensor([self.C_max]), self.use_cuda))
pbar = tqdm(total=self.max_iter)
pbar.update(self.global_iter)
while not out:
for x in self.data_loader:
self.global_iter += 1
x = Variable(cuda(x, self.use_cuda))
x_recon, mu, logvar = self.net(x)
recon_loss = reconstruction_loss(x, x_recon, self.decoder_dist)
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
if self.objective == 'H':
beta_vae_loss = recon_loss + self.beta*total_kld
elif self.objective == 'B':
C = torch.clamp(self.C_max/self.C_stop_iter*self.global_iter, 0, self.C_max.data[0])
beta_vae_loss = recon_loss + self.gamma*(total_kld-C).abs()
self.optim.zero_grad()
beta_vae_loss.backward()
self.optim.step()
if self.viz_on and self.global_iter%self.gather_step == 0:
self.gather.insert(iter=self.global_iter,
mu=mu.mean(0).data, var=logvar.exp().mean(0).data,
recon_loss=recon_loss.data, total_kld=total_kld.data,
dim_wise_kld=dim_wise_kld.data, mean_kld=mean_kld.data)
if self.global_iter%self.display_step == 0:
if self.viz_on:
self.gather.insert(images=x.data)
self.gather.insert(images=F.sigmoid(x_recon).data)
self.viz_reconstruction()
self.viz_lines()
pbar.write('[{}] recon_loss:{:.3f} total_kld:{:.3f} mean_kld:{:.3f}'.format(
self.global_iter, recon_loss.data[0], total_kld.data[0], mean_kld.data[0]))
var = logvar.exp().mean(0).data
var_str = ''
for j, var_j in enumerate(var):
var_str += 'var{}:{:.4f} '.format(j+1, var_j)
pbar.write(var_str)
if self.objective == 'advanced':
pbar.write('C:{:.3f}'.format(C.data[0]))
if self.global_iter%self.save_step == 0:
self.save_checkpoint('last')
pbar.write('Saved checkpoint')
if self.global_iter%50000 == 0:
self.save_checkpoint(str(self.global_iter))
if self.global_iter >= self.max_iter:
out = True
break
self.net_mode(train=False)
x = self.gather.data['images'][0][:100]
x = make_grid(x, normalize=True)
x_recon = self.gather.data['images'][1][:100]
x_recon = make_grid(x_recon, normalize=True)
images = torch.stack([x, x_recon], dim=0).cpu()
self.viz.images(images, env=self.viz_name+'_reconstruction',
opts=dict(title=str(self.global_iter)), nrow=10)
recon_losses = torch.stack(self.gather.data['recon_loss']).cpu()
mus = torch.stack(self.gather.data['mu']).cpu()
vars = torch.stack(self.gather.data['var']).cpu()
dim_wise_klds = torch.stack(self.gather.data['dim_wise_kld'])
mean_klds = torch.stack(self.gather.data['mean_kld'])
total_klds = torch.stack(self.gather.data['total_kld'])
klds = torch.cat([dim_wise_klds, mean_klds, total_klds], 1).cpu()
legend = []
for z_j in range(self.z_dim):
legend.append('z_{}'.format(z_j))
legend.append('mean')
legend.append('total')
self.win_recon = self.viz.line(
X=iters,
Y=recon_losses,
env=self.viz_name+'_lines',
self.win_recon = self.viz.line(
X=iters,
Y=recon_losses,
env=self.viz_name+'_lines',
self.win_kld = self.viz.line(
X=iters,
Y=klds,
env=self.viz_name+'_lines',
self.win_kld = self.viz.line(
X=iters,
Y=klds,
env=self.viz_name+'_lines',
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
if self.win_mu is None:
self.win_mu = self.viz.line(
X=iters,
Y=mus,
env=self.viz_name+'_lines',
opts=dict(
width=400,
height=400,
legend=legend[:self.z_dim],
xlabel='iteration',
title='posterior mean',))
else:
self.win_mu = self.viz.line(
X=iters,
Y=vars,
env=self.viz_name+'_lines',
win=self.win_mu,
update='append',
opts=dict(
width=400,
height=400,
legend=legend[:self.z_dim],
xlabel='iteration',
title='posterior mean',))
if self.win_var is None:
self.win_var = self.viz.line(
X=iters,
Y=vars,
env=self.viz_name+'_lines',
opts=dict(
width=400,
height=400,
legend=legend[:self.z_dim],
xlabel='iteration',
title='posterior variance',))
else:
self.win_var = self.viz.line(
X=iters,
Y=vars,
env=self.viz_name+'_lines',
win=self.win_var,
update='append',
opts=dict(
width=400,
height=400,
legend=legend[:self.z_dim],
xlabel='iteration',
title='posterior variance',))
self.net_mode(train=True)
decoder = self.net.decoder
encoder = self.net.encoder
interpolation = torch.arange(-limit, limit+0.1, inter)
random_img = Variable(cuda(random_img, self.use_cuda), volatile=True).unsqueeze(0)
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
random_z = Variable(cuda(torch.rand(1, self.z_dim), self.use_cuda), volatile=True)
if self.dataset == 'dsprites':
fixed_idx1 = 87040 # square
fixed_idx2 = 332800 # ellipse
fixed_idx3 = 578560 # heart
fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)
fixed_img1 = Variable(cuda(fixed_img1, self.use_cuda), volatile=True).unsqueeze(0)
fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim]
fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)
fixed_img2 = Variable(cuda(fixed_img2, self.use_cuda), volatile=True).unsqueeze(0)
fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim]
fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)
fixed_img3 = Variable(cuda(fixed_img3, self.use_cuda), volatile=True).unsqueeze(0)
fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim]
Z = {'fixed_square':fixed_img_z1, 'fixed_ellipse':fixed_img_z2,
'fixed_heart':fixed_img_z3, 'random_img':random_img_z}
else:
fixed_idx = 0
fixed_img = self.data_loader.dataset.__getitem__(fixed_idx)
fixed_img = Variable(cuda(fixed_img, self.use_cuda), volatile=True).unsqueeze(0)
fixed_img_z = encoder(fixed_img)[:, :self.z_dim]
Z = {'fixed_img':fixed_img_z, 'random_img':random_img_z, 'random_z':random_z}
z = z_ori.clone()
for val in interpolation:
z[:, row] = val
gifs.append(sample)
samples = torch.cat(samples, dim=0).cpu()
title = '{}_latent_traversal(iter:{})'.format(key, self.global_iter)
self.viz.images(samples, env=self.viz_name+'_traverse',
opts=dict(title=title), nrow=len(interpolation))
if self.save_output:
output_dir = self.output_dir.joinpath(str(self.global_iter))
output_dir.mkdir(parents=True, exist_ok=True)
gifs = torch.cat(gifs)
gifs = gifs.view(len(Z), self.z_dim, len(interpolation), self.nc, 64, 64).transpose(1, 2)
for i, key in enumerate(Z.keys()):
for j, val in enumerate(interpolation):
save_image(tensor=gifs[i][j].cpu(),
filename=output_dir.joinpath('{}_{}.jpg'.format(key, j)),
nrow=self.z_dim, pad_value=1)
grid2gif(str(output_dir.joinpath(key+'*.jpg')),
str(output_dir.joinpath(key+'.gif')), delay=10)
def net_mode(self, train):
if not isinstance(train, bool):
raise('Only bool type is supported. True or False')
if train:
self.net.train()
else:
self.net.eval()
model_states = {'net':self.net.state_dict(),}
optim_states = {'optim':self.optim.state_dict(),}
if self.viz_on:
win_states = {'recon':self.win_recon,
'kld':self.win_kld,
'mu':self.win_mu,
'var':self.win_var,}
else:
win_states = {'recon':None,
'kld':None,
'mu':None,
'var':None,}
states = {'iter':self.global_iter,
'win_states':win_states,
'model_states':model_states,
'optim_states':optim_states}
file_path = self.ckpt_dir.joinpath(filename)
torch.save(states, file_path.open('wb+'))
if not silent:
print("=> saved checkpoint '{}' (iter {})".format(file_path, self.global_iter))
file_path = self.ckpt_dir.joinpath(filename)
if file_path.is_file():
checkpoint = torch.load(file_path.open('rb'))
self.global_iter = checkpoint['iter']
self.win_recon = checkpoint['win_states']['recon']
self.win_kld = checkpoint['win_states']['kld']
self.win_var = checkpoint['win_states']['var']
self.win_mu = checkpoint['win_states']['mu']
self.net.load_state_dict(checkpoint['model_states']['net'])
self.optim.load_state_dict(checkpoint['optim_states']['optim'])
print("=> loaded checkpoint '{} (iter {})'".format(file_path, self.global_iter))
else:
print("=> no checkpoint found at '{}'".format(file_path))