Skip to content

Commit 53d16c4

Browse files
committed
add code for recurrent nets and resnet
1 parent 31b0fe5 commit 53d16c4

File tree

1 file changed

+278
-0
lines changed

1 file changed

+278
-0
lines changed
Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
# -*- coding: utf-8 -*-
2+
# @Author: Arunabh Sharma
3+
# @Date: 2024-02-21 23:38:56
4+
# @Last Modified by: Arunabh Sharma
5+
# @Last Modified time: 2024-02-25 22:14:30
6+
7+
import torch
8+
import torchvision
9+
import numpy as np
10+
import torchvision.transforms as transforms
11+
import cv2
12+
13+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14+
15+
16+
def torch_dataset_loader(img_display=False):
17+
train_dataset = torchvision.datasets.MNIST(
18+
root="../data/",
19+
train=True,
20+
transform=torchvision.transforms.ToTensor(),
21+
download=True,
22+
)
23+
test_dataset = torchvision.datasets.MNIST(
24+
root="../data/",
25+
train=False,
26+
transform=torchvision.transforms.ToTensor(),
27+
download=True,
28+
)
29+
30+
train_loader = torch.utils.data.DataLoader(
31+
dataset=train_dataset, batch_size=100, shuffle=True
32+
)
33+
test_loader = torch.utils.data.DataLoader(
34+
dataset=test_dataset, batch_size=1, shuffle=False
35+
)
36+
37+
data_iter = iter(train_loader)
38+
images, labels = next(data_iter)
39+
print(images.size())
40+
print(labels.size())
41+
42+
img = images[0].numpy().transpose(1, 2, 0)
43+
img = cv2.resize(img, (800, 800))
44+
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
45+
print(img.shape)
46+
47+
if img_display:
48+
cv2.imshow("images", img)
49+
cv2.waitKey(0)
50+
cv2.destroyAllWindows()
51+
52+
return train_loader, test_loader
53+
54+
55+
class ConvNet(torch.nn.Module):
56+
def __init__(self):
57+
super(ConvNet, self).__init__()
58+
self.layer1 = torch.nn.Sequential(
59+
torch.nn.Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1)),
60+
torch.nn.BatchNorm2d(16),
61+
torch.nn.ReLU(),
62+
torch.nn.MaxPool2d(kernel_size=(2, 2)),
63+
)
64+
65+
self.layer2 = torch.nn.Sequential(
66+
torch.nn.Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1)),
67+
torch.nn.BatchNorm2d(32),
68+
torch.nn.ReLU(),
69+
torch.nn.MaxPool2d(kernel_size=(2, 2)),
70+
)
71+
72+
self.fc1 = torch.nn.Linear(512, 10)
73+
74+
def forward(self, x):
75+
x = self.layer1(x)
76+
x = self.layer2(x)
77+
print(x.shape)
78+
x = x.reshape(x.size(0), -1)
79+
print(x.shape)
80+
x = self.fc1(x)
81+
return x
82+
83+
84+
def train_model(model, epoch, train_loader, model_name, model_save):
85+
model = model.to(device)
86+
87+
criterion = torch.nn.CrossEntropyLoss()
88+
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
89+
90+
for i in range(epoch):
91+
data_iter = iter(train_loader)
92+
while (ret := next(data_iter, None)) is not None:
93+
images, labels = ret[0], ret[1]
94+
# images = images.to(device)
95+
images = images.reshape(-1, 28, 28).to(device)
96+
labels = labels.to(device)
97+
loss = criterion(model(images), labels)
98+
99+
print(f"Iteration: {i}, Loss: {loss.item()}")
100+
101+
optimizer.zero_grad()
102+
loss.backward()
103+
optimizer.step()
104+
if model_save:
105+
torch.save(model, model_name)
106+
107+
return model
108+
109+
110+
def update_lr(optimizer, lr):
111+
for param_group in optimizer.param_groups:
112+
param_group["lr"] = lr
113+
114+
115+
def train_model_decay_lr(model, epoch, lr, train_loader, model_name, model_save):
116+
model = model.to(device)
117+
118+
curr_lr = lr
119+
criterion = torch.nn.CrossEntropyLoss()
120+
optimizer = torch.optim.Adam(model.parameters(), lr=curr_lr)
121+
122+
for i in range(epoch):
123+
data_iter = iter(train_loader)
124+
while (ret := next(data_iter, None)) is not None:
125+
images, labels = ret[0], ret[1]
126+
images = images.to(device)
127+
labels = labels.to(device)
128+
loss = criterion(model(images), labels)
129+
130+
print(f"Iteration: {i}, Loss: {loss.item()}")
131+
132+
optimizer.zero_grad()
133+
loss.backward()
134+
optimizer.step()
135+
curr_lr /= 3.0
136+
update_lr(optimizer, curr_lr)
137+
if model_save:
138+
torch.save(model, model_name)
139+
140+
return model
141+
142+
143+
def explore_model(test_loader, model_name):
144+
model = torch.load(model_name).to("cpu")
145+
model.eval()
146+
147+
window_name = "Image"
148+
149+
# font
150+
font = cv2.FONT_HERSHEY_SIMPLEX
151+
152+
# org
153+
org0 = (30, 30)
154+
org1 = (30, 60)
155+
156+
# fontScale
157+
fontScale = 1
158+
159+
# Blue color in BGR
160+
color = (0, 255, 0)
161+
162+
# Line thickness of 2 px
163+
thickness = 2
164+
165+
with torch.no_grad():
166+
for image, label in test_loader:
167+
image = image.reshape(-1, 28, 28)
168+
pred_label = model(image)
169+
img = image[0].numpy()
170+
img = cv2.resize(img, (800, 800))
171+
img = np.dstack((img, img, img))
172+
img = cv2.putText(
173+
img,
174+
"GT Label:" + str(int(label[0])),
175+
org0,
176+
font,
177+
fontScale,
178+
color,
179+
thickness,
180+
cv2.LINE_AA,
181+
)
182+
183+
pred_label = np.argmax(pred_label.numpy())
184+
img = cv2.putText(
185+
img,
186+
"Pred Label:" + str(int(pred_label)),
187+
org1,
188+
font,
189+
fontScale,
190+
color,
191+
thickness,
192+
cv2.LINE_AA,
193+
)
194+
cv2.imshow(window_name, img)
195+
if (cv2.waitKey(0) & 0xFF) == ord("q"):
196+
cv2.destroyAllWindows()
197+
exit(0)
198+
199+
200+
class ResnetBlock(torch.nn.Module):
201+
def __init__(self, in_channels, out_channels, downsample):
202+
super(ResnetBlock, self).__init__()
203+
204+
if downsample:
205+
self.conv1 = self.conv_layer(in_channels, out_channels, 3, 2)
206+
self.ds = self.conv_layer(in_channels, out_channels, 1, 2, 0)
207+
else:
208+
self.conv1 = self.conv_layer(in_channels, out_channels, 3, 1)
209+
self.ds = torch.nn.Identity()
210+
self.bn1 = torch.nn.BatchNorm2d(out_channels)
211+
self.nl1 = torch.nn.ReLU(inplace=True)
212+
self.conv2 = self.conv_layer(out_channels, out_channels, 3, 1)
213+
214+
def conv_layer(self, in_channels, out_channels, kernel_size, stride, padding=1):
215+
return torch.nn.Conv2d(
216+
in_channels, out_channels, kernel_size, stride=stride, padding=padding
217+
)
218+
219+
def forward(self, x):
220+
out = self.conv1(x)
221+
residual = self.ds(x)
222+
out = self.bn1(out)
223+
out = self.nl1(out)
224+
out = self.conv2(out)
225+
out = self.bn1(out)
226+
out += residual
227+
out = self.nl1(out)
228+
229+
return out
230+
231+
232+
class Resnet(torch.nn.Module):
233+
def __init__(self):
234+
super(Resnet, self).__init__()
235+
236+
self.layer1 = ResnetBlock(1, 16, False)
237+
self.layer2 = ResnetBlock(16, 32, True)
238+
self.avg_pool = torch.nn.AvgPool2d(2)
239+
self.fc = torch.nn.Linear(7 * 7 * 32, 10)
240+
241+
def forward(self, x):
242+
out = self.layer1(x)
243+
out = self.layer2(out)
244+
out = self.avg_pool(out)
245+
out = out.reshape(out.size(0), -1)
246+
out = self.fc(out)
247+
return out
248+
249+
250+
class RNN(torch.nn.Module):
251+
def __init__(self, input_size, hidden_size, num_layers, num_classes):
252+
super(RNN, self).__init__()
253+
self.hidden_size = hidden_size
254+
self.num_layers = num_layers
255+
self.lstm = torch.nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
256+
self.fc = torch.nn.Linear(hidden_size, num_classes)
257+
258+
def forward(self, x):
259+
# Set initial hidden and cell states
260+
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
261+
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
262+
263+
# Forward propagate LSTM
264+
out, _ = self.lstm(
265+
x, (h0, c0)
266+
) # out: tensor of shape (batch_size, seq_length, hidden_size)
267+
268+
# Decode the hidden state of the last time step
269+
out = self.fc(out[:, -1, :])
270+
return out
271+
272+
273+
if __name__ == "__main__":
274+
train_loader, test_loader = torch_dataset_loader(False)
275+
model_name = "rnn.ckpt"
276+
model = RNN(28, 128, 2, 10)
277+
# _ = train_model(model, 4, train_loader, model_name, True)
278+
explore_model(test_loader, model_name)

0 commit comments

Comments
 (0)