Skip to content

Commit 9c944e9

Browse files
committed
conv vae
1 parent fa6e409 commit 9c944e9

File tree

4 files changed

+277
-100
lines changed

4 files changed

+277
-100
lines changed

AE.py

Lines changed: 92 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,95 @@
1-
import torch.nn as nn
1+
import torch
2+
import torchvision.datasets as dsets
3+
import torchvision.transforms as transforms
4+
import torchvision
5+
from torch.autograd import Variable
6+
7+
from time import time
8+
9+
from AE import *
10+
11+
12+
num_epochs = 50
13+
batch_size = 100
14+
hidden_size = 30
15+
16+
17+
# MNIST dataset
18+
dataset = dsets.MNIST(root='../data',
19+
train=True,
20+
transform=transforms.ToTensor(),
21+
download=True)
22+
23+
# Data loader
24+
data_loader = torch.utils.data.DataLoader(dataset=dataset,
25+
batch_size=batch_size,
26+
shuffle=True)
27+
28+
def to_var(x):
29+
if torch.cuda.is_available():
30+
x = x.cuda()
31+
return Variable(x)
232

333

434
class Autoencoder(nn.Module):
5-
def __init__(self, in_dim=784, h_dim=400):
6-
super(Autoencoder, self).__init__()
7-
8-
self.encoder = nn.Sequential(
9-
nn.Linear(in_dim, h_dim),
10-
nn.ReLU()
11-
)
12-
13-
self.decoder = nn.Sequential(
14-
nn.Linear(h_dim, in_dim),
15-
nn.Sigmoid()
16-
)
17-
18-
19-
def forward(self, x):
20-
"""
21-
Note: image dimension conversion will be handled by external methods
22-
"""
23-
out = self.encoder(x)
24-
out = self.decoder(out)
25-
return out
35+
def __init__(self, in_dim=784, h_dim=400):
36+
super(Autoencoder, self).__init__()
37+
38+
self.encoder = nn.Sequential(
39+
nn.Linear(in_dim, h_dim),
40+
nn.ReLU()
41+
)
42+
43+
self.decoder = nn.Sequential(
44+
nn.Linear(h_dim, in_dim),
45+
nn.Sigmoid()
46+
)
47+
48+
49+
def forward(self, x):
50+
"""
51+
Note: image dimension conversion will be handled by external methods
52+
"""
53+
out = self.encoder(x)
54+
out = self.decoder(out)
55+
return out
56+
57+
58+
ae = Autoencoder(in_dim=784, h_dim=hidden_size)
59+
60+
if torch.cuda.is_available():
61+
ae.cuda()
62+
63+
criterion = nn.BCELoss()
64+
optimizer = torch.optim.Adam(ae.parameters(), lr=0.001)
65+
iter_per_epoch = len(data_loader)
66+
data_iter = iter(data_loader)
67+
68+
# save fixed inputs for debugging
69+
fixed_x, _ = next(data_iter)
70+
torchvision.utils.save_image(Variable(fixed_x).data.cpu(), './data/real_images.png')
71+
fixed_x = to_var(fixed_x.view(fixed_x.size(0), -1))
72+
73+
for epoch in range(num_epochs):
74+
t0 = time()
75+
for i, (images, _) in enumerate(data_loader):
76+
77+
# flatten the image
78+
images = to_var(images.view(images.size(0), -1))
79+
out = ae(images)
80+
loss = criterion(out, images)
81+
82+
optimizer.zero_grad()
83+
loss.backward()
84+
optimizer.step()
85+
86+
if (i+1) % 100 == 0:
87+
print ('Epoch [%d/%d], Iter [%d/%d] Loss: %.4f Time: %.2fs'
88+
%(epoch+1, num_epochs, i+1, len(dataset)//batch_size, loss.data[0], time()-t0))
89+
90+
# save the reconstructed images
91+
reconst_images = ae(fixed_x)
92+
reconst_images = reconst_images.view(reconst_images.size(0), 1, 28, 28)
93+
torchvision.utils.save_image(reconst_images.data.cpu(), './data/reconst_images_%d.png' % (epoch+1))
94+
95+

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
# Autoencoder in PyTorch #
1+
# Autoencoders in PyTorch #
22

3-
### Update - Jun 30, 2017 ###
3+
### Update - Feb 4, 2018 ###
44

55
* One layer vanilla autoencoder on MNIST
6+
* Variational autoencoder with Convolutional hidden layers on CIFAR-10

conv_vae.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
from __future__ import print_function
2+
import argparse
3+
import torch
4+
import torch.utils.data
5+
from torch import nn, optim
6+
from torch.autograd import Variable
7+
import torch.nn as nn
8+
from torch.nn import functional as F
9+
from torchvision import datasets, transforms
10+
from torchvision.utils import save_image
11+
12+
13+
parser = argparse.ArgumentParser(description='VAE MNIST Example')
14+
parser.add_argument('--batch-size', type=int, default=128, metavar='N',
15+
help='input batch size for training (default: 128)')
16+
parser.add_argument('--epochs', type=int, default=10, metavar='N',
17+
help='number of epochs to train (default: 10)')
18+
parser.add_argument('--no-cuda', action='store_true', default=False,
19+
help='enables CUDA training')
20+
parser.add_argument('--seed', type=int, default=1, metavar='S',
21+
help='random seed (default: 1)')
22+
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
23+
help='how many batches to wait before logging training status')
24+
parser.add_argument('--hidden-size', type=int, default=20, metavar='N',
25+
help='how big is z')
26+
parser.add_argument('--intermediate-size', type=int, default=128, metavar='N',
27+
help='how big is linear around z')
28+
# parser.add_argument('--widen-factor', type=int, default=1, metavar='N',
29+
# help='how wide is the model')
30+
args = parser.parse_args()
31+
args.cuda = not args.no_cuda and torch.cuda.is_available()
32+
33+
34+
torch.manual_seed(args.seed)
35+
if args.cuda:
36+
torch.cuda.manual_seed(args.seed)
37+
38+
39+
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
40+
train_loader = torch.utils.data.DataLoader(
41+
datasets.CIFAR10('../data', train=True, download=True,
42+
transform=transforms.ToTensor()),
43+
batch_size=args.batch_size, shuffle=True, **kwargs)
44+
test_loader = torch.utils.data.DataLoader(
45+
datasets.CIFAR10('../data', train=False, transform=transforms.ToTensor()),
46+
batch_size=args.batch_size, shuffle=False, **kwargs)
47+
48+
49+
class VAE(nn.Module):
50+
def __init__(self):
51+
super(VAE, self).__init__()
52+
53+
# Encoder
54+
self.conv1 = nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1)
55+
self.conv2 = nn.Conv2d(3, 32, kernel_size=2, stride=2, padding=0)
56+
self.conv3 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
57+
self.conv4 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
58+
self.fc1 = nn.Linear(16 * 16 * 32, args.intermediate_size)
59+
60+
# Latent space
61+
self.fc21 = nn.Linear(args.intermediate_size, args.hidden_size)
62+
self.fc22 = nn.Linear(args.intermediate_size, args.hidden_size)
63+
64+
# Decoder
65+
self.fc3 = nn.Linear(args.hidden_size, args.intermediate_size)
66+
self.fc4 = nn.Linear(args.intermediate_size, 8192)
67+
self.deconv1 = nn.ConvTranspose2d(32, 32, kernel_size=3, stride=1, padding=1)
68+
self.deconv2 = nn.ConvTranspose2d(32, 32, kernel_size=3, stride=1, padding=1)
69+
self.deconv3 = nn.ConvTranspose2d(32, 32, kernel_size=2, stride=2, padding=0)
70+
self.conv5 = nn.Conv2d(32, 3, kernel_size=3, stride=1, padding=1)
71+
72+
self.relu = nn.ReLU()
73+
self.sigmoid = nn.Sigmoid()
74+
75+
def encode(self, x):
76+
out = self.relu(self.conv1(x))
77+
out = self.relu(self.conv2(out))
78+
out = self.relu(self.conv3(out))
79+
out = self.relu(self.conv4(out))
80+
out = out.view(out.size(0), -1)
81+
h1 = self.relu(self.fc1(out))
82+
return self.fc21(h1), self.fc22(h1)
83+
84+
def reparameterize(self, mu, logvar):
85+
if self.training:
86+
std = logvar.mul(0.5).exp_()
87+
eps = Variable(std.data.new(std.size()).normal_())
88+
return eps.mul(std).add_(mu)
89+
else:
90+
return mu
91+
92+
def decode(self, z):
93+
h3 = self.relu(self.fc3(z))
94+
out = self.relu(self.fc4(h3))
95+
# import pdb; pdb.set_trace()
96+
out = out.view(out.size(0), 32, 16, 16)
97+
out = self.relu(self.deconv1(out))
98+
out = self.relu(self.deconv2(out))
99+
out = self.relu(self.deconv3(out))
100+
out = self.sigmoid(self.conv5(out))
101+
return out
102+
103+
def forward(self, x):
104+
mu, logvar = self.encode(x)
105+
z = self.reparameterize(mu, logvar)
106+
return self.decode(z), mu, logvar
107+
108+
109+
model = VAE()
110+
if args.cuda:
111+
model.cuda()
112+
optimizer = optim.RMSprop(model.parameters(), lr=1e-3)
113+
114+
115+
# Reconstruction + KL divergence losses summed over all elements and batch
116+
def loss_function(recon_x, x, mu, logvar):
117+
BCE = F.binary_cross_entropy(recon_x.view(-1, 32 * 32 * 3),
118+
x.view(-1, 32 * 32 * 3), size_average=False)
119+
120+
# see Appendix B from VAE paper:
121+
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
122+
# https://arxiv.org/abs/1312.6114
123+
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
124+
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
125+
126+
return BCE + KLD
127+
128+
129+
def train(epoch):
130+
model.train()
131+
train_loss = 0
132+
for batch_idx, (data, _) in enumerate(train_loader):
133+
data = Variable(data)
134+
if args.cuda:
135+
data = data.cuda()
136+
optimizer.zero_grad()
137+
recon_batch, mu, logvar = model(data)
138+
loss = loss_function(recon_batch, data, mu, logvar)
139+
loss.backward()
140+
train_loss += loss.data[0]
141+
optimizer.step()
142+
if batch_idx % args.log_interval == 0:
143+
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
144+
epoch, batch_idx * len(data), len(train_loader.dataset),
145+
100. * batch_idx / len(train_loader),
146+
loss.data[0] / len(data)))
147+
148+
print('====> Epoch: {} Average loss: {:.4f}'.format(
149+
epoch, train_loss / len(train_loader.dataset)))
150+
151+
152+
def test(epoch):
153+
model.eval()
154+
test_loss = 0
155+
for i, (data, _) in enumerate(test_loader):
156+
if args.cuda:
157+
data = data.cuda()
158+
data = Variable(data, volatile=True)
159+
recon_batch, mu, logvar = model(data)
160+
test_loss += loss_function(recon_batch, data, mu, logvar).data[0]
161+
if epoch == args.epochs and i == 0:
162+
n = min(data.size(0), 8)
163+
comparison = torch.cat([data[:n],
164+
recon_batch[:n]])
165+
save_image(comparison.data.cpu(),
166+
'snapshots/conv_vae/reconstruction_' + str(epoch) +
167+
'.png', nrow=n)
168+
169+
test_loss /= len(test_loader.dataset)
170+
print('====> Test set loss: {:.4f}'.format(test_loss))
171+
172+
173+
for epoch in range(1, args.epochs + 1):
174+
train(epoch)
175+
test(epoch)
176+
if epoch == args.epochs:
177+
sample = Variable(torch.randn(64, args.hidden_size))
178+
if args.cuda:
179+
sample = sample.cuda()
180+
sample = model.decode(sample).cpu()
181+
save_image(sample.data.view(64, 3, 32, 32),
182+
'snapshots/conv_vae/sample_' + str(epoch) + '.png')

main.py

Lines changed: 0 additions & 76 deletions
This file was deleted.

0 commit comments

Comments
 (0)