Skip to content

Commit dbaa30a

Browse files
Update resnet.py
Enable training else the parameter will not get updated
1 parent add77e7 commit dbaa30a

File tree

1 file changed

+7
-7
lines changed
  • 深度学习与TensorFlow入门实战-源码和PPT/lesson43-ResNet

1 file changed

+7
-7
lines changed

深度学习与TensorFlow入门实战-源码和PPT/lesson43-ResNet/resnet.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@ def call(self, inputs, training=None):
2828

2929
# [b, h, w, c]
3030
out = self.conv1(inputs)
31-
out = self.bn1(out)
31+
out = self.bn1(out,training=training)
3232
out = self.relu(out)
3333

3434
out = self.conv2(out)
35-
out = self.bn2(out)
35+
out = self.bn2(out,training=training)
3636

3737
identity = self.downsample(inputs)
3838

@@ -71,10 +71,10 @@ def call(self, inputs, training=None):
7171

7272
x = self.stem(inputs)
7373

74-
x = self.layer1(x)
75-
x = self.layer2(x)
76-
x = self.layer3(x)
77-
x = self.layer4(x)
74+
x = self.layer1(x,training=training)
75+
x = self.layer2(x,training=training)
76+
x = self.layer3(x,training=training)
77+
x = self.layer4(x,training=training)
7878

7979
# [b, c]
8080
x = self.avgpool(x)
@@ -102,4 +102,4 @@ def resnet18():
102102

103103

104104
def resnet34():
105-
return ResNet([3, 4, 6, 3])
105+
return ResNet([3, 4, 6, 3])

0 commit comments

Comments
 (0)