1+ import torch
2+ import torch .nn as nn
3+ import torchvision
4+ import torchvision .transforms as transforms
5+
6+
7+ # Device configuration
8+ device = torch .device ('cuda' if torch .cuda .is_available () else 'cpu' )
9+
10+ # Hyper-parameters
11+ input_size = 784
12+ hidden_size = 500
13+ num_classes = 10
14+ num_epochs = 5
15+ batch_size = 100
16+ learning_rate = 0.001
17+
18+ # MNIST dataset
19+ train_dataset = torchvision .datasets .MNIST (root = '../../data' ,
20+ train = True ,
21+ transform = transforms .ToTensor (),
22+ download = True )
23+
24+ test_dataset = torchvision .datasets .MNIST (root = '../../data' ,
25+ train = False ,
26+ transform = transforms .ToTensor ())
27+
28+ # Data loader
29+ train_loader = torch .utils .data .DataLoader (dataset = train_dataset ,
30+ batch_size = batch_size ,
31+ shuffle = True )
32+
33+ test_loader = torch .utils .data .DataLoader (dataset = test_dataset ,
34+ batch_size = batch_size ,
35+ shuffle = False )
36+
37+ # Fully connected neural network with one hidden layer
38+ class NeuralNet (nn .Module ):
39+ def __init__ (self , input_size , hidden_size , num_classes ):
40+ super (NeuralNet , self ).__init__ ()
41+ self .fc1 = nn .Linear (input_size , hidden_size )
42+ self .relu = nn .ReLU ()
43+ self .fc2 = nn .Linear (hidden_size , num_classes )
44+
45+ def forward (self , x ):
46+ out = self .fc1 (x )
47+ out = self .relu (out )
48+ out = self .fc2 (out )
49+ return out
50+
51+ model = NeuralNet (input_size , hidden_size , num_classes ).to (device )
52+
53+ # Loss and optimizer
54+ criterion = nn .CrossEntropyLoss ()
55+ optimizer = torch .optim .Adam (model .parameters (), lr = learning_rate )
56+
57+ # Train the model
58+ total_step = len (train_loader )
59+ for epoch in range (num_epochs ):
60+ for i , (images , labels ) in enumerate (train_loader ):
61+ # Move tensors to the configured device
62+ images = images .reshape (- 1 , 28 * 28 ).to (device )
63+ labels = labels .to (device )
64+
65+ # Forward pass
66+ outputs = model (images )
67+ loss = criterion (outputs , labels )
68+
69+ # Backward and optimize
70+ optimizer .zero_grad ()
71+ loss .backward ()
72+ optimizer .step ()
73+
74+ if (i + 1 ) % 100 == 0 :
75+ print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
76+ .format (epoch + 1 , num_epochs , i + 1 , total_step , loss .item ()))
77+
78+ # Test the model
79+ # In test phase, we don't need to compute gradients (for memory efficiency)
80+ with torch .no_grad ():
81+ correct = 0
82+ total = 0
83+ for images , labels in test_loader :
84+ images = images .reshape (- 1 , 28 * 28 ).to (device )
85+ labels = labels .to (device )
86+ outputs = model (images )
87+ _ , predicted = torch .max (outputs .data , 1 )
88+ total += labels .size (0 )
89+ correct += (predicted == labels ).sum ().item ()
90+
91+ print ('Accuracy of the network on the 10000 test images: {} %' .format (100 * correct / total ))
92+
93+ # Save the model checkpoint
94+ torch .save (model .state_dict (), 'model.ckpt' )
0 commit comments