|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +# @Author: Arunabh Sharma |
| 3 | +# @Date: 2024-02-21 23:38:56 |
| 4 | +# @Last Modified by: Arunabh Sharma |
| 5 | +# @Last Modified time: 2024-02-25 22:14:30 |
| 6 | + |
| 7 | +import torch |
| 8 | +import torchvision |
| 9 | +import numpy as np |
| 10 | +import torchvision.transforms as transforms |
| 11 | +import cv2 |
| 12 | + |
| 13 | +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 14 | + |
| 15 | + |
| 16 | +def torch_dataset_loader(img_display=False): |
| 17 | + train_dataset = torchvision.datasets.MNIST( |
| 18 | + root="../data/", |
| 19 | + train=True, |
| 20 | + transform=torchvision.transforms.ToTensor(), |
| 21 | + download=True, |
| 22 | + ) |
| 23 | + test_dataset = torchvision.datasets.MNIST( |
| 24 | + root="../data/", |
| 25 | + train=False, |
| 26 | + transform=torchvision.transforms.ToTensor(), |
| 27 | + download=True, |
| 28 | + ) |
| 29 | + |
| 30 | + train_loader = torch.utils.data.DataLoader( |
| 31 | + dataset=train_dataset, batch_size=100, shuffle=True |
| 32 | + ) |
| 33 | + test_loader = torch.utils.data.DataLoader( |
| 34 | + dataset=test_dataset, batch_size=1, shuffle=False |
| 35 | + ) |
| 36 | + |
| 37 | + data_iter = iter(train_loader) |
| 38 | + images, labels = next(data_iter) |
| 39 | + print(images.size()) |
| 40 | + print(labels.size()) |
| 41 | + |
| 42 | + img = images[0].numpy().transpose(1, 2, 0) |
| 43 | + img = cv2.resize(img, (800, 800)) |
| 44 | + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) |
| 45 | + print(img.shape) |
| 46 | + |
| 47 | + if img_display: |
| 48 | + cv2.imshow("images", img) |
| 49 | + cv2.waitKey(0) |
| 50 | + cv2.destroyAllWindows() |
| 51 | + |
| 52 | + return train_loader, test_loader |
| 53 | + |
| 54 | + |
| 55 | +class ConvNet(torch.nn.Module): |
| 56 | + def __init__(self): |
| 57 | + super(ConvNet, self).__init__() |
| 58 | + self.layer1 = torch.nn.Sequential( |
| 59 | + torch.nn.Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1)), |
| 60 | + torch.nn.BatchNorm2d(16), |
| 61 | + torch.nn.ReLU(), |
| 62 | + torch.nn.MaxPool2d(kernel_size=(2, 2)), |
| 63 | + ) |
| 64 | + |
| 65 | + self.layer2 = torch.nn.Sequential( |
| 66 | + torch.nn.Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1)), |
| 67 | + torch.nn.BatchNorm2d(32), |
| 68 | + torch.nn.ReLU(), |
| 69 | + torch.nn.MaxPool2d(kernel_size=(2, 2)), |
| 70 | + ) |
| 71 | + |
| 72 | + self.fc1 = torch.nn.Linear(512, 10) |
| 73 | + |
| 74 | + def forward(self, x): |
| 75 | + x = self.layer1(x) |
| 76 | + x = self.layer2(x) |
| 77 | + print(x.shape) |
| 78 | + x = x.reshape(x.size(0), -1) |
| 79 | + print(x.shape) |
| 80 | + x = self.fc1(x) |
| 81 | + return x |
| 82 | + |
| 83 | + |
| 84 | +def train_model(model, epoch, train_loader, model_name, model_save): |
| 85 | + model = model.to(device) |
| 86 | + |
| 87 | + criterion = torch.nn.CrossEntropyLoss() |
| 88 | + optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) |
| 89 | + |
| 90 | + for i in range(epoch): |
| 91 | + data_iter = iter(train_loader) |
| 92 | + while (ret := next(data_iter, None)) is not None: |
| 93 | + images, labels = ret[0], ret[1] |
| 94 | + # images = images.to(device) |
| 95 | + images = images.reshape(-1, 28, 28).to(device) |
| 96 | + labels = labels.to(device) |
| 97 | + loss = criterion(model(images), labels) |
| 98 | + |
| 99 | + print(f"Iteration: {i}, Loss: {loss.item()}") |
| 100 | + |
| 101 | + optimizer.zero_grad() |
| 102 | + loss.backward() |
| 103 | + optimizer.step() |
| 104 | + if model_save: |
| 105 | + torch.save(model, model_name) |
| 106 | + |
| 107 | + return model |
| 108 | + |
| 109 | + |
| 110 | +def update_lr(optimizer, lr): |
| 111 | + for param_group in optimizer.param_groups: |
| 112 | + param_group["lr"] = lr |
| 113 | + |
| 114 | + |
| 115 | +def train_model_decay_lr(model, epoch, lr, train_loader, model_name, model_save): |
| 116 | + model = model.to(device) |
| 117 | + |
| 118 | + curr_lr = lr |
| 119 | + criterion = torch.nn.CrossEntropyLoss() |
| 120 | + optimizer = torch.optim.Adam(model.parameters(), lr=curr_lr) |
| 121 | + |
| 122 | + for i in range(epoch): |
| 123 | + data_iter = iter(train_loader) |
| 124 | + while (ret := next(data_iter, None)) is not None: |
| 125 | + images, labels = ret[0], ret[1] |
| 126 | + images = images.to(device) |
| 127 | + labels = labels.to(device) |
| 128 | + loss = criterion(model(images), labels) |
| 129 | + |
| 130 | + print(f"Iteration: {i}, Loss: {loss.item()}") |
| 131 | + |
| 132 | + optimizer.zero_grad() |
| 133 | + loss.backward() |
| 134 | + optimizer.step() |
| 135 | + curr_lr /= 3.0 |
| 136 | + update_lr(optimizer, curr_lr) |
| 137 | + if model_save: |
| 138 | + torch.save(model, model_name) |
| 139 | + |
| 140 | + return model |
| 141 | + |
| 142 | + |
| 143 | +def explore_model(test_loader, model_name): |
| 144 | + model = torch.load(model_name).to("cpu") |
| 145 | + model.eval() |
| 146 | + |
| 147 | + window_name = "Image" |
| 148 | + |
| 149 | + # font |
| 150 | + font = cv2.FONT_HERSHEY_SIMPLEX |
| 151 | + |
| 152 | + # org |
| 153 | + org0 = (30, 30) |
| 154 | + org1 = (30, 60) |
| 155 | + |
| 156 | + # fontScale |
| 157 | + fontScale = 1 |
| 158 | + |
| 159 | + # Blue color in BGR |
| 160 | + color = (0, 255, 0) |
| 161 | + |
| 162 | + # Line thickness of 2 px |
| 163 | + thickness = 2 |
| 164 | + |
| 165 | + with torch.no_grad(): |
| 166 | + for image, label in test_loader: |
| 167 | + image = image.reshape(-1, 28, 28) |
| 168 | + pred_label = model(image) |
| 169 | + img = image[0].numpy() |
| 170 | + img = cv2.resize(img, (800, 800)) |
| 171 | + img = np.dstack((img, img, img)) |
| 172 | + img = cv2.putText( |
| 173 | + img, |
| 174 | + "GT Label:" + str(int(label[0])), |
| 175 | + org0, |
| 176 | + font, |
| 177 | + fontScale, |
| 178 | + color, |
| 179 | + thickness, |
| 180 | + cv2.LINE_AA, |
| 181 | + ) |
| 182 | + |
| 183 | + pred_label = np.argmax(pred_label.numpy()) |
| 184 | + img = cv2.putText( |
| 185 | + img, |
| 186 | + "Pred Label:" + str(int(pred_label)), |
| 187 | + org1, |
| 188 | + font, |
| 189 | + fontScale, |
| 190 | + color, |
| 191 | + thickness, |
| 192 | + cv2.LINE_AA, |
| 193 | + ) |
| 194 | + cv2.imshow(window_name, img) |
| 195 | + if (cv2.waitKey(0) & 0xFF) == ord("q"): |
| 196 | + cv2.destroyAllWindows() |
| 197 | + exit(0) |
| 198 | + |
| 199 | + |
| 200 | +class ResnetBlock(torch.nn.Module): |
| 201 | + def __init__(self, in_channels, out_channels, downsample): |
| 202 | + super(ResnetBlock, self).__init__() |
| 203 | + |
| 204 | + if downsample: |
| 205 | + self.conv1 = self.conv_layer(in_channels, out_channels, 3, 2) |
| 206 | + self.ds = self.conv_layer(in_channels, out_channels, 1, 2, 0) |
| 207 | + else: |
| 208 | + self.conv1 = self.conv_layer(in_channels, out_channels, 3, 1) |
| 209 | + self.ds = torch.nn.Identity() |
| 210 | + self.bn1 = torch.nn.BatchNorm2d(out_channels) |
| 211 | + self.nl1 = torch.nn.ReLU(inplace=True) |
| 212 | + self.conv2 = self.conv_layer(out_channels, out_channels, 3, 1) |
| 213 | + |
| 214 | + def conv_layer(self, in_channels, out_channels, kernel_size, stride, padding=1): |
| 215 | + return torch.nn.Conv2d( |
| 216 | + in_channels, out_channels, kernel_size, stride=stride, padding=padding |
| 217 | + ) |
| 218 | + |
| 219 | + def forward(self, x): |
| 220 | + out = self.conv1(x) |
| 221 | + residual = self.ds(x) |
| 222 | + out = self.bn1(out) |
| 223 | + out = self.nl1(out) |
| 224 | + out = self.conv2(out) |
| 225 | + out = self.bn1(out) |
| 226 | + out += residual |
| 227 | + out = self.nl1(out) |
| 228 | + |
| 229 | + return out |
| 230 | + |
| 231 | + |
| 232 | +class Resnet(torch.nn.Module): |
| 233 | + def __init__(self): |
| 234 | + super(Resnet, self).__init__() |
| 235 | + |
| 236 | + self.layer1 = ResnetBlock(1, 16, False) |
| 237 | + self.layer2 = ResnetBlock(16, 32, True) |
| 238 | + self.avg_pool = torch.nn.AvgPool2d(2) |
| 239 | + self.fc = torch.nn.Linear(7 * 7 * 32, 10) |
| 240 | + |
| 241 | + def forward(self, x): |
| 242 | + out = self.layer1(x) |
| 243 | + out = self.layer2(out) |
| 244 | + out = self.avg_pool(out) |
| 245 | + out = out.reshape(out.size(0), -1) |
| 246 | + out = self.fc(out) |
| 247 | + return out |
| 248 | + |
| 249 | + |
| 250 | +class RNN(torch.nn.Module): |
| 251 | + def __init__(self, input_size, hidden_size, num_layers, num_classes): |
| 252 | + super(RNN, self).__init__() |
| 253 | + self.hidden_size = hidden_size |
| 254 | + self.num_layers = num_layers |
| 255 | + self.lstm = torch.nn.LSTM(input_size, hidden_size, num_layers, batch_first=True) |
| 256 | + self.fc = torch.nn.Linear(hidden_size, num_classes) |
| 257 | + |
| 258 | + def forward(self, x): |
| 259 | + # Set initial hidden and cell states |
| 260 | + h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size) |
| 261 | + c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size) |
| 262 | + |
| 263 | + # Forward propagate LSTM |
| 264 | + out, _ = self.lstm( |
| 265 | + x, (h0, c0) |
| 266 | + ) # out: tensor of shape (batch_size, seq_length, hidden_size) |
| 267 | + |
| 268 | + # Decode the hidden state of the last time step |
| 269 | + out = self.fc(out[:, -1, :]) |
| 270 | + return out |
| 271 | + |
| 272 | + |
| 273 | +if __name__ == "__main__": |
| 274 | + train_loader, test_loader = torch_dataset_loader(False) |
| 275 | + model_name = "rnn.ckpt" |
| 276 | + model = RNN(28, 128, 2, 10) |
| 277 | + # _ = train_model(model, 4, train_loader, model_name, True) |
| 278 | + explore_model(test_loader, model_name) |
0 commit comments