Skip to content

Commit 8ee7008

Browse files
committed
Add CMLE multi-gpu sample
1 parent c4b38a2 commit 8ee7008

File tree

7 files changed

+389
-0
lines changed

7 files changed

+389
-0
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
trainingInput:
2+
scaleTier: CUSTOM
3+
masterType: complex_model_l_gpu

Experimental/distribution/multi-gpu/cmle/project/__init__.py

Whitespace-only changes.
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from setuptools import find_packages
2+
from setuptools import setup
3+
4+
REQUIRED_PACKAGES = []
5+
6+
setup(
7+
name='trainer',
8+
version='0.1',
9+
install_requires=REQUIRED_PACKAGES,
10+
packages=find_packages(),
11+
include_package_data=True,
12+
description='Generic example trainer package.',
13+
)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2018 Google Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright 2018 Google Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import tensorflow as tf
16+
17+
18+
def _conv(x,kernel, name, log=False):
19+
with tf.variable_scope(name):
20+
W = tf.get_variable(initializer=tf.truncated_normal(shape=kernel, stddev=0.01), name='W')
21+
b = tf.get_variable(initializer=tf.constant(0.0, shape=[kernel[3]]), name='b')
22+
conv = tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME')
23+
activation = tf.nn.relu(tf.add(conv,b))
24+
pool = tf.nn.max_pool(activation, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')
25+
if log == True:
26+
tf.summary.histogram("weights", W)
27+
tf.summary.histogram("biases", b)
28+
tf.summary.histogram("activations", activation)
29+
return pool
30+
31+
32+
def _dense(x,size_in,size_out,name,relu=False,log=False):
33+
with tf.variable_scope(name):
34+
flat = tf.reshape(x, [-1, size_in])
35+
W = tf.get_variable(initializer=tf.truncated_normal([size_in,size_out], stddev=0.1), name='W')
36+
b = tf.get_variable(initializer=tf.constant(0.0, shape=[size_out]), name='b')
37+
activation = tf.add(tf.matmul(flat, W), b)
38+
if relu==True:
39+
activation = tf.nn.relu(activation)
40+
if log==True:
41+
tf.summary.histogram("weights", W)
42+
tf.summary.histogram("biases", b)
43+
tf.summary.histogram("activations", activation)
44+
return activation
45+
46+
47+
def _model(features, mode, params):
48+
input_layer = tf.reshape(features, [-1, 32, 32, 3])
49+
conv1 = _conv(input_layer, kernel=[5,5,3,128], name='conv1', log=params['log'])
50+
conv2 = _conv(conv1, kernel=[5,5,128,128], name='conv2', log=params['log'])
51+
conv3 = _conv(conv2, kernel=[3,3,128,256], name='conv3', log=params['log'])
52+
conv4 = _conv(conv3, kernel=[3,3,256,512], name='conv4', log=params['log'])
53+
dense = _dense(conv4, size_in=2*2*512, size_out=params['dense_units'],
54+
name='Dense', relu=True, log=params['log'])
55+
56+
if mode==tf.estimator.ModeKeys.TRAIN:
57+
dense = tf.nn.dropout(dense, params['drop_out'])
58+
59+
logits = _dense(dense, size_in=params['dense_units'],
60+
size_out=10, name='Output', relu=False, log=params['log'])
61+
return logits
62+
63+
64+
def model_fn(features, labels, mode, params):
65+
logits = _model(features, mode, params)
66+
predictions = {"logits": logits,
67+
"classes": tf.argmax(input=logits,axis=1),
68+
"probabilities": tf.nn.softmax(logits,name='softmax')}
69+
export_outputs = {'predictions': tf.estimator.export.PredictOutput(predictions)}
70+
71+
if (mode==tf.estimator.ModeKeys.TRAIN or mode==tf.estimator.ModeKeys.EVAL):
72+
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels,logits=logits)
73+
74+
if mode == tf.estimator.ModeKeys.TRAIN:
75+
learning_rate = tf.train.exponential_decay(params['learning_rate'],
76+
tf.train.get_global_step(),
77+
decay_steps=100000,
78+
decay_rate=0.96)
79+
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
80+
train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step())
81+
tf.summary.scalar('learning_rate', learning_rate)
82+
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
83+
84+
if mode == tf.estimator.ModeKeys.EVAL:
85+
accuracy = tf.metrics.accuracy(
86+
labels=labels, predictions=tf.argmax(logits, axis=1))
87+
metrics = {'accuracy':accuracy}
88+
return tf.estimator.EstimatorSpec(mode=mode,loss=loss, eval_metric_ops=metrics)
89+
90+
if mode == tf.estimator.ModeKeys.PREDICT:
91+
return tf.estimator.EstimatorSpec(
92+
mode=mode, predictions=predictions, export_outputs=export_outputs)
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Copyright 2018 Google Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import tensorflow as tf
17+
18+
import trainer.sample_model as sm
19+
20+
FLAGS = tf.app.flags.FLAGS
21+
22+
tf.app.flags.DEFINE_integer(
23+
'max_steps', 1000, 'max_step for training.')
24+
tf.app.flags.DEFINE_string(
25+
'output_dir', '', 'GCS location to root directory for checkpoints and exported models.')
26+
tf.app.flags.DEFINE_string(
27+
'model_name', 'sample_model', 'model name.')
28+
tf.app.flags.DEFINE_integer(
29+
'train_batch_size', 200, 'batch size for training.')
30+
tf.app.flags.DEFINE_integer(
31+
'eval_batch_size', 200, 'batch size for evaluation.')
32+
tf.app.flags.DEFINE_integer(
33+
'eval_steps', 50, 'The number of steps that are used in evaluation phase.')
34+
tf.app.flags.DEFINE_integer(
35+
'tf_random_seed', 19851211, '')
36+
tf.app.flags.DEFINE_integer(
37+
'save_checkpoints_steps', 500, '')
38+
tf.app.flags.DEFINE_string(
39+
'train_data_pattern', 'cifar-10/train*.tfrecord', 'path to train dataset on GCS.')
40+
tf.app.flags.DEFINE_string(
41+
'eval_data_pattern', 'cifar-10/valid*.tfrecord', 'path to eval dataset on GCS.')
42+
tf.app.flags.DEFINE_float(
43+
'learning_rate', 1e-3, 'learning rate.')
44+
tf.app.flags.DEFINE_integer(
45+
'num_gpus', 1, 'num of gpus in single-node-multi-GPUs setting.')
46+
tf.app.flags.DEFINE_integer(
47+
'num_gpus_per_worker', 0, 'num of gpus for each node.')
48+
tf.app.flags.DEFINE_bool(
49+
'auto_shard_dataset', False,
50+
'whether to auto-shard the dataset when there are multiple workers.')
51+
tf.app.flags.DEFINE_float(
52+
'drop_out_rate', 1e-2, 'drop out rate')
53+
tf.app.flags.DEFINE_integer(
54+
'dense_units', 1024, 'units in dense layer.')
55+
56+
tf.logging.set_verbosity(tf.logging.INFO)
57+
58+
def parse_tfrecord(example):
59+
feature={'label': tf.FixedLenFeature((), tf.int64),
60+
'image': tf.FixedLenFeature((), tf.string, default_value="")}
61+
parsed = tf.parse_single_example(example, feature)
62+
image = tf.decode_raw(parsed['image'],tf.float64)
63+
image = tf.cast(image,tf.float32)
64+
image = tf.reshape(image,[32,32,3])
65+
return image, parsed['label']
66+
67+
68+
def image_scaling(x):
69+
return tf.image.per_image_standardization(x)
70+
71+
def distort(x):
72+
x = tf.image.resize_image_with_crop_or_pad(x, 40, 40)
73+
x = tf.random_crop(x, [32, 32, 3])
74+
x = tf.image.random_flip_left_right(x)
75+
return x
76+
77+
def dataset_input_fn(params):
78+
dataset = tf.data.TFRecordDataset(params['filenames'],
79+
num_parallel_reads=params['threads'])
80+
dataset = dataset.map(parse_tfrecord, num_parallel_calls=params['threads'])
81+
dataset = dataset.map(
82+
lambda x,y: (image_scaling(x),y), num_parallel_calls=params['threads'])
83+
if params['mode']==tf.estimator.ModeKeys.TRAIN:
84+
dataset = dataset.map(
85+
lambda x,y: (distort(x),y), num_parallel_calls=params['threads'])
86+
dataset = dataset.shuffle(buffer_size=params['shuffle_buff'])
87+
dataset = dataset.repeat()
88+
dataset = dataset.batch(params['batch'])
89+
dataset = dataset.prefetch(8*params['batch'])
90+
return dataset
91+
92+
93+
def train_dataset_input_fn(pattern):
94+
files = tf.gfile.Glob(pattern)
95+
params = {'filenames': files, 'mode': tf.estimator.ModeKeys.TRAIN,
96+
'threads': 16, 'shuffle_buff': 100000, 'batch': FLAGS.train_batch_size}
97+
return dataset_input_fn(params)
98+
99+
100+
def eval_dataset_input_fn(pattern):
101+
files = tf.gfile.Glob(pattern)
102+
params = {'filenames': tf.gfile.Glob(pattern), 'mode': tf.estimator.ModeKeys.EVAL,
103+
'threads': 16, 'batch': FLAGS.eval_batch_size}
104+
return dataset_input_fn(params)
105+
106+
107+
def serving_input_fn():
108+
receiver_tensor = {'images': tf.placeholder(shape=[None, 32, 32, 3], dtype=tf.float32)}
109+
features = tf.map_fn(image_scaling, receiver_tensor['images'])
110+
return tf.estimator.export.TensorServingInputReceiver(features, receiver_tensor)
111+
112+
113+
def train_and_evaluate():
114+
model_dir = os.path.join(FLAGS.output_dir, FLAGS.model_name)
115+
116+
# MirroredStrategy
117+
if FLAGS.num_gpus_per_worker > 0:
118+
distribution = tf.contrib.distribute.MirroredStrategy(
119+
num_gpus_per_worker=FLAGS.num_gpus_per_worker,
120+
auto_shard_dataset=FLAGS.auto_shard_dataset)
121+
elif FLAGS.num_gpus > 0:
122+
distribution = tf.contrib.distribute.MirroredStrategy(num_gpus=FLAGS.num_gpus)
123+
else:
124+
distribution = None
125+
126+
# Configuration for Estimator
127+
config = tf.estimator.RunConfig(
128+
save_checkpoints_secs=FLAGS.save_checkpoints_steps,
129+
keep_checkpoint_max=5,
130+
session_config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True),
131+
train_distribute=distribution,
132+
tf_random_seed=FLAGS.tf_random_seed)
133+
134+
model_params = {
135+
'drop_out': FLAGS.drop_out_rate,
136+
'dense_units': FLAGS.dense_units,
137+
'learning_rate': FLAGS.learning_rate,
138+
'log': True}
139+
140+
# Create Estimator.
141+
estimator = tf.estimator.Estimator(
142+
model_fn=sm.model_fn,
143+
model_dir=model_dir,
144+
params=model_params,
145+
config=config)
146+
147+
# Specify training data paths, batch size and max steps.
148+
train_spec = tf.estimator.TrainSpec(
149+
input_fn=lambda: train_dataset_input_fn(FLAGS.train_data_pattern),
150+
max_steps=FLAGS.max_steps)
151+
152+
# Configuration for model exportation
153+
exporter = tf.estimator.LatestExporter(
154+
name='export',
155+
serving_input_receiver_fn=serving_input_fn,
156+
assets_extra=None, as_text=False, exports_to_keep=5)
157+
158+
# Specify validation data paths, steps for evaluation and exporter specs
159+
eval_spec = tf.estimator.EvalSpec(
160+
input_fn=lambda: eval_dataset_input_fn(FLAGS.eval_data_pattern),
161+
steps=FLAGS.eval_steps, exporters=exporter)
162+
163+
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
164+
165+
def main(unused_argv=None):
166+
tf.logging.info(tf.__version__)
167+
train_and_evaluate()
168+
169+
if __name__ == '__main__':
170+
tf.app.run()
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"### Single host multi gpus (K80 * 8)"
8+
]
9+
},
10+
{
11+
"cell_type": "code",
12+
"execution_count": 2,
13+
"metadata": {},
14+
"outputs": [
15+
{
16+
"name": "stdout",
17+
"output_type": "stream",
18+
"text": [
19+
"jobId: sample_model_20181221_112757\n",
20+
"state: QUEUED\n"
21+
]
22+
},
23+
{
24+
"name": "stderr",
25+
"output_type": "stream",
26+
"text": [
27+
"Job [sample_model_20181221_112757] submitted successfully.\n",
28+
"Your job is still active. You may view the status of your job with the command\n",
29+
"\n",
30+
" $ gcloud ml-engine jobs describe sample_model_20181221_112757\n",
31+
"\n",
32+
"or continue streaming the logs with the command\n",
33+
"\n",
34+
" $ gcloud ml-engine jobs stream-logs sample_model_20181221_112757\n"
35+
]
36+
}
37+
],
38+
"source": [
39+
"%%bash\n",
40+
"\n",
41+
"PROJECT_ID=\"YOUR-PROJECT-ID\"\n",
42+
"BUCKET_ID=\"YOUR-BUCKET-ID\"\n",
43+
"REGION=\"YOUR-REGION\"\n",
44+
"\n",
45+
"TRAINER_PACKAGE_PATH=$(pwd)/project/trainer\n",
46+
"now=$(date +\"%Y%m%d_%H%M%S\")\n",
47+
"JOB_NAME=\"sample_model_$now\"\n",
48+
"MAIN_TRAINER_MODULE=trainer.task\n",
49+
"JOB_DIR=gs://$BUCKET_ID/job\n",
50+
"PACKAGE_STAGING_PATH=gs://$BUCKET_ID/staging\n",
51+
"#https://cloud.google.com/ml-engine/docs/tensorflow/regions\n",
52+
"\n",
53+
"\n",
54+
"JOB_DIR=gs://$BUCKET_ID/sample_model_job_dir\n",
55+
"SCALE_TIER=BASIC\n",
56+
"RUNTIME_VERSION=\"1.11\"\n",
57+
"# https://cloud.google.com/ml-engine/docs/tensorflow/runtime-version-list\n",
58+
"\n",
59+
"gcloud ml-engine jobs submit training $JOB_NAME \\\n",
60+
" --package-path $TRAINER_PACKAGE_PATH \\\n",
61+
" --module-name $MAIN_TRAINER_MODULE \\\n",
62+
" --job-dir $JOB_DIR \\\n",
63+
" --project $PROJECT_ID \\\n",
64+
" --region $REGION \\\n",
65+
" --runtime-version $RUNTIME_VERSION \\\n",
66+
" --config config.yaml \\\n",
67+
" -- \\\n",
68+
" --train_data_pattern \"gs://$BUCKET_ID/data/cifar10_data_00*\" \\\n",
69+
" --eval_data_pattern \"gs://$BUCKET_ID/data/cifar10_data_01*\" \\\n",
70+
" --max_steps 10000 \\\n",
71+
" --num_gpus 8 \\\n",
72+
" --output_dir \"gs://$BUCKET_ID/model\""
73+
]
74+
}
75+
],
76+
"metadata": {
77+
"kernelspec": {
78+
"display_name": "Python 3",
79+
"language": "python",
80+
"name": "python3"
81+
},
82+
"language_info": {
83+
"codemirror_mode": {
84+
"name": "ipython",
85+
"version": 3
86+
},
87+
"file_extension": ".py",
88+
"mimetype": "text/x-python",
89+
"name": "python",
90+
"nbconvert_exporter": "python",
91+
"pygments_lexer": "ipython3",
92+
"version": "3.6.7"
93+
}
94+
},
95+
"nbformat": 4,
96+
"nbformat_minor": 2
97+
}

0 commit comments

Comments
 (0)