diff --git a/config.py b/config.py index c991e4e..3eb731e 100644 --- a/config.py +++ b/config.py @@ -3,53 +3,90 @@ import argparse import time -parser = argparse.ArgumentParser('PGGAN') +parser = argparse.ArgumentParser("PGGAN") ## general settings. -parser.add_argument('--train_data_root', type=str, default='/homes/user/Desktop/YOUR_DIRECTORY') -parser.add_argument('--random_seed', type=int, default=int(time.time())) -parser.add_argument('--n_gpu', type=int, default=1) # for Multi-GPU training. +parser.add_argument( + "--train_data_root", type=str, default="/home/veesion/nabirds/images/" +) +parser.add_argument("--random_seed", type=int, default=int(time.time())) +parser.add_argument("--n_gpu", type=int, default=1) # for Multi-GPU training. ## training parameters. -parser.add_argument('--lr', type=float, default=0.001) # learning rate. -parser.add_argument('--lr_decay', type=float, default=0.87) # learning rate decay at every resolution transition. -parser.add_argument('--eps_drift', type=float, default=0.001) # coeff for the drift loss. -parser.add_argument('--smoothing', type=float, default=0.997) # smoothing factor for smoothed generator. -parser.add_argument('--nc', type=int, default=3) # number of input channel. -parser.add_argument('--nz', type=int, default=512) # input dimension of noise. -parser.add_argument('--ngf', type=int, default=512) # feature dimension of final layer of generator. -parser.add_argument('--ndf', type=int, default=512) # feature dimension of first layer of discriminator. -parser.add_argument('--TICK', type=int, default=1000) # 1 tick = 1000 images = (1000/batch_size) iter. -parser.add_argument('--max_resl', type=int, default=8) # 10-->1024, 9-->512, 8-->256 -parser.add_argument('--trns_tick', type=int, default=200) # transition tick -parser.add_argument('--stab_tick', type=int, default=100) # stabilization tick +parser.add_argument("--lr", type=float, default=0.001) # learning rate. +parser.add_argument( + "--lr_decay", type=float, default=0.9 +) # learning rate decay at every resolution transition. +parser.add_argument( + "--eps_drift", type=float, default=0.001 +) # coeff for the drift loss. +parser.add_argument( + "--smoothing", type=float, default=0.997 +) # smoothing factor for smoothed generator. +parser.add_argument("--nc", type=int, default=3) # number of input channel. +parser.add_argument("--nz", type=int, default=512) # input dimension of noise. +parser.add_argument( + "--ngf", type=int, default=512 +) # feature dimension of final layer of generator. +parser.add_argument( + "--ndf", type=int, default=512 +) # feature dimension of first layer of discriminator. +parser.add_argument( + "--TICK", type=int, default=1000 +) # 1 tick = 1000 images = (1000/batch_size) iter. +parser.add_argument("--max_resl", type=int, default=8) # 10-->1024, 9-->512, 8-->256 +parser.add_argument("--trns_tick", type=int, default=200) # transition tick +parser.add_argument("--stab_tick", type=int, default=500) # stabilization tick +parser.add_argument("--resume", type=int, default=0) # stabilization tick ## network structure. -parser.add_argument('--flag_wn', type=bool, default=True) # use of equalized-learning rate. -parser.add_argument('--flag_bn', type=bool, default=False) # use of batch-normalization. (not recommended) -parser.add_argument('--flag_pixelwise', type=bool, default=True) # use of pixelwise normalization for generator. -parser.add_argument('--flag_gdrop', type=bool, default=True) # use of generalized dropout layer for discriminator. -parser.add_argument('--flag_leaky', type=bool, default=True) # use of leaky relu instead of relu. -parser.add_argument('--flag_tanh', type=bool, default=False) # use of tanh at the end of the generator. -parser.add_argument('--flag_sigmoid', type=bool, default=False) # use of sigmoid at the end of the discriminator. -parser.add_argument('--flag_add_noise', type=bool, default=True) # add noise to the real image(x) -parser.add_argument('--flag_norm_latent', type=bool, default=False) # pixelwise normalization of latent vector (z) -parser.add_argument('--flag_add_drift', type=bool, default=True) # add drift loss - - +parser.add_argument( + "--flag_wn", type=bool, default=True +) # use of equalized-learning rate. +parser.add_argument( + "--flag_bn", type=bool, default=False +) # use of batch-normalization. (not recommended) +parser.add_argument( + "--flag_pixelwise", type=bool, default=True +) # use of pixelwise normalization for generator. +parser.add_argument( + "--flag_gdrop", type=bool, default=True +) # use of generalized dropout layer for discriminator. +parser.add_argument( + "--flag_leaky", type=bool, default=True +) # use of leaky relu instead of relu. +parser.add_argument( + "--flag_tanh", type=bool, default=False +) # use of tanh at the end of the generator. +parser.add_argument( + "--flag_sigmoid", type=bool, default=False +) # use of sigmoid at the end of the discriminator. +parser.add_argument( + "--flag_add_noise", type=bool, default=True +) # add noise to the real image(x) +parser.add_argument( + "--flag_norm_latent", type=bool, default=False +) # pixelwise normalization of latent vector (z) +parser.add_argument("--flag_add_drift", type=bool, default=True) # add drift loss ## optimizer setting. -parser.add_argument('--optimizer', type=str, default='adam') # optimizer type. -parser.add_argument('--beta1', type=float, default=0.0) # beta1 for adam. -parser.add_argument('--beta2', type=float, default=0.99) # beta2 for adam. +parser.add_argument("--optimizer", type=str, default="adam") # optimizer type. +parser.add_argument("--beta1", type=float, default=0.0) # beta1 for adam. +parser.add_argument("--beta2", type=float, default=0.99) # beta2 for adam. ## display and save setting. -parser.add_argument('--use_tb', type=bool, default=True) # enable tensorboard visualization -parser.add_argument('--save_img_every', type=int, default=20) # save images every specified iteration. -parser.add_argument('--display_tb_every', type=int, default=5) # display progress every specified iteration. +parser.add_argument( + "--use_tb", type=bool, default=True +) # enable tensorboard visualization +parser.add_argument( + "--save_img_every", type=int, default=20 +) # save images every specified iteration. +parser.add_argument( + "--display_tb_every", type=int, default=5 +) # display progress every specified iteration. ## parse and save config. diff --git a/continue.txt b/continue.txt new file mode 100644 index 0000000..c227083 --- /dev/null +++ b/continue.txt @@ -0,0 +1 @@ +0 \ No newline at end of file diff --git a/custom_layers.py b/custom_layers.py index 3aed3fa..145f45b 100644 --- a/custom_layers.py +++ b/custom_layers.py @@ -3,7 +3,7 @@ import torch.nn.functional as F import numpy as np from torch.autograd import Variable -import torch +import torch import torch.nn as nn import torchvision.datasets as dsets import torchvision.transforms as transforms @@ -18,11 +18,12 @@ def __init__(self, layer1, layer2): super(ConcatTable, self).__init__() self.layer1 = layer1 self.layer2 = layer2 - - def forward(self,x): + + def forward(self, x): y = [self.layer1(x), self.layer2(x)] return y + class Flatten(nn.Module): def __init__(self): super(Flatten, self).__init__() @@ -31,7 +32,6 @@ def forward(self, x): return x.view(x.size(0), -1) - class fadein_layer(nn.Module): def __init__(self, config): super(fadein_layer, self).__init__() @@ -43,48 +43,57 @@ def update_alpha(self, delta): # input : [x_low, x_high] from ConcatTable() def forward(self, x): - return torch.add(x[0].mul(1.0-self.alpha), x[1].mul(self.alpha)) - + return torch.add(x[0].mul(1.0 - self.alpha), x[1].mul(self.alpha)) # https://github.com/github-pengge/PyTorch-progressive_growing_of_gans/blob/master/models/base_model.py class minibatch_std_concat_layer(nn.Module): - def __init__(self, averaging='all'): + def __init__(self, averaging="all"): super(minibatch_std_concat_layer, self).__init__() self.averaging = averaging.lower() - if 'group' in self.averaging: + if "group" in self.averaging: self.n = int(self.averaging[5:]) else: - assert self.averaging in ['all', 'flat', 'spatial', 'none', 'gpool'], 'Invalid averaging mode'%self.averaging - self.adjusted_std = lambda x, **kwargs: torch.sqrt(torch.mean((x - torch.mean(x, **kwargs)) ** 2, **kwargs) + 1e-8) + assert self.averaging in ["all", "flat", "spatial", "none", "gpool"], ( + "Invalid averaging mode" % self.averaging + ) + self.adjusted_std = lambda x, **kwargs: torch.sqrt( + torch.mean((x - torch.mean(x, **kwargs)) ** 2, **kwargs) + 1e-8 + ) def forward(self, x): shape = list(x.size()) target_shape = copy.deepcopy(shape) vals = self.adjusted_std(x, dim=0, keepdim=True) - if self.averaging == 'all': + if self.averaging == "all": target_shape[1] = 1 vals = torch.mean(vals, dim=1, keepdim=True) - elif self.averaging == 'spatial': + elif self.averaging == "spatial": if len(shape) == 4: - vals = mean(vals, axis=[2,3], keepdim=True) # torch.mean(torch.mean(vals, 2, keepdim=True), 3, keepdim=True) - elif self.averaging == 'none': + vals = mean( + vals, axis=[2, 3], keepdim=True + ) # torch.mean(torch.mean(vals, 2, keepdim=True), 3, keepdim=True) + elif self.averaging == "none": target_shape = [target_shape[0]] + [s for s in target_shape[1:]] - elif self.averaging == 'gpool': + elif self.averaging == "gpool": if len(shape) == 4: - vals = mean(x, [0,2,3], keepdim=True) # torch.mean(torch.mean(torch.mean(x, 2, keepdim=True), 3, keepdim=True), 0, keepdim=True) - elif self.averaging == 'flat': + vals = mean( + x, [0, 2, 3], keepdim=True + ) # torch.mean(torch.mean(torch.mean(x, 2, keepdim=True), 3, keepdim=True), 0, keepdim=True) + elif self.averaging == "flat": target_shape[1] = 1 vals = torch.FloatTensor([self.adjusted_std(x)]) - else: # self.averaging == 'group' + else: # self.averaging == 'group' target_shape[1] = self.n - vals = vals.view(self.n, self.shape[1]/self.n, self.shape[2], self.shape[3]) + vals = vals.view( + self.n, self.shape[1] / self.n, self.shape[2], self.shape[3] + ) vals = mean(vals, axis=0, keepdim=True).view(1, self.n, 1, 1) vals = vals.expand(*target_shape) return torch.cat([x, vals], 1) def __repr__(self): - return self.__class__.__name__ + '(averaging = %s)' % (self.averaging) + return self.__class__.__name__ + "(averaging = %s)" % (self.averaging) class pixelwise_norm_layer(nn.Module): @@ -93,66 +102,75 @@ def __init__(self): self.eps = 1e-8 def forward(self, x): - return x / (torch.mean(x**2, dim=1, keepdim=True) + self.eps) ** 0.5 + return x / (torch.mean(x ** 2, dim=1, keepdim=True) + self.eps) ** 0.5 # for equaliaeed-learning rate. class equalized_conv2d(nn.Module): - def __init__(self, c_in, c_out, k_size, stride, pad, initializer='kaiming', bias=False): + def __init__( + self, c_in, c_out, k_size, stride, pad, initializer="kaiming", bias=False + ): super(equalized_conv2d, self).__init__() self.conv = nn.Conv2d(c_in, c_out, k_size, stride, pad, bias=False) - if initializer == 'kaiming': kaiming_normal(self.conv.weight, a=calculate_gain('conv2d')) - elif initializer == 'xavier': xavier_normal(self.conv.weight) - + if initializer == "kaiming": + kaiming_normal(self.conv.weight, a=calculate_gain("conv2d")) + elif initializer == "xavier": + xavier_normal(self.conv.weight) + conv_w = self.conv.weight.data.clone() self.bias = torch.nn.Parameter(torch.FloatTensor(c_out).fill_(0)) - self.scale = (torch.mean(self.conv.weight.data ** 2)) ** 0.5 - self.conv.weight.data.copy_(self.conv.weight.data/self.scale) + self.scale = ((torch.mean(self.conv.weight.data ** 2)) ** 0.5).cpu() + self.conv.weight.data.copy_(self.conv.weight.data / self.scale) def forward(self, x): x = self.conv(x.mul(self.scale)) - return x + self.bias.view(1,-1,1,1).expand_as(x) - - + return x + self.bias.view(1, -1, 1, 1).expand_as(x) + + class equalized_deconv2d(nn.Module): - def __init__(self, c_in, c_out, k_size, stride, pad, initializer='kaiming'): + def __init__(self, c_in, c_out, k_size, stride, pad, initializer="kaiming"): super(equalized_deconv2d, self).__init__() self.deconv = nn.ConvTranspose2d(c_in, c_out, k_size, stride, pad, bias=False) - if initializer == 'kaiming': kaiming_normal(self.deconv.weight, a=calculate_gain('conv2d')) - elif initializer == 'xavier': xavier_normal(self.deconv.weight) - + if initializer == "kaiming": + kaiming_normal(self.deconv.weight, a=calculate_gain("conv2d")) + elif initializer == "xavier": + xavier_normal(self.deconv.weight) + deconv_w = self.deconv.weight.data.clone() self.bias = torch.nn.Parameter(torch.FloatTensor(c_out).fill_(0)) - self.scale = (torch.mean(self.deconv.weight.data ** 2)) ** 0.5 - self.deconv.weight.data.copy_(self.deconv.weight.data/self.scale) + self.scale = ((torch.mean(self.deconv.weight.data ** 2)) ** 0.5).cpu() + self.deconv.weight.data.copy_(self.deconv.weight.data / self.scale) + def forward(self, x): x = self.deconv(x.mul(self.scale)) - return x + self.bias.view(1,-1,1,1).expand_as(x) + return x + self.bias.view(1, -1, 1, 1).expand_as(x) class equalized_linear(nn.Module): - def __init__(self, c_in, c_out, initializer='kaiming'): + def __init__(self, c_in, c_out, initializer="kaiming"): super(equalized_linear, self).__init__() self.linear = nn.Linear(c_in, c_out, bias=False) - if initializer == 'kaiming': kaiming_normal(self.linear.weight, a=calculate_gain('linear')) - elif initializer == 'xavier': torch.nn.init.xavier_normal(self.linear.weight) - + if initializer == "kaiming": + kaiming_normal(self.linear.weight, a=calculate_gain("linear")) + elif initializer == "xavier": + torch.nn.init.xavier_normal(self.linear.weight) + linear_w = self.linear.weight.data.clone() self.bias = torch.nn.Parameter(torch.FloatTensor(c_out).fill_(0)) - self.scale = (torch.mean(self.linear.weight.data ** 2)) ** 0.5 - self.linear.weight.data.copy_(self.linear.weight.data/self.scale) - + self.scale = ((torch.mean(self.linear.weight.data ** 2)) ** 0.5).cpu() + self.linear.weight.data.copy_(self.linear.weight.data / self.scale) + def forward(self, x): x = self.linear(x.mul(self.scale)) - return x + self.bias.view(1,-1).expand_as(x) + return x + self.bias.view(1, -1).expand_as(x) # ref: https://github.com/github-pengge/PyTorch-progressive_growing_of_gans/blob/master/models/base_model.py class generalized_drop_out(nn.Module): - def __init__(self, mode='mul', strength=0.4, axes=(0,1), normalize=False): + def __init__(self, mode="mul", strength=0.4, axes=(0, 1), normalize=False): super(generalized_drop_out, self).__init__() self.mode = mode.lower() - assert self.mode in ['mul', 'drop', 'prop'], 'Invalid GDropLayer mode'%mode + assert self.mode in ["mul", "drop", "prop"], "Invalid GDropLayer mode" % mode self.strength = strength self.axes = [axes] if isinstance(axes, int) else list(axes) self.normalize = normalize @@ -162,11 +180,13 @@ def forward(self, x, deterministic=False): if deterministic or not self.strength: return x - rnd_shape = [s if axis in self.axes else 1 for axis, s in enumerate(x.size())] # [x.size(axis) for axis in self.axes] - if self.mode == 'drop': + rnd_shape = [ + s if axis in self.axes else 1 for axis, s in enumerate(x.size()) + ] # [x.size(axis) for axis in self.axes] + if self.mode == "drop": p = 1 - self.strength rnd = np.random.binomial(1, p=p, size=rnd_shape) / p - elif self.mode == 'mul': + elif self.mode == "mul": rnd = (1 + self.strength) ** np.random.normal(size=rnd_shape) else: coef = self.strength * x.size(1) ** 0.5 @@ -180,8 +200,10 @@ def forward(self, x, deterministic=False): return x * rnd def __repr__(self): - param_str = '(mode = %s, strength = %s, axes = %s, normalize = %s)' % (self.mode, self.strength, self.axes, self.normalize) + param_str = "(mode = %s, strength = %s, axes = %s, normalize = %s)" % ( + self.mode, + self.strength, + self.axes, + self.normalize, + ) return self.__class__.__name__ + param_str - - - diff --git a/dataloader.py b/dataloader.py index e922582..7eb52a7 100644 --- a/dataloader.py +++ b/dataloader.py @@ -3,7 +3,8 @@ import numpy as np from io import BytesIO import scipy.misc -#import tensorflow as tf + +# import tensorflow as tf import torchvision import torchvision.transforms as transforms from torch.utils.data import DataLoader @@ -16,51 +17,56 @@ class dataloader: def __init__(self, config): self.root = config.train_data_root - self.batch_table = {4:32, 8:32, 16:32, 32:16, 64:16, 128:16, 256:12, 512:3, 1024:1} # change this according to available gpu memory. - self.batchsize = int(self.batch_table[pow(2,2)]) # we start from 2^2=4 - self.imsize = int(pow(2,2)) - self.num_workers = 4 - + self.batch_table = { + 4: 32, + 8: 32, + 16: 32, + 32: 16, + 64: 16, + 128: 16, + 256: 12, + 512: 3, + 1024: 1, + } # change this according to available gpu memory. + self.batchsize = int(self.batch_table[pow(2, 2)]) # we start from 2^2=4 + self.imsize = int(pow(2, 2)) + self.num_workers = 10 + def renew(self, resl): - print('[*] Renew dataloader configuration, load data from {}.'.format(self.root)) - - self.batchsize = int(self.batch_table[pow(2,resl)]) - self.imsize = int(pow(2,resl)) + print( + "[*] Renew dataloader configuration, load data from {}.".format(self.root) + ) + + self.batchsize = int(self.batch_table[pow(2, resl)]) + self.imsize = int(pow(2, resl)) self.dataset = ImageFolder( - root=self.root, - transform=transforms.Compose( [ - transforms.Resize(size=(self.imsize,self.imsize), interpolation=Image.NEAREST), - transforms.ToTensor(), - ])) + root=self.root, + transform=transforms.Compose( + [ + transforms.Resize( + size=(self.imsize, self.imsize), interpolation=Image.NEAREST + ), + transforms.ToTensor(), + ] + ), + ) self.dataloader = DataLoader( dataset=self.dataset, batch_size=self.batchsize, shuffle=True, - num_workers=self.num_workers + num_workers=self.num_workers, ) def __iter__(self): return iter(self.dataloader) - + def __next__(self): return next(self.dataloader) def __len__(self): return len(self.dataloader.dataset) - def get_batch(self): dataIter = iter(self.dataloader) - return next(dataIter)[0].mul(2).add(-1) # pixel range [-1, 1] - - - - - - - - - - - + return next(dataIter)[0].mul(2).add(-1) # pixel range [-1, 1] diff --git a/dirty_save_video.py b/dirty_save_video.py new file mode 100644 index 0000000..bb506f3 --- /dev/null +++ b/dirty_save_video.py @@ -0,0 +1,22 @@ +import gizeh +import moviepy.editor as mpy +import cv2 +import os +import numpy as np + +base = '/home/damien/Images/resl_8/' +counter= 0 +images = np.array(os.listdir(base)) +images = images[np.array([int(x[:4])>4379 for x in images])] + + +images = images[np.argsort([int(x[:4]) for x in images])] + +def make_frame(t): + global counter, images + im = cv2.imread(base+images[counter]) + counter += 4 + return im + +clip = mpy.VideoClip(make_frame, duration=int(len(images)/100.)) # 2 seconds +clip.write_gif("canard.gif",fps=25) diff --git a/generate_interpolated.py b/generate_interpolated.py index a927261..e92d98f 100644 --- a/generate_interpolated.py +++ b/generate_interpolated.py @@ -1,7 +1,7 @@ # generate interpolated images. -import os,sys +import os, sys import torch from config import config from torch.autograd import Variable @@ -9,33 +9,34 @@ use_cuda = True -checkpoint_path = 'repo/model/gen_R8_T55.pth.tar' +checkpoint_path = "repo/model/gen_R8_T55.pth.tar" n_intp = 20 # load trained model. import network as net + test_model = net.Generator(config) if use_cuda: - torch.set_default_tensor_type('torch.cuda.FloatTensor') + torch.set_default_tensor_type("torch.cuda.FloatTensor") test_model = torch.nn.DataParallel(test_model).cuda(device=0) else: - torch.set_default_tensor_type('torch.FloatTensor') + torch.set_default_tensor_type("torch.FloatTensor") -for resl in range(3, config.max_resl+1): +for resl in range(3, config.max_resl + 1): test_model.module.grow_network(resl) test_model.module.flush_network() print(test_model) -print('load checkpoint form ... {}'.format(checkpoint_path)) +print("load checkpoint form ... {}".format(checkpoint_path)) checkpoint = torch.load(checkpoint_path) -test_model.module.load_state_dict(checkpoint['state_dict']) +test_model.module.load_state_dict(checkpoint["state_dict"]) # create folder. for i in range(1000): - name = 'repo/interpolation/try_{}'.format(i) + name = "repo/interpolation/try_{}".format(i) if not os.path.exists(name): - os.system('mkdir -p {}'.format(name)) - break; + os.system("mkdir -p {}".format(name)) + break # interpolate between twe noise(z1, z2). z_intp = torch.FloatTensor(1, config.nz) @@ -49,12 +50,10 @@ z_intp = Variable(z_intp) -for i in range(1, n_intp+1): - alpha = 1.0/float(n_intp+1) - z_intp.data = z1.mul_(alpha) + z2.mul_(1.0-alpha) +for i in range(1, n_intp + 1): + alpha = 1.0 / float(n_intp + 1) + z_intp.data = z1.mul_(alpha) + z2.mul_(1.0 - alpha) fake_im = test_model.module(z_intp) - fname = os.path.join(name, '_intp{}.jpg'.format(i)) - utils.save_image_single(fake_im.data, fname, imsize=pow(2,config.max_resl)) - print('saved {}-th interpolated image ...'.format(i)) - - + fname = os.path.join(name, "_intp{}.jpg".format(i)) + utils.save_image_single(fake_im.data, fname, imsize=pow(2, config.max_resl)) + print("saved {}-th interpolated image ...".format(i)) diff --git a/network.py b/network.py index fb3cc56..6d937f2 100644 --- a/network.py +++ b/network.py @@ -8,54 +8,103 @@ # defined for code simplicity. -def deconv(layers, c_in, c_out, k_size, stride=1, pad=0, leaky=True, bn=False, wn=False, pixel=False, only=False): - if wn: layers.append(equalized_conv2d(c_in, c_out, k_size, stride, pad)) - else: layers.append(nn.Conv2d(c_in, c_out, k_size, stride, pad)) +def deconv( + layers, + c_in, + c_out, + k_size, + stride=1, + pad=0, + leaky=True, + bn=False, + wn=False, + pixel=False, + only=False, +): + if wn: + layers.append(equalized_conv2d(c_in, c_out, k_size, stride, pad)) + else: + layers.append(nn.Conv2d(c_in, c_out, k_size, stride, pad)) if not only: - if leaky: layers.append(nn.LeakyReLU(0.2)) - else: layers.append(nn.ReLU()) - if bn: layers.append(nn.BatchNorm2d(c_out)) - if pixel: layers.append(pixelwise_norm_layer()) + if leaky: + layers.append(nn.LeakyReLU(0.2)) + else: + layers.append(nn.ReLU()) + if bn: + layers.append(nn.BatchNorm2d(c_out)) + if pixel: + layers.append(pixelwise_norm_layer()) return layers -def conv(layers, c_in, c_out, k_size, stride=1, pad=0, leaky=True, bn=False, wn=False, pixel=False, gdrop=True, only=False): - if gdrop: layers.append(generalized_drop_out(mode='prop', strength=0.0)) - if wn: layers.append(equalized_conv2d(c_in, c_out, k_size, stride, pad, initializer='kaiming')) - else: layers.append(nn.Conv2d(c_in, c_out, k_size, stride, pad)) + +def conv( + layers, + c_in, + c_out, + k_size, + stride=1, + pad=0, + leaky=True, + bn=False, + wn=False, + pixel=False, + gdrop=True, + only=False, +): + if gdrop: + layers.append(generalized_drop_out(mode="prop", strength=0.0)) + if wn: + layers.append( + equalized_conv2d(c_in, c_out, k_size, stride, pad, initializer="kaiming") + ) + else: + layers.append(nn.Conv2d(c_in, c_out, k_size, stride, pad)) if not only: - if leaky: layers.append(nn.LeakyReLU(0.2)) - else: layers.append(nn.ReLU()) - if bn: layers.append(nn.BatchNorm2d(c_out)) - if pixel: layers.append(pixelwise_norm_layer()) + if leaky: + layers.append(nn.LeakyReLU(0.2)) + else: + layers.append(nn.ReLU()) + if bn: + layers.append(nn.BatchNorm2d(c_out)) + if pixel: + layers.append(pixelwise_norm_layer()) return layers + def linear(layers, c_in, c_out, sig=True, wn=False): layers.append(Flatten()) - if wn: layers.append(equalized_linear(c_in, c_out)) - else: layers.append(Linear(c_in, c_out)) - if sig: layers.append(nn.Sigmoid()) + if wn: + layers.append(equalized_linear(c_in, c_out)) + else: + layers.append(Linear(c_in, c_out)) + if sig: + layers.append(nn.Sigmoid()) return layers - + def deepcopy_module(module, target): new_module = nn.Sequential() for name, m in module.named_children(): if name == target: - new_module.add_module(name, m) # make new structure and, - new_module[-1].load_state_dict(m.state_dict()) # copy weights + new_module.add_module(name, m) # make new structure and, + new_module[-1].load_state_dict(m.state_dict()) # copy weights return new_module + def soft_copy_param(target_link, source_link, tau): - ''' soft-copy parameters of a link to another link. ''' + """ soft-copy parameters of a link to another link. """ target_params = dict(target_link.named_parameters()) for param_name, param in source_link.named_parameters(): - target_params[param_name].data = target_params[param_name].data.mul(1.0-tau) - target_params[param_name].data = target_params[param_name].data.add(param.data.mul(tau)) + target_params[param_name].data = target_params[param_name].data.mul(1.0 - tau) + target_params[param_name].data = target_params[param_name].data.add( + param.data.mul(tau) + ) + def get_module_names(model): names = [] - for key, val in model.state_dict().iteritems(): - name = key.split('.')[0] + for key, val in model.state_dict().items(): + name = key.split(".")[0] if not name in names: names.append(name) return names @@ -83,89 +132,192 @@ def first_block(self): ndim = self.ngf if self.flag_norm_latent: layers.append(pixelwise_norm_layer()) - layers = deconv(layers, self.nz, ndim, 4, 1, 3, self.flag_leaky, self.flag_bn, self.flag_wn, self.flag_pixelwise) - layers = deconv(layers, ndim, ndim, 3, 1, 1, self.flag_leaky, self.flag_bn, self.flag_wn, self.flag_pixelwise) - return nn.Sequential(*layers), ndim + layers = deconv( + layers, + self.nz, + ndim, + 4, + 1, + 3, + self.flag_leaky, + self.flag_bn, + self.flag_wn, + self.flag_pixelwise, + ) + layers = deconv( + layers, + ndim, + ndim, + 3, + 1, + 1, + self.flag_leaky, + self.flag_bn, + self.flag_wn, + self.flag_pixelwise, + ) + return nn.Sequential(*layers), ndim def intermediate_block(self, resl): halving = False - layer_name = 'intermediate_{}x{}_{}x{}'.format(int(pow(2,resl-1)), int(pow(2,resl-1)), int(pow(2, resl)), int(pow(2, resl))) + layer_name = "intermediate_{}x{}_{}x{}".format( + int(pow(2, resl - 1)), + int(pow(2, resl - 1)), + int(pow(2, resl)), + int(pow(2, resl)), + ) ndim = self.ngf - if resl==3 or resl==4 or resl==5: + if resl == 3 or resl == 4 or resl == 5: halving = False ndim = self.ngf - elif resl==6 or resl==7 or resl==8 or resl==9 or resl==10: + elif resl == 6 or resl == 7 or resl == 8 or resl == 9 or resl == 10: halving = True - for i in range(int(resl)-5): - ndim = ndim/2 + for i in range(int(resl) - 5): + ndim = ndim / 2 ndim = int(ndim) layers = [] - layers.append(nn.Upsample(scale_factor=2, mode='nearest')) # scale up by factor of 2.0 + layers.append( + nn.Upsample(scale_factor=2, mode="nearest") + ) # scale up by factor of 2.0 if halving: - layers = deconv(layers, ndim*2, ndim, 3, 1, 1, self.flag_leaky, self.flag_bn, self.flag_wn, self.flag_pixelwise) - layers = deconv(layers, ndim, ndim, 3, 1, 1, self.flag_leaky, self.flag_bn, self.flag_wn, self.flag_pixelwise) + layers = deconv( + layers, + ndim * 2, + ndim, + 3, + 1, + 1, + self.flag_leaky, + self.flag_bn, + self.flag_wn, + self.flag_pixelwise, + ) + layers = deconv( + layers, + ndim, + ndim, + 3, + 1, + 1, + self.flag_leaky, + self.flag_bn, + self.flag_wn, + self.flag_pixelwise, + ) else: - layers = deconv(layers, ndim, ndim, 3, 1, 1, self.flag_leaky, self.flag_bn, self.flag_wn, self.flag_pixelwise) - layers = deconv(layers, ndim, ndim, 3, 1, 1, self.flag_leaky, self.flag_bn, self.flag_wn, self.flag_pixelwise) - return nn.Sequential(*layers), ndim, layer_name - + layers = deconv( + layers, + ndim, + ndim, + 3, + 1, + 1, + self.flag_leaky, + self.flag_bn, + self.flag_wn, + self.flag_pixelwise, + ) + layers = deconv( + layers, + ndim, + ndim, + 3, + 1, + 1, + self.flag_leaky, + self.flag_bn, + self.flag_wn, + self.flag_pixelwise, + ) + return nn.Sequential(*layers), ndim, layer_name + def to_rgb_block(self, c_in): layers = [] - layers = deconv(layers, c_in, self.nc, 1, 1, 0, self.flag_leaky, self.flag_bn, self.flag_wn, self.flag_pixelwise, only=True) - if self.flag_tanh: layers.append(nn.Tanh()) + layers = deconv( + layers, + c_in, + self.nc, + 1, + 1, + 0, + self.flag_leaky, + self.flag_bn, + self.flag_wn, + self.flag_pixelwise, + only=True, + ) + if self.flag_tanh: + layers.append(nn.Tanh()) return nn.Sequential(*layers) def get_init_gen(self): model = nn.Sequential() first_block, ndim = self.first_block() - model.add_module('first_block', first_block) - model.add_module('to_rgb_block', self.to_rgb_block(ndim)) + model.add_module("first_block", first_block) + model.add_module("to_rgb_block", self.to_rgb_block(ndim)) self.module_names = get_module_names(model) return model - + def grow_network(self, resl): # we make new network since pytorch does not support remove_module() new_model = nn.Sequential() names = get_module_names(self.model) for name, module in self.model.named_children(): - if not name=='to_rgb_block': - new_model.add_module(name, module) # make new structure and, - new_model[-1].load_state_dict(module.state_dict()) # copy pretrained weights - + if not name == "to_rgb_block": + new_model.add_module(name, module) # make new structure and, + new_model[-1].load_state_dict( + module.state_dict() + ) # copy pretrained weights + if resl >= 3 and resl <= 9: - print('growing network[{}x{} to {}x{}]. It may take few seconds...'.format(int(pow(2,resl-1)), int(pow(2,resl-1)), int(pow(2,resl)), int(pow(2,resl)))) - low_resl_to_rgb = deepcopy_module(self.model, 'to_rgb_block') + # print( + # "growing network[{}x{} to {}x{}]. It may take few seconds...".format( + # int(pow(2, resl - 1)), + # int(pow(2, resl - 1)), + # int(pow(2, resl)), + # int(pow(2, resl)), + # ) + # ) + low_resl_to_rgb = deepcopy_module(self.model, "to_rgb_block") prev_block = nn.Sequential() - prev_block.add_module('low_resl_upsample', nn.Upsample(scale_factor=2, mode='nearest')) - prev_block.add_module('low_resl_to_rgb', low_resl_to_rgb) + prev_block.add_module( + "low_resl_upsample", nn.Upsample(scale_factor=2, mode="nearest") + ) + prev_block.add_module("low_resl_to_rgb", low_resl_to_rgb) inter_block, ndim, self.layer_name = self.intermediate_block(resl) next_block = nn.Sequential() - next_block.add_module('high_resl_block', inter_block) - next_block.add_module('high_resl_to_rgb', self.to_rgb_block(ndim)) + next_block.add_module("high_resl_block", inter_block) + next_block.add_module("high_resl_to_rgb", self.to_rgb_block(ndim)) - new_model.add_module('concat_block', ConcatTable(prev_block, next_block)) - new_model.add_module('fadein_block', fadein_layer(self.config)) + new_model.add_module("concat_block", ConcatTable(prev_block, next_block)) + new_model.add_module("fadein_block", fadein_layer(self.config)) self.model = None self.model = new_model self.module_names = get_module_names(self.model) - + def flush_network(self): try: - print('flushing network... It may take few seconds...') + # print("flushing network... It may take few seconds...") # make deep copy and paste. - high_resl_block = deepcopy_module(self.model.concat_block.layer2, 'high_resl_block') - high_resl_to_rgb = deepcopy_module(self.model.concat_block.layer2, 'high_resl_to_rgb') - + high_resl_block = deepcopy_module( + self.model.concat_block.layer2, "high_resl_block" + ) + high_resl_to_rgb = deepcopy_module( + self.model.concat_block.layer2, "high_resl_to_rgb" + ) + new_model = nn.Sequential() for name, module in self.model.named_children(): - if name!='concat_block' and name!='fadein_block': - new_model.add_module(name, module) # make new structure and, - new_model[-1].load_state_dict(module.state_dict()) # copy pretrained weights + if name != "concat_block" and name != "fadein_block": + new_model.add_module(name, module) # make new structure and, + new_model[-1].load_state_dict( + module.state_dict() + ) # copy pretrained weights # now, add the high resolution block. new_model.add_module(self.layer_name, high_resl_block) - new_model.add_module('to_rgb_block', high_resl_to_rgb) + new_model.add_module("to_rgb_block", high_resl_to_rgb) self.model = new_model self.module_names = get_module_names(self.model) except: @@ -173,7 +325,7 @@ def flush_network(self): def freeze_layers(self): # let's freeze pretrained blocks. (Found freezing layers not helpful, so did not use this func.) - print('freeze pretrained weights ... ') + # print("freeze pretrained weights ... ") for param in self.model.parameters(): param.requires_grad = False @@ -203,115 +355,201 @@ def last_block(self): ndim = self.ndf layers = [] layers.append(minibatch_std_concat_layer()) - layers = conv(layers, ndim+1, ndim, 3, 1, 1, self.flag_leaky, self.flag_bn, self.flag_wn, pixel=False) - layers = conv(layers, ndim, ndim, 4, 1, 0, self.flag_leaky, self.flag_bn, self.flag_wn, pixel=False) + layers = conv( + layers, + ndim + 1, + ndim, + 3, + 1, + 1, + self.flag_leaky, + self.flag_bn, + self.flag_wn, + pixel=False, + ) + layers = conv( + layers, + ndim, + ndim, + 4, + 1, + 0, + self.flag_leaky, + self.flag_bn, + self.flag_wn, + pixel=False, + ) layers = linear(layers, ndim, 1, sig=self.flag_sigmoid, wn=self.flag_wn) - return nn.Sequential(*layers), ndim - + return nn.Sequential(*layers), ndim + def intermediate_block(self, resl): halving = False - layer_name = 'intermediate_{}x{}_{}x{}'.format(int(pow(2,resl)), int(pow(2,resl)), int(pow(2, resl-1)), int(pow(2, resl-1))) + layer_name = "intermediate_{}x{}_{}x{}".format( + int(pow(2, resl)), + int(pow(2, resl)), + int(pow(2, resl - 1)), + int(pow(2, resl - 1)), + ) ndim = self.ndf - if resl==3 or resl==4 or resl==5: + if resl == 3 or resl == 4 or resl == 5: halving = False ndim = self.ndf - elif resl==6 or resl==7 or resl==8 or resl==9 or resl==10: + elif resl == 6 or resl == 7 or resl == 8 or resl == 9 or resl == 10: halving = True - for i in range(int(resl)-5): - ndim = ndim/2 + for i in range(int(resl) - 5): + ndim = ndim / 2 ndim = int(ndim) layers = [] if halving: - layers = conv(layers, ndim, ndim, 3, 1, 1, self.flag_leaky, self.flag_bn, self.flag_wn, pixel=False) - layers = conv(layers, ndim, ndim*2, 3, 1, 1, self.flag_leaky, self.flag_bn, self.flag_wn, pixel=False) + layers = conv( + layers, + ndim, + ndim, + 3, + 1, + 1, + self.flag_leaky, + self.flag_bn, + self.flag_wn, + pixel=False, + ) + layers = conv( + layers, + ndim, + ndim * 2, + 3, + 1, + 1, + self.flag_leaky, + self.flag_bn, + self.flag_wn, + pixel=False, + ) else: - layers = conv(layers, ndim, ndim, 3, 1, 1, self.flag_leaky, self.flag_bn, self.flag_wn, pixel=False) - layers = conv(layers, ndim, ndim, 3, 1, 1, self.flag_leaky, self.flag_bn, self.flag_wn, pixel=False) - - layers.append(nn.AvgPool2d(kernel_size=2)) # scale up by factor of 2.0 - return nn.Sequential(*layers), ndim, layer_name - + layers = conv( + layers, + ndim, + ndim, + 3, + 1, + 1, + self.flag_leaky, + self.flag_bn, + self.flag_wn, + pixel=False, + ) + layers = conv( + layers, + ndim, + ndim, + 3, + 1, + 1, + self.flag_leaky, + self.flag_bn, + self.flag_wn, + pixel=False, + ) + + layers.append(nn.AvgPool2d(kernel_size=2)) # scale up by factor of 2.0 + return nn.Sequential(*layers), ndim, layer_name + def from_rgb_block(self, ndim): layers = [] - layers = conv(layers, self.nc, ndim, 1, 1, 0, self.flag_leaky, self.flag_bn, self.flag_wn, pixel=False) - return nn.Sequential(*layers) - + layers = conv( + layers, + self.nc, + ndim, + 1, + 1, + 0, + self.flag_leaky, + self.flag_bn, + self.flag_wn, + pixel=False, + ) + return nn.Sequential(*layers) + def get_init_dis(self): model = nn.Sequential() last_block, ndim = self.last_block() - model.add_module('from_rgb_block', self.from_rgb_block(ndim)) - model.add_module('last_block', last_block) + model.add_module("from_rgb_block", self.from_rgb_block(ndim)) + model.add_module("last_block", last_block) self.module_names = get_module_names(model) return model - def grow_network(self, resl): - + if resl >= 3 and resl <= 9: - print('growing network[{}x{} to {}x{}]. It may take few seconds...'.format(int(pow(2,resl-1)), int(pow(2,resl-1)), int(pow(2,resl)), int(pow(2,resl)))) - low_resl_from_rgb = deepcopy_module(self.model, 'from_rgb_block') + # print( + # "growing network[{}x{} to {}x{}]. It may take few seconds...".format( + # int(pow(2, resl - 1)), + # int(pow(2, resl - 1)), + # int(pow(2, resl)), + # int(pow(2, resl)), + # ) + # ) + low_resl_from_rgb = deepcopy_module(self.model, "from_rgb_block") prev_block = nn.Sequential() - prev_block.add_module('low_resl_downsample', nn.AvgPool2d(kernel_size=2)) - prev_block.add_module('low_resl_from_rgb', low_resl_from_rgb) + prev_block.add_module("low_resl_downsample", nn.AvgPool2d(kernel_size=2)) + prev_block.add_module("low_resl_from_rgb", low_resl_from_rgb) inter_block, ndim, self.layer_name = self.intermediate_block(resl) next_block = nn.Sequential() - next_block.add_module('high_resl_from_rgb', self.from_rgb_block(ndim)) - next_block.add_module('high_resl_block', inter_block) + next_block.add_module("high_resl_from_rgb", self.from_rgb_block(ndim)) + next_block.add_module("high_resl_block", inter_block) new_model = nn.Sequential() - new_model.add_module('concat_block', ConcatTable(prev_block, next_block)) - new_model.add_module('fadein_block', fadein_layer(self.config)) + new_model.add_module("concat_block", ConcatTable(prev_block, next_block)) + new_model.add_module("fadein_block", fadein_layer(self.config)) # we make new network since pytorch does not support remove_module() names = get_module_names(self.model) for name, module in self.model.named_children(): - if not name=='from_rgb_block': - new_model.add_module(name, module) # make new structure and, - new_model[-1].load_state_dict(module.state_dict()) # copy pretrained weights + if not name == "from_rgb_block": + new_model.add_module(name, module) # make new structure and, + new_model[-1].load_state_dict( + module.state_dict() + ) # copy pretrained weights self.model = None self.model = new_model self.module_names = get_module_names(self.model) def flush_network(self): try: - print('flushing network... It may take few seconds...') + # print("flushing network... It may take few seconds...") # make deep copy and paste. - high_resl_block = deepcopy_module(self.model.concat_block.layer2, 'high_resl_block') - high_resl_from_rgb = deepcopy_module(self.model.concat_block.layer2, 'high_resl_from_rgb') - + high_resl_block = deepcopy_module( + self.model.concat_block.layer2, "high_resl_block" + ) + high_resl_from_rgb = deepcopy_module( + self.model.concat_block.layer2, "high_resl_from_rgb" + ) + # add the high resolution block. new_model = nn.Sequential() - new_model.add_module('from_rgb_block', high_resl_from_rgb) + new_model.add_module("from_rgb_block", high_resl_from_rgb) new_model.add_module(self.layer_name, high_resl_block) - + # add rest. for name, module in self.model.named_children(): - if name!='concat_block' and name!='fadein_block': - new_model.add_module(name, module) # make new structure and, - new_model[-1].load_state_dict(module.state_dict()) # copy pretrained weights + if name != "concat_block" and name != "fadein_block": + new_model.add_module(name, module) # make new structure and, + new_model[-1].load_state_dict( + module.state_dict() + ) # copy pretrained weights self.model = new_model self.module_names = get_module_names(self.model) except: self.model = self.model - + def freeze_layers(self): # let's freeze pretrained blocks. (Found freezing layers not helpful, so did not use this func.) - print('freeze pretrained weights ... ') + print("freeze pretrained weights ... ") for param in self.model.parameters(): param.requires_grad = False def forward(self, x): x = self.model(x) return x - - - - - - - - - - diff --git a/tf_recorder.py b/tf_recorder.py index fab772f..ee8c89a 100644 --- a/tf_recorder.py +++ b/tf_recorder.py @@ -11,14 +11,14 @@ class tf_recorder: def __init__(self): - utils.mkdir('repo/tensorboard') - + utils.mkdir("repo/tensorboard") + for i in range(1000): - self.targ = 'repo/tensorboard/try_{}'.format(i) + self.targ = "repo/tensorboard/try_{}".format(i) if not os.path.exists(self.targ): self.writer = SummaryWriter(self.targ) break - + def add_scalar(self, index, val, niter): self.writer.add_scalar(index, val, niter) @@ -33,20 +33,19 @@ def add_image_single(self, index, x, niter): self.writer.add_image(index, x, niter) def add_graph(self, index, x_input, model): - torch.onnx.export(model, x_input, os.path.join(self.targ, "{}.proto".format(index)), verbose=True) + torch.onnx.export( + model, + x_input, + os.path.join(self.targ, "{}.proto".format(index)), + verbose=True, + ) self.writer.add_graph_onnx(os.path.join(self.targ, "{}.proto".format(index))) def export_json(self, out_file): self.writer.export_scalars_to_json(out_file) - - - - - - -''' +""" resnet18 = models.resnet18(False) writer = SummaryWriter() for n_iter in range(100): @@ -66,11 +65,10 @@ def export_json(self, out_file): # export scalar data to JSON for external processing writer.export_scalars_to_json("./all_scalars.json") writer.close() -''' +""" - -''' +""" resnet18 = models.resnet18(False) writer = SummaryWriter() sample_rate = 44100 @@ -105,5 +103,4 @@ def export_json(self, out_file): writer.export_scalars_to_json("./all_scalars.json") writer.close() -''' - +""" diff --git a/trainer.py b/trainer.py index e2fb33d..4d54289 100755 --- a/trainer.py +++ b/trainer.py @@ -11,169 +11,284 @@ import tf_recorder as tensorboard import utils as utils import numpy as np +from multiprocessing import Manager, Value +from torch.autograd import grad as torch_grad + # import tensorflow as tf +def safe_reading(file): + value = file.read() + try: + value = int(value) + return value + except: + return 0 + + +def accelerate(value): + return value * 2 + class trainer: def __init__(self, config): self.config = config if torch.cuda.is_available(): self.use_cuda = True - torch.set_default_tensor_type('torch.cuda.FloatTensor') + torch.set_default_tensor_type("torch.cuda.FloatTensor") else: self.use_cuda = False - torch.set_default_tensor_type('torch.FloatTensor') - + torch.set_default_tensor_type("torch.FloatTensor") + self.nz = config.nz self.optimizer = config.optimizer - self.resl = 2 # we start from 2^2 = 4 + self.resl = 2 # we start from 2^2 = 4 self.lr = config.lr self.eps_drift = config.eps_drift self.smoothing = config.smoothing self.max_resl = config.max_resl + self.accelerate = 1 + self.wgan_target = 1.0 self.trns_tick = config.trns_tick self.stab_tick = config.stab_tick self.TICK = config.TICK + self.skip = False self.globalIter = 0 self.globalTick = 0 + self.wgan_epsilon = 0.001 + self.stack = 0 + self.wgan_lambda = 10.0 + self.just_passed = False + if self.config.resume: + saved_models = os.listdir("repo/model/") + iterations = list( + map(lambda x: int(x.split("_")[-1].split(".")[0][1:]), saved_models) + ) + self.last_iteration = max(iterations) + selected_indexes = np.where([x == self.last_iteration for x in iterations])[ + 0 + ] + G_last_model = [ + saved_models[x] for x in selected_indexes if "gen" in saved_models[x] + ][0] + D_last_model = [ + saved_models[x] for x in selected_indexes if "dis" in saved_models[x] + ][0] + saved_grids = os.listdir("repo/save/grid") + global_iterations = list(map(lambda x: int(x.split("_")[0]), saved_grids)) + self.globalIter = self.config.save_img_every * max(global_iterations) + print( + "Resuming after " + + str(self.last_iteration) + + " ticks and " + + str(self.globalIter) + + " iterations" + ) + G_weights = torch.load("repo/model/" + G_last_model) + D_weights = torch.load("repo/model/" + D_last_model) + self.resuming = True + else: + self.resuming = False + self.kimgs = 0 self.stack = 0 self.epoch = 0 - self.fadein = {'gen':None, 'dis':None} - self.complete = {'gen':0, 'dis':0} - self.phase = 'init' + self.fadein = {"gen": None, "dis": None} + self.complete = {"gen": 0, "dis": 0} + self.phase = "init" self.flag_flush_gen = False self.flag_flush_dis = False self.flag_add_noise = self.config.flag_add_noise self.flag_add_drift = self.config.flag_add_drift - + # network and cirterion self.G = net.Generator(config) self.D = net.Discriminator(config) - print ('Generator structure: ') + print("Generator structure: ") print(self.G.model) - print ('Discriminator structure: ') + print("Discriminator structure: ") print(self.D.model) self.mse = torch.nn.MSELoss() if self.use_cuda: self.mse = self.mse.cuda() torch.cuda.manual_seed(config.random_seed) - if config.n_gpu==1: - self.G = torch.nn.DataParallel(self.G).cuda(device=0) - self.D = torch.nn.DataParallel(self.D).cuda(device=0) - else: - gpus = [] - for i in range(config.n_gpu): - gpus.append(i) - self.G = torch.nn.DataParallel(self.G, device_ids=gpus).cuda() - self.D = torch.nn.DataParallel(self.D, device_ids=gpus).cuda() + self.G = torch.nn.DataParallel(self.G, device_ids=[0]).cuda(device=0) + self.D = torch.nn.DataParallel(self.D, device_ids=[0]).cuda(device=0) - # define tensors, ship model to cuda, and get dataloader. self.renew_everything() - + if self.resuming: + self.resl = G_weights["resl"] + self.globalIter = G_weights["globalIter"] + self.globalTick = G_weights["globalTick"] + self.kimgs = G_weights["kimgs"] + self.epoch = G_weights["epoch"] + self.phase = G_weights["phase"] + self.fadein = G_weights["fadein"] + self.complete = G_weights["complete"] + self.flag_flush_gen = G_weights["flag_flush_gen"] + self.flag_flush_dis = G_weights["flag_flush_dis"] + self.stack = G_weights["stack"] + + print( + "Resuming at " + + str(self.resl) + + " definition after " + + str(self.epoch) + + " epochs" + ) + self.G.module.load_state_dict(G_weights["state_dict"]) + self.D.module.load_state_dict(D_weights["state_dict"]) + self.opt_g.load_state_dict(G_weights["optimizer"]) + self.opt_d.load_state_dict(D_weights["optimizer"]) + # tensorboard self.use_tb = config.use_tb if self.use_tb: self.tb = tensorboard.tf_recorder() - def resl_scheduler(self): - ''' + """ this function will schedule image resolution(self.resl) progressively. it should be called every iteration to ensure resl value is updated properly. step 1. (trns_tick) --> transition in generator. step 2. (stab_tick) --> stabilize. step 3. (trns_tick) --> transition in discriminator. step 4. (stab_tick) --> stabilize. - ''' - if floor(self.resl) != 2 : + """ + + self.previous_phase = self.phase + if self.phase[1:] != "trns": + self.accelerate = 1 + + if floor(self.resl) != 2: self.trns_tick = self.config.trns_tick self.stab_tick = self.config.stab_tick - + self.batchsize = self.loader.batchsize - delta = 1.0/(2*self.trns_tick+2*self.stab_tick) - d_alpha = 1.0*self.batchsize/self.trns_tick/self.TICK + delta = 1.0 / (2 * self.trns_tick + 2 * self.stab_tick) + d_alpha = 1.0 * self.batchsize / self.trns_tick / self.TICK # update alpha if fade-in layer exist. - if self.fadein['gen'] is not None: - if self.resl%1.0 < (self.trns_tick)*delta: - self.fadein['gen'].update_alpha(d_alpha) - self.complete['gen'] = self.fadein['gen'].alpha*100 - self.phase = 'gtrns' - elif self.resl%1.0 >= (self.trns_tick)*delta and self.resl%1.0 < (self.trns_tick+self.stab_tick)*delta: - self.phase = 'gstab' - if self.fadein['dis'] is not None: - if self.resl%1.0 >= (self.trns_tick+self.stab_tick)*delta and self.resl%1.0 < (self.stab_tick + self.trns_tick*2)*delta: - self.fadein['dis'].update_alpha(d_alpha) - self.complete['dis'] = self.fadein['dis'].alpha*100 - self.phase = 'dtrns' - elif self.resl%1.0 >= (self.stab_tick + self.trns_tick*2)*delta and self.phase!='final': - self.phase = 'dstab' - + if self.fadein["gen"] is not None: + if self.resl % 1.0 < (self.trns_tick) * delta: + self.fadein["gen"].update_alpha(d_alpha) + self.complete["gen"] = self.fadein["gen"].alpha * 100 + self.phase = "gtrns" + elif ( + self.resl % 1.0 >= (self.trns_tick) * delta + and self.resl % 1.0 < (self.trns_tick + self.stab_tick) * delta + ): + self.phase = "gstab" + if self.fadein["dis"] is not None: + if ( + self.resl % 1.0 >= (self.trns_tick + self.stab_tick) * delta + and self.resl % 1.0 < (self.stab_tick + self.trns_tick * 2) * delta + ): + self.fadein["dis"].update_alpha(d_alpha) + self.complete["dis"] = self.fadein["dis"].alpha * 100 + self.phase = "dtrns" + elif ( + self.resl % 1.0 >= (self.stab_tick + self.trns_tick * 2) * delta + and self.phase != "final" + ): + self.phase = "dstab" + prev_kimgs = self.kimgs self.kimgs = self.kimgs + self.batchsize - if (self.kimgs%self.TICK) < (prev_kimgs%self.TICK): + if (self.kimgs % self.TICK) < (prev_kimgs % self.TICK): self.globalTick = self.globalTick + 1 + if self.resuming and self.globalTick > self.last_iteration: + self.resuming = False # increase linearly every tick, and grow network structure. prev_resl = floor(self.resl) + f = open("continue.txt", "r") + if safe_reading(f): + f.close() + if self.phase[1:] == "trns": + self.accelerate = accelerate(self.accelerate) + else: + self.skip = True + f = open("continue.txt", "w") + f.write("0") self.resl = self.resl + delta - self.resl = max(2, min(10.5, self.resl)) # clamping, range: 4 ~ 1024 - + f.close() + self.resl = max(2, min(10.5, self.resl)) # clamping, range: 4 ~ 1024 # flush network. - if self.flag_flush_gen and self.resl%1.0 >= (self.trns_tick+self.stab_tick)*delta and prev_resl!=2: - if self.fadein['gen'] is not None: - self.fadein['gen'].update_alpha(d_alpha) - self.complete['gen'] = self.fadein['gen'].alpha*100 + if ( + self.flag_flush_gen + and self.resl % 1.0 >= (self.trns_tick + self.stab_tick) * delta + and prev_resl != 2 + ): + if self.fadein["gen"] is not None: + self.fadein["gen"].update_alpha(d_alpha) + self.complete["gen"] = self.fadein["gen"].alpha * 100 self.flag_flush_gen = False - self.G.module.flush_network() # flush G - print(self.G.module.model) - #self.Gs.module.flush_network() # flush Gs - self.fadein['gen'] = None - self.complete['gen'] = 0.0 - self.phase = 'dtrns' - elif self.flag_flush_dis and floor(self.resl) != prev_resl and prev_resl!=2: - if self.fadein['dis'] is not None: - self.fadein['dis'].update_alpha(d_alpha) - self.complete['dis'] = self.fadein['dis'].alpha*100 + self.G.module.flush_network() # flush G + # print(self.G.module.model) + # self.Gs.module.flush_network() # flush Gs + self.fadein["gen"] = None + self.complete["gen"] = 0.0 + self.phase = "dtrns" + print("flush gen, stop fadein gen, begin phase " + self.phase) + self.just_passed = True + elif ( + self.flag_flush_dis and floor(self.resl) != prev_resl and prev_resl != 2 + ): + if self.fadein["dis"] is not None: + self.fadein["dis"].update_alpha(d_alpha) + self.complete["dis"] = self.fadein["dis"].alpha * 100 self.flag_flush_dis = False - self.D.module.flush_network() # flush and, - print(self.D.module.model) - self.fadein['dis'] = None - self.complete['dis'] = 0.0 - if floor(self.resl) < self.max_resl and self.phase != 'final': - self.phase = 'gtrns' + self.D.module.flush_network() # flush and, + # print(self.D.module.model) + self.fadein["dis"] = None + self.complete["dis"] = 0.0 + if floor(self.resl) < self.max_resl and self.phase != "final": + self.phase = "gtrns" + print("flush dis, stop fadein dis, begin phase " + self.phase) + self.just_passed = True # grow network. - if floor(self.resl) != prev_resl and floor(self.resl)= self.max_resl and self.resl%1.0 >= (self.stab_tick + self.trns_tick*2)*delta: - self.phase = 'final' - self.resl = self.max_resl + (self.stab_tick + self.trns_tick*2)*delta + if ( + floor(self.resl) >= self.max_resl + and self.resl % 1.0 >= (self.stab_tick + self.trns_tick * 2) * delta + ): + self.phase = "final" + self.resl = ( + self.max_resl + (self.stab_tick + self.trns_tick * 2) * delta + ) - - def renew_everything(self): # renew dataloader. self.loader = DL.dataloader(config) self.loader.renew(min(floor(self.resl), self.max_resl)) - + # define tensors self.z = torch.FloatTensor(self.loader.batchsize, self.nz) - self.x = torch.FloatTensor(self.loader.batchsize, 3, self.loader.imsize, self.loader.imsize) - self.x_tilde = torch.FloatTensor(self.loader.batchsize, 3, self.loader.imsize, self.loader.imsize) + self.x = torch.FloatTensor( + self.loader.batchsize, 3, self.loader.imsize, self.loader.imsize + ) + self.x_tilde = torch.FloatTensor( + self.loader.batchsize, 3, self.loader.imsize, self.loader.imsize + ) self.real_label = torch.FloatTensor(self.loader.batchsize).fill_(1) self.fake_label = torch.FloatTensor(self.loader.batchsize).fill_(0) - + # enable cuda if self.use_cuda: self.z = self.z.cuda() @@ -184,36 +299,56 @@ def renew_everything(self): torch.cuda.manual_seed(config.random_seed) # wrapping autograd Variable. - self.x = Variable(self.x) + self.x = Variable(self.x, requires_grad=True) self.x_tilde = Variable(self.x_tilde) self.z = Variable(self.z) self.real_label = Variable(self.real_label) self.fake_label = Variable(self.fake_label) - + # ship new model to cuda. if self.use_cuda: self.G = self.G.cuda() self.D = self.D.cuda() - + # optimizer betas = (self.config.beta1, self.config.beta2) - if self.optimizer == 'adam': - self.opt_g = Adam(filter(lambda p: p.requires_grad, self.G.parameters()), lr=self.lr, betas=betas, weight_decay=0.0) - self.opt_d = Adam(filter(lambda p: p.requires_grad, self.D.parameters()), lr=self.lr, betas=betas, weight_decay=0.0) - + if self.optimizer == "adam": + self.opt_g = Adam( + filter(lambda p: p.requires_grad, self.G.parameters()), + lr=self.lr, + betas=betas, + weight_decay=0.0, + ) + self.opt_d = Adam( + filter(lambda p: p.requires_grad, self.D.parameters()), + lr=self.lr, + betas=betas, + weight_decay=0.0, + ) def feed_interpolated_input(self, x): - if self.phase == 'gtrns' and floor(self.resl)>2 and floor(self.resl)<=self.max_resl: - alpha = self.complete['gen']/100.0 - transform = transforms.Compose( [ transforms.ToPILImage(), - transforms.Scale(size=int(pow(2,floor(self.resl)-1)), interpolation=0), # 0: nearest - transforms.Scale(size=int(pow(2,floor(self.resl))), interpolation=0), # 0: nearest - transforms.ToTensor(), - ] ) + if ( + self.phase == "gtrns" + and floor(self.resl) > 2 + and floor(self.resl) <= self.max_resl + ): + alpha = self.complete["gen"] / 100.0 + transform = transforms.Compose( + [ + transforms.ToPILImage(), + transforms.Scale( + size=int(pow(2, floor(self.resl) - 1)), interpolation=0 + ), # 0: nearest + transforms.Scale( + size=int(pow(2, floor(self.resl))), interpolation=0 + ), # 0: nearest + transforms.ToTensor(), + ] + ) x_low = x.clone().add(1).mul(0.5) for i in range(x_low.size(0)): x_low[i] = transform(x_low[i]).mul(2).add(-1) - x = torch.add(x.mul(alpha), x_low.mul(1-alpha)) # interpolated_x + x = torch.add(x.mul(alpha), x_low.mul(1 - alpha)) # interpolated_x if self.use_cuda: return x.cuda() @@ -222,37 +357,65 @@ def feed_interpolated_input(self, x): def add_noise(self, x): # TODO: support more method of adding noise. - if self.flag_add_noise==False: + if self.flag_add_noise == False: return x - if hasattr(self, '_d_'): + if hasattr(self, "_d_"): self._d_ = self._d_ * 0.9 + torch.mean(self.fx_tilde).item() * 0.1 else: self._d_ = 0.0 - strength = 0.2 * max(0, self._d_ - 0.5)**2 + strength = 0.2 * max(0, self._d_ - 0.5) ** 2 z = np.random.randn(*x.size()).astype(np.float32) * strength - z = Variable(torch.from_numpy(z)).cuda() if self.use_cuda else Variable(torch.from_numpy(z)) + z = ( + Variable(torch.from_numpy(z)).cuda() + if self.use_cuda + else Variable(torch.from_numpy(z)) + ) return x + z + def _gradient_penalty(self, gradients): + # Gradients have shape (batch_size, num_channels, img_width, img_height), + # so flatten to easily take norm per example in batch + gradients = gradients.view(self.batchsize, -1) + # Derivatives of the gradient close to 0 can cause problems because of + # the square root, so manually calculate norm and add epsilon + gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12) + + # Return gradient penalty + return self.wgan_lambda * ((gradients_norm - 1) ** 2).mean() + def train(self): # noise for test. self.z_test = torch.FloatTensor(self.loader.batchsize, self.nz) if self.use_cuda: self.z_test = self.z_test.cuda() - self.z_test = Variable(self.z_test, volatile=True) + self.z_test.data.resize_(self.loader.batchsize, self.nz).normal_(0.0, 1.0) - - for step in range(2, self.max_resl+1+5): - for iter in tqdm(range(0,(self.trns_tick*2+self.stab_tick*2)*self.TICK, self.loader.batchsize)): - self.globalIter = self.globalIter+1 + + for step in range(2, self.max_resl + 1 + 5): + for iter in tqdm( + range( + 0, + (self.trns_tick * 2 + self.stab_tick * 2) * self.TICK, + self.loader.batchsize, + ) + ): + if self.just_passed: + continue + self.globalIter = self.globalIter + 1 self.stack = self.stack + self.loader.batchsize if self.stack > ceil(len(self.loader.dataset)): self.epoch = self.epoch + 1 - self.stack = int(self.stack%(ceil(len(self.loader.dataset)))) + self.stack = int(self.stack % (ceil(len(self.loader.dataset)))) # reslolution scheduler. self.resl_scheduler() - + if self.skip and self.previous_phase == self.phase: + continue + self.skip = False + if self.globalIter % self.accelerate != 0: + continue + # zero gradients. self.G.zero_grad() self.D.zero_grad() @@ -263,12 +426,30 @@ def train(self): self.x = self.add_noise(self.x) self.z.data.resize_(self.loader.batchsize, self.nz).normal_(0.0, 1.0) self.x_tilde = self.G(self.z) - + self.fx = self.D(self.x) self.fx_tilde = self.D(self.x_tilde.detach()) - - loss_d = self.mse(self.fx.squeeze(), self.real_label) + \ - self.mse(self.fx_tilde, self.fake_label) + + loss_d = self.mse(self.fx.squeeze(), self.real_label) + self.mse( + self.fx_tilde, self.fake_label + ) + + ### gradient penalty + gradients = torch_grad( + outputs=self.fx, + inputs=self.x, + grad_outputs=torch.ones(self.fx.size()).cuda() + if self.use_cuda + else torch.ones(self.fx.size()), + create_graph=True, + retain_graph=True, + )[0] + gradient_penalty = self._gradient_penalty(gradients) + loss_d += gradient_penalty + + ### epsilon penalty + epsilon_penalty = (self.fx ** 2).mean() + loss_d += epsilon_penalty * self.wgan_epsilon loss_d.backward() self.opt_d.step() @@ -277,97 +458,169 @@ def train(self): loss_g = self.mse(fx_tilde.squeeze(), self.real_label.detach()) loss_g.backward() self.opt_g.step() - + # logging. - log_msg = ' [E:{0}][T:{1}][{2:6}/{3:6}] errD: {4:.4f} | errG: {5:.4f} | [lr:{11:.5f}][cur:{6:.3f}][resl:{7:4}][{8}][{9:.1f}%][{10:.1f}%]'.format(self.epoch, self.globalTick, self.stack, len(self.loader.dataset), loss_d.item(), loss_g.item(), self.resl, int(pow(2,floor(self.resl))), self.phase, self.complete['gen'], self.complete['dis'], self.lr) - tqdm.write(log_msg) + if (iter - 1) % 10: + log_msg = " [E:{0}][T:{1}][{2:6}/{3:6}] errD: {4:.4f} | errG: {5:.4f} | [lr:{11:.5f}][cur:{6:.3f}][resl:{7:4}][{8}][{9:.1f}%][{10:.1f}%]".format( + self.epoch, + self.globalTick, + self.stack, + len(self.loader.dataset), + loss_d.item(), + loss_g.item(), + self.resl, + int(pow(2, floor(self.resl))), + self.phase, + self.complete["gen"], + self.complete["dis"], + self.lr, + ) + tqdm.write(log_msg) # save model. - self.snapshot('repo/model') + self.snapshot("repo/model") # save image grid. - if self.globalIter%self.config.save_img_every == 0: + if self.globalIter % self.config.save_img_every == 0: with torch.no_grad(): x_test = self.G(self.z_test) - utils.mkdir('repo/save/grid') - utils.save_image_grid(x_test.data, 'repo/save/grid/{}_{}_G{}_D{}.jpg'.format(int(self.globalIter/self.config.save_img_every), self.phase, self.complete['gen'], self.complete['dis'])) - utils.mkdir('repo/save/resl_{}'.format(int(floor(self.resl)))) - utils.save_image_single(x_test.data, 'repo/save/resl_{}/{}_{}_G{}_D{}.jpg'.format(int(floor(self.resl)),int(self.globalIter/self.config.save_img_every), self.phase, self.complete['gen'], self.complete['dis'])) + utils.mkdir("repo/save/grid") + utils.mkdir("repo/save/grid_real") + utils.save_image_grid( + x_test.data, + "repo/save/grid/{}_{}_G{}_D{}.jpg".format( + int(self.globalIter / self.config.save_img_every), + self.phase, + self.complete["gen"], + self.complete["dis"], + ), + ) + if self.globalIter % self.config.save_img_every * 10 == 0: + utils.save_image_grid( + self.x.data, + "repo/save/grid_real/{}_{}_G{}_D{}.jpg".format( + int(self.globalIter / self.config.save_img_every), + self.phase, + self.complete["gen"], + self.complete["dis"], + ), + ) + utils.mkdir("repo/save/resl_{}".format(int(floor(self.resl)))) + utils.mkdir("repo/save/resl_{}_real".format(int(floor(self.resl)))) + utils.save_image_single( + x_test.data, + "repo/save/resl_{}/{}_{}_G{}_D{}.jpg".format( + int(floor(self.resl)), + int(self.globalIter / self.config.save_img_every), + self.phase, + self.complete["gen"], + self.complete["dis"], + ), + ) + if self.globalIter % self.config.save_img_every * 10 == 0: + utils.save_image_single( + self.x.data, + "repo/save/resl_{}_real/{}_{}_G{}_D{}.jpg".format( + int(floor(self.resl)), + int(self.globalIter / self.config.save_img_every), + self.phase, + self.complete["gen"], + self.complete["dis"], + ), + ) # tensorboard visualization. if self.use_tb: with torch.no_grad(): x_test = self.G(self.z_test) - self.tb.add_scalar('data/loss_g', loss_g[0].item(), self.globalIter) - self.tb.add_scalar('data/loss_d', loss_d[0].item(), self.globalIter) - self.tb.add_scalar('tick/lr', self.lr, self.globalIter) - self.tb.add_scalar('tick/cur_resl', int(pow(2,floor(self.resl))), self.globalIter) - '''IMAGE GRID + self.tb.add_scalar("data/loss_g", loss_g.item(), self.globalIter) + self.tb.add_scalar("data/loss_d", loss_d.item(), self.globalIter) + self.tb.add_scalar("tick/lr", self.lr, self.globalIter) + self.tb.add_scalar( + "tick/cur_resl", int(pow(2, floor(self.resl))), self.globalIter + ) + """IMAGE GRID self.tb.add_image_grid('grid/x_test', 4, utils.adjust_dyn_range(x_test.data.float(), [-1,1], [0,1]), self.globalIter) self.tb.add_image_grid('grid/x_tilde', 4, utils.adjust_dyn_range(self.x_tilde.data.float(), [-1,1], [0,1]), self.globalIter) self.tb.add_image_grid('grid/x_intp', 4, utils.adjust_dyn_range(self.x.data.float(), [-1,1], [0,1]), self.globalIter) - ''' + """ + self.just_passed = False def get_state(self, target): - if target == 'gen': + if target == "gen": state = { - 'resl' : self.resl, - 'state_dict' : self.G.module.state_dict(), - 'optimizer' : self.opt_g.state_dict(), + "resl": self.resl, + "state_dict": self.G.module.state_dict(), + "optimizer": self.opt_g.state_dict(), } return state - elif target == 'dis': + elif target == "dis": state = { - 'resl' : self.resl, - 'state_dict' : self.D.module.state_dict(), - 'optimizer' : self.opt_d.state_dict(), + "resl": self.resl, + "state_dict": self.D.module.state_dict(), + "optimizer": self.opt_d.state_dict(), } return state - def get_state(self, target): - if target == 'gen': + if target == "gen": state = { - 'resl' : self.resl, - 'state_dict' : self.G.module.state_dict(), - 'optimizer' : self.opt_g.state_dict(), + "resl": self.resl, + "state_dict": self.G.module.state_dict(), + "optimizer": self.opt_g.state_dict(), + "globalIter": self.globalIter, + "globalTick": self.globalTick, + "phase": self.phase, + "epoch": self.epoch, + "kimgs": self.kimgs, + "fadein": self.fadein, + "complete": self.complete, + "flag_flush_gen": self.flag_flush_gen, + "flag_flush_dis": self.flag_flush_dis, } return state - elif target == 'dis': + elif target == "dis": state = { - 'resl' : self.resl, - 'state_dict' : self.D.module.state_dict(), - 'optimizer' : self.opt_d.state_dict(), + "resl": self.resl, + "state_dict": self.D.module.state_dict(), + "optimizer": self.opt_d.state_dict(), + "globalIter": self.globalIter, + "globalTick": self.globalTick, + "phase": self.phase, + "epoch": self.epoch, + "kimgs": self.kimgs, + "fadein": self.fadein, + "complete": self.complete, + "flag_flush_gen": self.flag_flush_gen, + "flag_flush_dis": self.flag_flush_dis, } return state - def snapshot(self, path): if not os.path.exists(path): - if os.name == 'nt': - os.system('mkdir {}'.format(path.replace('/', '\\'))) + if os.name == "nt": + os.system("mkdir {}".format(path.replace("/", "\\"))) else: - os.system('mkdir -p {}'.format(path)) + os.system("mkdir -p {}".format(path)) # save every 100 tick if the network is in stab phase. - ndis = 'dis_R{}_T{}.pth.tar'.format(int(floor(self.resl)), self.globalTick) - ngen = 'gen_R{}_T{}.pth.tar'.format(int(floor(self.resl)), self.globalTick) - if self.globalTick%50==0: - if self.phase == 'gstab' or self.phase =='dstab' or self.phase == 'final': + ndis = "dis_R{}_T{}.pth.tar".format(int(floor(self.resl)), self.globalTick) + ngen = "gen_R{}_T{}.pth.tar".format(int(floor(self.resl)), self.globalTick) + if self.globalTick % 50 == 0: + if self.phase == "gstab" or self.phase == "dstab" or self.phase == "final": save_path = os.path.join(path, ndis) if not os.path.exists(save_path): - torch.save(self.get_state('dis'), save_path) + torch.save(self.get_state("dis"), save_path) save_path = os.path.join(path, ngen) - torch.save(self.get_state('gen'), save_path) - print('[snapshot] model saved @ {}'.format(path)) + torch.save(self.get_state("gen"), save_path) + print("[snapshot] model saved @ {}".format(path)) + -if __name__ == '__main__': +if __name__ == "__main__": ## perform training. - print('----------------- configuration -----------------') + print("----------------- configuration -----------------") for k, v in vars(config).items(): - print(' {}: {}'.format(k, v)) - print('-------------------------------------------------') - torch.backends.cudnn.benchmark = True # boost speed. + print(" {}: {}".format(k, v)) + print("-------------------------------------------------") + torch.backends.cudnn.benchmark = True # boost speed. trainer = trainer(config) trainer.train() - - diff --git a/utils.py b/utils.py index 15ddc00..4ce258a 100644 --- a/utils.py +++ b/utils.py @@ -12,64 +12,70 @@ def adjust_dyn_range(x, drange_in, drange_out): if not drange_in == drange_out: - scale = float(drange_out[1]-drange_out[0])/float(drange_in[1]-drange_in[0]) - bias = drange_out[0]-drange_in[0]*scale + scale = float(drange_out[1] - drange_out[0]) / float( + drange_in[1] - drange_in[0] + ) + bias = drange_out[0] - drange_in[0] * scale x = x.mul(scale).add(bias) return x def resize(x, size): - transform = transforms.Compose([ - transforms.ToPILImage(), - transforms.Scale(size), - transforms.ToTensor(), - ]) + transform = transforms.Compose( + [transforms.ToPILImage(), transforms.Scale(size), transforms.ToTensor()] + ) return transform(x) def make_image_grid(x, ngrid): x = x.clone().cpu() - if pow(ngrid,2) < x.size(0): - grid = make_grid(x[:ngrid*ngrid], nrow=ngrid, padding=0, normalize=True, scale_each=False) + if pow(ngrid, 2) < x.size(0): + grid = make_grid( + x[: ngrid * ngrid], nrow=ngrid, padding=0, normalize=True, scale_each=False + ) else: - grid = torch.FloatTensor(ngrid*ngrid, x.size(1), x.size(2), x.size(3)).fill_(1) - grid[:x.size(0)].copy_(x) + grid = torch.FloatTensor(ngrid * ngrid, x.size(1), x.size(2), x.size(3)).fill_( + 1 + ) + grid[: x.size(0)].copy_(x) grid = make_grid(grid, nrow=ngrid, padding=0, normalize=True, scale_each=False) return grid def save_image_single(x, path, imsize=512): from PIL import Image + grid = make_image_grid(x, 1) ndarr = grid.mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy() im = Image.fromarray(ndarr) - im = im.resize((imsize,imsize), Image.NEAREST) + im = im.resize((imsize, imsize), Image.NEAREST) im.save(path) def save_image_grid(x, path, imsize=512, ngrid=4): from PIL import Image + grid = make_image_grid(x, ngrid) ndarr = grid.mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy() im = Image.fromarray(ndarr) - im = im.resize((imsize,imsize), Image.NEAREST) + im = im.resize((imsize, imsize), Image.NEAREST) im.save(path) - def load_model(net, path): net.load_state_dict(torch.load(path)) + def save_model(net, path): torch.save(net.state_dict(), path) def make_summary(writer, key, value, step): - if hasattr(value, '__len__'): + if hasattr(value, "__len__"): for idx, img in enumerate(value): summary = tf.Summary() sio = BytesIO() - scipy.misc.toimage(img).save(sio, format='png') + scipy.misc.toimage(img).save(sio, format="png") image_summary = tf.Summary.Image(encoded_image_string=sio.getvalue()) summary.value.add(tag="{}/{}".format(key, idx), image=image_summary) writer.add_summary(summary, global_step=step) @@ -79,17 +85,27 @@ def make_summary(writer, key, value, step): def mkdir(path): - if os.name == 'nt': - os.system('mkdir {}'.format(path.replace('/', '\\'))) + if os.name == "nt": + os.system("mkdir {}".format(path.replace("/", "\\"))) else: - os.system('mkdir -r {}'.format(path)) + os.system("mkdir -p {}".format(path)) import torch import math + irange = range -def make_grid(tensor, nrow=8, padding=2, - normalize=False, range=None, scale_each=False, pad_value=0): + + +def make_grid( + tensor, + nrow=8, + padding=2, + normalize=False, + range=None, + scale_each=False, + pad_value=0, +): """Make a grid of images. Args: tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W) @@ -108,9 +124,13 @@ def make_grid(tensor, nrow=8, padding=2, Example: See this notebook `here `_ """ - if not (torch.is_tensor(tensor) or - (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): - raise TypeError('tensor or list of tensors expected, got {}'.format(type(tensor))) + if not ( + torch.is_tensor(tensor) + or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor)) + ): + raise TypeError( + "tensor or list of tensors expected, got {}".format(type(tensor)) + ) # if list of tensors, convert to a 4D mini-batch Tensor if isinstance(tensor, list): @@ -128,8 +148,9 @@ def make_grid(tensor, nrow=8, padding=2, if normalize is True: tensor = tensor.clone() # avoid modifying tensor in-place if range is not None: - assert isinstance(range, tuple), \ - "range has to be a tuple (min, max) if specified. min and max are numbers" + assert isinstance( + range, tuple + ), "range has to be a tuple (min, max) if specified. min and max are numbers" def norm_ip(img, min, max): img.clamp_(min=min, max=max) @@ -152,21 +173,31 @@ def norm_range(t, range): xmaps = min(nrow, nmaps) ymaps = int(math.ceil(float(nmaps) / xmaps)) height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding) - grid = tensor.new(3, height * ymaps + padding, width * xmaps + padding).fill_(pad_value) + grid = tensor.new(3, height * ymaps + padding, width * xmaps + padding).fill_( + pad_value + ) k = 0 for y in irange(ymaps): for x in irange(xmaps): if k >= nmaps: break - grid.narrow(1, y * height + padding, height - padding)\ - .narrow(2, x * width + padding, width - padding)\ - .copy_(tensor[k]) + grid.narrow(1, y * height + padding, height - padding).narrow( + 2, x * width + padding, width - padding + ).copy_(tensor[k]) k = k + 1 return grid -def save_image(tensor, filename, nrow=8, padding=2, - normalize=False, range=None, scale_each=False, pad_value=0): +def save_image( + tensor, + filename, + nrow=8, + padding=2, + normalize=False, + range=None, + scale_each=False, + pad_value=0, +): """Save a given Tensor into an image file. Args: tensor (Tensor or list): Image to be saved. If given a mini-batch tensor, @@ -174,9 +205,17 @@ def save_image(tensor, filename, nrow=8, padding=2, **kwargs: Other arguments are documented in ``make_grid``. """ from PIL import Image + tensor = tensor.cpu() - grid = make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value, - normalize=normalize, range=range, scale_each=scale_each) + grid = make_grid( + tensor, + nrow=nrow, + padding=padding, + pad_value=pad_value, + normalize=normalize, + range=range, + scale_each=scale_each, + ) ndarr = grid.mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy() im = Image.fromarray(ndarr) im.save(filename)