Skip to content

Commit 31b0fe5

Browse files
committed
move pytorch basics up a level
1 parent c2da39d commit 31b0fe5

File tree

1 file changed

+54
-8
lines changed

1 file changed

+54
-8
lines changed

tutorials/01-basics/pytorch_basics/pytorch_basics.py renamed to tutorials/01-basics/pytorch_basics.py

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# @Author: Arunabh Sharma
33
# @Date: 2024-02-19 23:14:58
44
# @Last Modified by: Arunabh Sharma
5-
# @Last Modified time: 2024-02-21 00:31:38
5+
# @Last Modified time: 2024-02-21 23:16:35
66

77

88
# Basic quadratic autograd example
@@ -11,6 +11,8 @@
1111
import torchvision
1212
import cv2
1313

14+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15+
1416

1517
def quadratic_function():
1618
x = torch.tensor(8.0, requires_grad=True)
@@ -134,14 +136,18 @@ def transfer_learning(train_loader, model_save=True):
134136
resnet.fc = torch.nn.Linear(resnet.fc.in_features, 10)
135137
print(resnet)
136138

137-
n_iter = 5
139+
resnet = resnet.to(device)
140+
141+
epoch = 5
138142
criterion = torch.nn.CrossEntropyLoss()
139-
optimizer = torch.optim.SGD(resnet.parameters(), lr=1e-6)
143+
optimizer = torch.optim.Adam(resnet.parameters())
140144

141-
for i in range(n_iter):
145+
for i in range(epoch):
142146
data_iter = iter(train_loader)
143147
while (ret := next(data_iter, None)) is not None:
144148
images, labels = ret[0], ret[1]
149+
images = images.to(device)
150+
labels = labels.to(device)
145151
loss = criterion(resnet(images), labels)
146152

147153
print(f"Iteration: {i}, Loss: {loss.item()}")
@@ -155,8 +161,8 @@ def transfer_learning(train_loader, model_save=True):
155161
return resnet
156162

157163

158-
def explore_transfer_learning(test_loader):
159-
model = torch.load("model.ckpt")
164+
def explore_transfer_learning(test_loader, model_name):
165+
model = torch.load(model_name).to("cpu")
160166
model.eval()
161167

162168
window_name = "Image"
@@ -179,6 +185,8 @@ def explore_transfer_learning(test_loader):
179185

180186
with torch.no_grad():
181187
for image, label in test_loader:
188+
# image_r = image.reshape(-1, 28 * 28)
189+
# pred_label = model(image_r)
182190
pred_label = model(image)
183191
img = image[0].numpy().transpose(1, 2, 0)
184192
img = cv2.resize(img, (800, 800))
@@ -211,7 +219,45 @@ def explore_transfer_learning(test_loader):
211219
exit(0)
212220

213221

222+
class NeuralNetwork(torch.nn.Module):
223+
def __init__(self):
224+
super(NeuralNetwork, self).__init__()
225+
self.fc0 = torch.nn.Linear(28 * 28, 128)
226+
self.nl1 = torch.nn.ReLU()
227+
self.fc1 = torch.nn.Linear(128, 10)
228+
229+
def forward(self, x):
230+
return self.fc1(self.nl1(self.fc0(x)))
231+
232+
233+
def train_nn(train_loader, model_save=True):
234+
model = NeuralNetwork().to(device)
235+
236+
epoch = 1
237+
criterion = torch.nn.CrossEntropyLoss()
238+
optimizer = torch.optim.Adam(model.parameters())
239+
240+
for i in range(epoch):
241+
data_iter = iter(train_loader)
242+
while (ret := next(data_iter, None)) is not None:
243+
images, labels = ret[0], ret[1]
244+
images = images.reshape(-1, 28 * 28).to(device)
245+
labels = labels.to(device)
246+
loss = criterion(model(images), labels)
247+
248+
print(f"Iteration: {i}, Loss: {loss.item()}")
249+
250+
optimizer.zero_grad()
251+
loss.backward()
252+
optimizer.step()
253+
if model_save:
254+
torch.save(model, "nn.ckpt")
255+
256+
return model
257+
258+
214259
if __name__ == "__main__":
215260
train_loader, test_loader = torch_dataset_loader(False)
216-
# _ = transfer_learning(train_loader)
217-
explore_transfer_learning(test_loader)
261+
_ = transfer_learning(train_loader)
262+
explore_transfer_learning(test_loader, "model.ckpt")
263+
# explore_transfer_learning(test_loader, "nn.ckpt")

0 commit comments

Comments
 (0)