Skip to content

Commit b1cf293

Browse files
committed
update 52
1 parent 4901f05 commit b1cf293

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

深度学习与TensorFlow入门实战-源码和PPT/lesson52-自定义数据集和迁移学习/train_scratch.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import os
2-
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
2+
3+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
4+
35

46
import tensorflow as tf
57
import numpy as np
@@ -8,7 +10,7 @@
810
from tensorflow.keras.callbacks import EarlyStopping
911

1012
tf.random.set_seed(22)
11-
np.random.seed(22)
13+
np.random.seed(22)
1214
assert tf.__version__.startswith('2.')
1315
# 设置GPU显存按需分配
1416
gpus = tf.config.experimental.list_physical_devices('GPU')

深度学习与TensorFlow入门实战-源码和PPT/lesson52-自定义数据集和迁移学习/train_transfer.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
import os
2-
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
3-
42
import tensorflow as tf
53
import numpy as np
64
from tensorflow import keras
@@ -16,7 +14,8 @@
1614
from tensorflow.keras.callbacks import EarlyStopping
1715

1816
tf.random.set_seed(22)
19-
np.random.seed(22)
17+
np.random.seed(22)
18+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
2019
assert tf.__version__.startswith('2.')
2120

2221

@@ -44,21 +43,20 @@ def preprocess(x,y):
4443

4544

4645
batchsz = 128
47-
48-
# creat train db
46+
# 创建训练集Datset对象
4947
images, labels, table = load_pokemon('pokemon',mode='train')
5048
db_train = tf.data.Dataset.from_tensor_slices((images, labels))
5149
db_train = db_train.shuffle(1000).map(preprocess).batch(batchsz)
52-
# crate validation db
50+
# 创建验证集Datset对象
5351
images2, labels2, table = load_pokemon('pokemon',mode='val')
5452
db_val = tf.data.Dataset.from_tensor_slices((images2, labels2))
5553
db_val = db_val.map(preprocess).batch(batchsz)
56-
# create test db
54+
# 创建测试集Datset对象
5755
images3, labels3, table = load_pokemon('pokemon',mode='test')
5856
db_test = tf.data.Dataset.from_tensor_slices((images3, labels3))
5957
db_test = db_test.map(preprocess).batch(batchsz)
6058

61-
59+
#
6260
net = keras.applications.VGG19(weights='imagenet', include_top=False,
6361
pooling='max')
6462
net.trainable = False

0 commit comments

Comments
 (0)