Newer
Older
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
recon_loss = F.binary_cross_entropy_with_logits(x_recon, x, size_average=False).div(batch_size)
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
elif distribution == 'gaussian':
x_recon = F.sigmoid(x_recon)
recon_loss = F.mse_loss(x_recon, x, size_average=False).div(batch_size)
else:
recon_loss = None
return recon_loss
def kl_divergence(mu, logvar):
batch_size = mu.size(0)
assert batch_size != 0
if mu.data.ndimension() == 4:
mu = mu.view(mu.size(0), mu.size(1))
if logvar.data.ndimension() == 4:
logvar = logvar.view(logvar.size(0), logvar.size(1))
klds = -0.5*(1 + logvar - mu.pow(2) - logvar.exp())
total_kld = klds.sum(1).mean()
dimension_wise_kld = klds.mean(0)
mean_kld = klds.mean()
return total_kld, dimension_wise_kld, mean_kld
class DataGather(object):
def __init__(self):
self.data = self.get_empty_data_dict()
def get_empty_data_dict(self):
return dict(iter=[],
recon_loss=[],
total_kld=[],
dim_wise_kld=[],
mean_kld=[],
mu=[],
var=[],
images=[],)
def insert(self, **kwargs):
for key in kwargs:
self.data[key].append(kwargs[key])
class Solver(object):
def __init__(self, args):
self.use_cuda = args.cuda and torch.cuda.is_available()
self.max_iter = args.max_iter
self.global_iter = 0
self.z_dim = args.z_dim
self.beta = args.beta
self.gamma = args.gamma
self.C_max = args.C_max
self.C_stop_iter = args.C_stop_iter
self.objective = args.objective
self.model = args.model
self.lr = args.lr
self.beta1 = args.beta1
self.beta2 = args.beta2
if args.model == 'H':
net = BetaVAE_H
elif args.model == 'B':
net = BetaVAE_B
else:
raise NotImplementedError('only support model H or B')
self.net = cuda(net(self.z_dim, self.nc), self.use_cuda)
self.optim = optim.Adam(self.net.parameters(), lr=self.lr,
betas=(self.beta1, self.beta2))
self.viz_name = args.viz_name
self.viz_port = args.viz_port
self.viz_on = args.viz_on
if self.viz_on:
self.viz = visdom.Visdom(env=self.viz_name+'_lines', port=self.viz_port)
self.ckpt_dir = Path(args.ckpt_dir).joinpath(args.viz_name)
if not self.ckpt_dir.exists():
self.ckpt_dir.mkdir(parents=True, exist_ok=True)
self.ckpt_name = args.ckpt_name
if self.ckpt_name is not None:
self.load_checkpoint(self.ckpt_name)
self.save_output = args.save_output
self.output_dir = Path(args.output_dir).joinpath(args.viz_name)
if not self.output_dir.exists():
self.output_dir.mkdir(parents=True, exist_ok=True)
self.batch_size = args.batch_size
self.data_loader = return_data(args)
self.C_max = Variable(cuda(torch.FloatTensor([self.C_max]), self.use_cuda))
pbar = tqdm(total=self.max_iter)
pbar.update(self.global_iter)
while not out:
for x in self.data_loader:
self.global_iter += 1
x = Variable(cuda(x, self.use_cuda))
x_recon, mu, logvar = self.net(x)
recon_loss = reconstruction_loss(x, x_recon, self.decoder_dist)
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
if self.objective == 'H':
beta_vae_loss = recon_loss + self.beta*total_kld
elif self.objective == 'B':
C = torch.clamp(self.C_max/self.C_stop_iter*self.global_iter, 0, self.C_max.data[0])
beta_vae_loss = recon_loss + self.gamma*(total_kld-C).abs()
self.optim.zero_grad()
beta_vae_loss.backward()
self.optim.step()
if self.global_iter%1000 == 0:
self.gather.insert(iter=self.global_iter,
mu=mu.mean(0).data, var=logvar.exp().mean(0).data,
recon_loss=recon_loss.data, total_kld=total_kld.data,
dim_wise_kld=dim_wise_kld.data, mean_kld=mean_kld.data)
self.gather.insert(images=F.sigmoid(x_recon).data)
self.gather.flush()
self.save_checkpoint('last')
pbar.write('[{}] recon_loss:{:.3f} total_kld:{:.3f} mean_kld:{:.3f}'.format(
self.global_iter, recon_loss.data[0], total_kld.data[0], mean_kld.data[0]))
var = logvar.exp().mean(0).data
var_str = ''
for j, var_j in enumerate(var):
var_str += 'var{}:{:.4f} '.format(j+1, var_j)
pbar.write(var_str)
if self.objective == 'advanced':
pbar.write('C:{:.3f}'.format(C.data[0]))
if self.global_iter%50000 == 0:
self.save_checkpoint(str(self.global_iter))
if self.global_iter >= self.max_iter:
out = True
break
self.net_mode(train=False)
x = self.gather.data['images'][0][:100]
x = make_grid(x, normalize=True)
x_recon = self.gather.data['images'][1][:100]
x_recon = make_grid(x_recon, normalize=True)
images = torch.stack([x, x_recon], dim=0).cpu()
self.viz.images(images, env=self.viz_name+'_reconstruction',
opts=dict(title=str(self.global_iter)), nrow=10)
recon_losses = torch.stack(self.gather.data['recon_loss']).cpu()
mus = torch.stack(self.gather.data['mu']).cpu()
vars = torch.stack(self.gather.data['var']).cpu()
dim_wise_klds = torch.stack(self.gather.data['dim_wise_kld'])
mean_klds = torch.stack(self.gather.data['mean_kld'])
total_klds = torch.stack(self.gather.data['total_kld'])
klds = torch.cat([dim_wise_klds, mean_klds, total_klds], 1).cpu()
legend = []
for z_j in range(self.z_dim):
legend.append('z_{}'.format(z_j))
legend.append('mean')
legend.append('total')
self.win_recon = self.viz.line(
X=iters,
Y=recon_losses,
env=self.viz_name+'_lines',
self.win_recon = self.viz.line(
X=iters,
Y=recon_losses,
env=self.viz_name+'_lines',
self.win_kld = self.viz.line(
X=iters,
Y=klds,
env=self.viz_name+'_lines',
self.win_kld = self.viz.line(
X=iters,
Y=klds,
env=self.viz_name+'_lines',
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
if self.win_mu is None:
self.win_mu = self.viz.line(
X=iters,
Y=mus,
env=self.viz_name+'_lines',
opts=dict(
width=400,
height=400,
legend=legend[:self.z_dim],
xlabel='iteration',
title='posterior mean',))
else:
self.win_mu = self.viz.line(
X=iters,
Y=vars,
env=self.viz_name+'_lines',
win=self.win_mu,
update='append',
opts=dict(
width=400,
height=400,
legend=legend[:self.z_dim],
xlabel='iteration',
title='posterior mean',))
if self.win_var is None:
self.win_var = self.viz.line(
X=iters,
Y=vars,
env=self.viz_name+'_lines',
opts=dict(
width=400,
height=400,
legend=legend[:self.z_dim],
xlabel='iteration',
title='posterior variance',))
else:
self.win_var = self.viz.line(
X=iters,
Y=vars,
env=self.viz_name+'_lines',
win=self.win_var,
update='append',
opts=dict(
width=400,
height=400,
legend=legend[:self.z_dim],
xlabel='iteration',
title='posterior variance',))
self.net_mode(train=True)
decoder = self.net.decoder
encoder = self.net.encoder
interpolation = torch.arange(-limit, limit+0.1, inter)
random_img = Variable(cuda(random_img, self.use_cuda), volatile=True).unsqueeze(0)
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
random_z = Variable(cuda(torch.rand(1, self.z_dim), self.use_cuda), volatile=True)
if self.dataset == 'dsprites':
fixed_idx1 = 87040 # square
fixed_idx2 = 332800 # ellipse
fixed_idx3 = 578560 # heart
fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)
fixed_img1 = Variable(cuda(fixed_img1, self.use_cuda), volatile=True).unsqueeze(0)
fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim]
fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)
fixed_img2 = Variable(cuda(fixed_img2, self.use_cuda), volatile=True).unsqueeze(0)
fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim]
fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)
fixed_img3 = Variable(cuda(fixed_img3, self.use_cuda), volatile=True).unsqueeze(0)
fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim]
Z = {'fixed_square':fixed_img_z1, 'fixed_ellipse':fixed_img_z2,
'fixed_heart':fixed_img_z3, 'random_img':random_img_z}
else:
fixed_idx = 0
fixed_img = self.data_loader.dataset.__getitem__(fixed_idx)
fixed_img = Variable(cuda(fixed_img, self.use_cuda), volatile=True).unsqueeze(0)
fixed_img_z = encoder(fixed_img)[:, :self.z_dim]
Z = {'fixed_img':fixed_img_z, 'random_img':random_img_z, 'random_z':random_z}
z = z_ori.clone()
for val in interpolation:
z[:, row] = val
gifs.append(sample)
samples = torch.cat(samples, dim=0).cpu()
title = '{}_latent_traversal(iter:{})'.format(key, self.global_iter)
self.viz.images(samples, env=self.viz_name+'_traverse',
opts=dict(title=title), nrow=len(interpolation))
if self.save_output:
output_dir = self.output_dir.joinpath(str(self.global_iter))
output_dir.mkdir(parents=True, exist_ok=True)
gifs = torch.cat(gifs)
gifs = gifs.view(len(Z), self.z_dim, len(interpolation), self.nc, 64, 64).transpose(1, 2)
for i, key in enumerate(Z.keys()):
for j, val in enumerate(interpolation):
save_image(tensor=gifs[i][j].cpu(),
filename=output_dir.joinpath('{}_{}.jpg'.format(key, j)),
nrow=self.z_dim, pad_value=1)
grid2gif(str(output_dir.joinpath(key+'*.jpg')),
str(output_dir.joinpath(key+'.gif')), delay=10)
def net_mode(self, train):
if not isinstance(train, bool):
raise('Only bool type is supported. True or False')
if train:
self.net.train()
else:
self.net.eval()
model_states = {'net':self.net.state_dict(),}
optim_states = {'optim':self.optim.state_dict(),}
win_states = {'recon':self.win_recon,
'kld':self.win_kld,
'mu':self.win_mu,
'var':self.win_var,}
states = {'iter':self.global_iter,
'win_states':win_states,
'model_states':model_states,
'optim_states':optim_states}
file_path = self.ckpt_dir.joinpath(filename)
torch.save(states, file_path.open('wb+'))
if not silent:
print("=> saved checkpoint '{}' (iter {})".format(file_path, self.global_iter))
file_path = self.ckpt_dir.joinpath(filename)
if file_path.is_file():
checkpoint = torch.load(file_path.open('rb'))
self.global_iter = checkpoint['iter']
self.win_recon = checkpoint['win_states']['recon']
self.win_kld = checkpoint['win_states']['kld']
self.win_var = checkpoint['win_states']['var']
self.win_mu = checkpoint['win_states']['mu']
self.net.load_state_dict(checkpoint['model_states']['net'])
self.optim.load_state_dict(checkpoint['optim_states']['optim'])
print("=> loaded checkpoint '{} (iter {})'".format(file_path, self.global_iter))
else:
print("=> no checkpoint found at '{}'".format(file_path))