Skip to content
Snippets Groups Projects
Commit 1117b138 authored by Tony Metger's avatar Tony Metger
Browse files

Add extra options for visualization and fix bug for viz_on False

parent 5d47d435
No related branches found
No related tags found
No related merge requests found
......@@ -57,6 +57,10 @@ if __name__ == "__main__":
parser.add_argument('--save_output', default=True, type=str2bool, help='save traverse images and gif')
parser.add_argument('--output_dir', default='outputs', type=str, help='output directory')
parser.add_argument('--gather_step', default=1000, type=int, help='numer of iterations after which data is gathered for visdom')
parser.add_argument('--display_step', default=10000, type=int, help='number of iterations after which loss data is printed and visdom is updated')
parser.add_argument('--save_step', default=10000, type=int, help='number of iterations after which a checkpoint is saved')
parser.add_argument('--ckpt_dir', default='checkpoints', type=str, help='checkpoint directory')
parser.add_argument('--ckpt_name', default='last', type=str, help='load previous checkpoint. insert checkpoint filename')
......
......@@ -130,6 +130,10 @@ class Solver(object):
if not self.output_dir.exists():
self.output_dir.mkdir(parents=True, exist_ok=True)
self.gather_step = args.gather_step
self.display_step = args.display_step
self.save_step = args.save_step
self.dset_dir = args.dset_dir
self.dataset = args.dataset
self.batch_size = args.batch_size
......@@ -164,19 +168,20 @@ class Solver(object):
beta_vae_loss.backward()
self.optim.step()
if self.global_iter%1000 == 0:
if self.viz_on and self.global_iter%self.gather_step == 0:
self.gather.insert(iter=self.global_iter,
mu=mu.mean(0).data, var=logvar.exp().mean(0).data,
recon_loss=recon_loss.data, total_kld=total_kld.data,
dim_wise_kld=dim_wise_kld.data, mean_kld=mean_kld.data)
if self.global_iter%10000 == 0:
self.gather.insert(images=x.data)
self.gather.insert(images=F.sigmoid(x_recon).data)
self.viz_reconstruction()
self.viz_lines()
if self.global_iter%self.display_step == 0:
if self.viz_on:
self.gather.insert(images=x.data)
self.gather.insert(images=F.sigmoid(x_recon).data)
self.viz_reconstruction()
self.viz_lines()
self.gather.flush()
self.save_checkpoint('last')
pbar.write('[{}] recon_loss:{:.3f} total_kld:{:.3f} mean_kld:{:.3f}'.format(
self.global_iter, recon_loss.data[0], total_kld.data[0], mean_kld.data[0]))
......@@ -188,6 +193,10 @@ class Solver(object):
if self.objective == 'advanced':
pbar.write('C:{:.3f}'.format(C.data[0]))
if self.global_iter%self.save_step == 0:
self.save_checkpoint('last')
pbar.write('Saved checkpoint')
if self.global_iter%20000 == 0:
self.viz_traverse()
......@@ -421,10 +430,16 @@ class Solver(object):
def save_checkpoint(self, filename, silent=True):
model_states = {'net':self.net.state_dict(),}
optim_states = {'optim':self.optim.state_dict(),}
win_states = {'recon':self.win_recon,
'kld':self.win_kld,
'mu':self.win_mu,
'var':self.win_var,}
if self.viz_on:
win_states = {'recon':self.win_recon,
'kld':self.win_kld,
'mu':self.win_mu,
'var':self.win_var,}
else:
win_states = {'recon':None,
'kld':None,
'mu':None,
'var':None,}
states = {'iter':self.global_iter,
'win_states':win_states,
'model_states':model_states,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment