Skip to content

Commit d82737a

Browse files
committed
upload cifar10
1 parent 8d058eb commit d82737a

File tree

1 file changed

+174
-0
lines changed

1 file changed

+174
-0
lines changed

chapter4_CNN/cifar10/main.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
import torch
2+
import torchvision.datasets as dsets
3+
import torchvision.transforms as transforms
4+
from torch import nn, optim
5+
from torch.autograd import Variable
6+
from torch.utils.data import DataLoader
7+
8+
# Image Preprocessing
9+
train_transform = transforms.Compose([
10+
transforms.Scale(40),
11+
transforms.RandomHorizontalFlip(),
12+
transforms.RandomCrop(32),
13+
transforms.ToTensor(),
14+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
15+
])
16+
17+
test_transform = transforms.Compose([
18+
transforms.ToTensor(),
19+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
20+
])
21+
# CIFAR-10 Dataset
22+
train_dataset = dsets.CIFAR10(
23+
root='./data', train=True, transform=train_transform, download=True)
24+
25+
test_dataset = dsets.CIFAR10(
26+
root='./data', train=False, transform=test_transform)
27+
28+
# Data Loader (Input Pipeline)
29+
train_loader = DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)
30+
31+
test_loader = DataLoader(dataset=test_dataset, batch_size=128, shuffle=False)
32+
33+
34+
# 3x3 Convolution
35+
def conv3x3(in_channels, out_channels, stride=1):
36+
return nn.Conv2d(
37+
in_channels,
38+
out_channels,
39+
kernel_size=3,
40+
stride=stride,
41+
padding=1,
42+
bias=False)
43+
44+
45+
# Residual Block
46+
class ResidualBlock(nn.Module):
47+
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
48+
super(ResidualBlock, self).__init__()
49+
self.conv1 = conv3x3(in_channels, out_channels, stride)
50+
self.bn1 = nn.BatchNorm2d(out_channels)
51+
self.relu = nn.ReLU(inplace=True)
52+
self.conv2 = conv3x3(out_channels, out_channels)
53+
self.bn2 = nn.BatchNorm2d(out_channels)
54+
self.downsample = downsample
55+
56+
def forward(self, x):
57+
residual = x
58+
out = self.conv1(x)
59+
out = self.bn1(out)
60+
out = self.relu(out)
61+
out = self.conv2(out)
62+
out = self.bn2(out)
63+
if self.downsample:
64+
residual = self.downsample(x)
65+
out += residual
66+
out = self.relu(out)
67+
return out
68+
69+
70+
# ResNet Module
71+
class ResNet(nn.Module):
72+
def __init__(self, block, layers, num_classes=10):
73+
super(ResNet, self).__init__()
74+
self.in_channels = 16
75+
self.conv = conv3x3(3, 16)
76+
self.bn = nn.BatchNorm2d(16)
77+
self.relu = nn.ReLU(inplace=True)
78+
self.layer1 = self.make_layer(block, 16, layers[0])
79+
self.layer2 = self.make_layer(block, 32, layers[0], 2)
80+
self.layer3 = self.make_layer(block, 64, layers[1], 2)
81+
self.avg_pool = nn.AvgPool2d(8)
82+
self.fc = nn.Linear(64, num_classes)
83+
84+
def make_layer(self, block, out_channels, blocks, stride=1):
85+
downsample = None
86+
if (stride != 1) or (self.in_channels != out_channels):
87+
downsample = nn.Sequential(
88+
conv3x3(self.in_channels, out_channels, stride=stride),
89+
nn.BatchNorm2d(out_channels))
90+
layers = []
91+
layers.append(
92+
block(self.in_channels, out_channels, stride, downsample))
93+
self.in_channels = out_channels
94+
for i in range(1, blocks):
95+
layers.append(block(out_channels, out_channels))
96+
return nn.Sequential(*layers)
97+
98+
def forward(self, x):
99+
out = self.conv(x)
100+
out = self.bn(out)
101+
out = self.relu(out)
102+
out = self.layer1(out)
103+
out = self.layer2(out)
104+
out = self.layer3(out)
105+
out = self.avg_pool(out)
106+
out = out.view(out.size(0), -1)
107+
out = self.fc(out)
108+
return out
109+
110+
111+
resnet = ResNet(ResidualBlock, [3, 3, 3])
112+
resnet.cuda()
113+
114+
# Loss and Optimizer
115+
criterion = nn.CrossEntropyLoss()
116+
lr = 0.001
117+
optimizer = torch.optim.Adam(resnet.parameters(), lr=lr)
118+
119+
# Training
120+
total_epoch = 50
121+
for epoch in range(total_epoch):
122+
running_loss = 0
123+
running_acc = 0
124+
running_num = 0
125+
for i, (images, labels) in enumerate(train_loader):
126+
if torch.cuda.is_available():
127+
images = Variable(images.cuda())
128+
labels = Variable(labels.cuda())
129+
else:
130+
images = Variable(images)
131+
labels = Variable(labels)
132+
# Forward + Backward + Optimize
133+
optimizer.zero_grad()
134+
outputs = resnet(images)
135+
loss = criterion(outputs, labels)
136+
loss.backward()
137+
optimizer.step()
138+
139+
# =====================log=====================
140+
running_num += labels.size(0)
141+
running_loss += loss.data[0] * labels.size(0)
142+
_, correct_label = torch.max(outputs, 1)
143+
correct_num = (correct_label == labels).sum()
144+
running_acc += correct_num.data[0]
145+
if (i + 1) % 100 == 0:
146+
print_loss = running_loss / running_num
147+
print_acc = running_acc / running_num
148+
print("Epoch [{}/{}], Iter [{}/{}] Loss: {:.6f} Acc: {:.6f}".
149+
format(epoch + 1, total_epoch, i + 1,
150+
len(train_loader), print_loss, print_acc))
151+
152+
# Decaying Learning Rate
153+
if (epoch + 1) % 20 == 0:
154+
lr /= 3
155+
optimizer = torch.optim.Adam(resnet.parameters(), lr=lr)
156+
157+
# Test
158+
correct = 0
159+
total = 0
160+
for images, labels in test_loader:
161+
if torch.cuda.is_available:
162+
images = Variable(images.cuda())
163+
else:
164+
images = Variable(images)
165+
outputs = resnet(images)
166+
_, predicted = torch.max(outputs.data, 1)
167+
total += labels.size(0)
168+
correct += (predicted.cpu() == labels).sum()
169+
170+
print('Accuracy of the model on the test images: {:.2f} %%'.format(
171+
100 * correct / total))
172+
173+
# Save the Model
174+
torch.save(resnet.state_dict(), 'resnet.pth')

0 commit comments

Comments
 (0)