Skip to content

Commit a981fef

Browse files
committed
update
1 parent 4ec6c08 commit a981fef

File tree

4 files changed

+96
-6
lines changed

4 files changed

+96
-6
lines changed

chapter4_CNN/alexnet.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from torch import nn
2+
3+
4+
class AlexNet(nn.Module):
5+
def __init__(self, num_classes):
6+
super(AlexNet, self).__init__()
7+
self.features = nn.Sequential(
8+
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
9+
nn.ReLU(inplace=True),
10+
nn.MaxPool2d(kernel_size=3, stride=2),
11+
nn.Conv2d(64, 192, kernel_size=5, padding=2),
12+
nn.ReLU(inplace=True),
13+
nn.MaxPool2d(kernel_size=3, stride=2),
14+
nn.Conv2d(192, 384, kernel_size=3, padding=1),
15+
nn.ReLU(inplace=True),
16+
nn.Conv2d(384, 256, kernel_size=3, padding=1),
17+
nn.ReLU(inplace=True),
18+
nn.Conv2d(256, 256, kernel_size=3, padding=1),
19+
nn.ReLU(inplace=True),
20+
nn.MaxPool2d(kernel_size=3, stride=2), )
21+
self.classifier = nn.Sequential(
22+
nn.Dropout(),
23+
nn.Linear(256 * 6 * 6, 4096),
24+
nn.ReLU(inplace=True),
25+
nn.Dropout(),
26+
nn.Linear(4096, 4096),
27+
nn.ReLU(inplace=True),
28+
nn.Linear(4096, num_classes), )
29+
30+
def forward(self, x):
31+
x = self.features(x)
32+
x = x.view(x.size(0), 256 * 6 * 6)
33+
x = self.classifier(x)
34+
return x

chapter4_CNN/lenet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,6 @@ def __init__(self):
2424
def forward(self, x):
2525
x = self.layer1(x)
2626
x = self.layer2(x)
27+
x = x.view(x.size(0), -1)
2728
x = self.layer3(x)
2829
return x

chapter4_CNN/net_structure.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,14 @@ def forward(self, x):
4949

5050
conv_model = nn.Sequential()
5151
for layer in model.named_modules():
52-
if 'conv' in layer[0]:
52+
if isinstance(layer[1], nn.Conv2d):
5353
conv_model.add_module(layer[0], layer[1])
5454

55-
for param in model.named_parameters():
56-
if 'conv' in param[0] and 'weight' in param[0]:
57-
init.normal(param[1].data)
58-
init.xavier_normal(param[1].data)
59-
init.kaiming_normal(param[1].data)
55+
for m in model.modules():
56+
if isinstance(m, nn.Conv2d):
57+
init.normal(m.weight.data)
58+
init.xavier_normal(m.weight.data)
59+
init.kaiming_normal(m.weight.data)
60+
m.bias.data.fill_(0)
61+
elif isinstance(m, nn.Linear):
62+
m.weight.data.normal_()

chapter4_CNN/vgg.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from torch import nn
2+
3+
4+
class VGG(nn.Module):
5+
def __init__(self, num_classes):
6+
super(VGG, self).__init__()
7+
self.features = nn.Sequential(
8+
nn.Conv2d(3, 64, kernel_size=3, padding=1),
9+
nn.ReLU(True),
10+
nn.Conv2d(64, 64, kernel_size=3, padding=1),
11+
nn.ReLU(True),
12+
nn.MaxPool2d(kernel_size=2, stride=2),
13+
nn.Conv2d(64, 128, kernel_size=3, padding=1),
14+
nn.ReLU(True),
15+
nn.Conv2d(128, 128, kernel_size=3, padding=1),
16+
nn.ReLU(True),
17+
nn.MaxPool2d(kernel_size=2, stride=2),
18+
nn.Conv2d(128, 256, kernel_size=3, padding=1),
19+
nn.ReLU(True),
20+
nn.Conv2d(256, 256, kernel_size=3, padding=1),
21+
nn.ReLU(True),
22+
nn.Conv2d(256, 256, kernel_size=3, padding=1),
23+
nn.ReLU(True),
24+
nn.MaxPool2d(kernel_size=2, stride=2),
25+
nn.Conv2d(256, 512, kernel_size=3, padding=1),
26+
nn.ReLU(True),
27+
nn.Conv2d(512, 512, kernel_size=3, padding=1),
28+
nn.ReLU(True),
29+
nn.Conv2d(512, 512, kernel_size=3, padding=1),
30+
nn.ReLU(True),
31+
nn.MaxPool2d(kernel_size=2, stride=2),
32+
nn.Conv2d(512, 512, kernel_size=3, padding=1),
33+
nn.ReLU(True),
34+
nn.Conv2d(512, 512, kernel_size=3, padding=1),
35+
nn.ReLU(True),
36+
nn.Conv2d(512, 512, kernel_size=3, padding=1),
37+
nn.ReLU(True),
38+
nn.MaxPool2d(kernel_size=2, stride=2), )
39+
self.classifier = nn.Sequential(
40+
nn.Linear(512 * 7 * 7, 4096),
41+
nn.ReLU(True),
42+
nn.Dropout(),
43+
nn.Linear(4096, 4096),
44+
nn.ReLU(True),
45+
nn.Dropout(),
46+
nn.Linear(4096, num_classes), )
47+
self._initialize_weights()
48+
49+
def forward(self, x):
50+
x = self.features(x)
51+
x = x.view(x.size(0), -1)
52+
x = self.classifier(x)

0 commit comments

Comments
 (0)