Skip to content

Commit 84b58a6

Browse files
jmchen-gmrry
authored andcommitted
Implement distributed inception (tensorflow#44)
Implements a distributed trainer for Inception.
1 parent 9a1dfdf commit 84b58a6

File tree

6 files changed

+842
-302
lines changed

6 files changed

+842
-302
lines changed

inception/README.md

Lines changed: 417 additions & 285 deletions
Large diffs are not rendered by default.

inception/inception/BUILD

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,17 @@ py_binary(
102102
],
103103
)
104104

105+
py_binary(
106+
name = "imagenet_distributed_train",
107+
srcs = [
108+
"imagenet_distributed_train.py",
109+
],
110+
deps = [
111+
":imagenet_data",
112+
":inception_distributed_train",
113+
],
114+
)
115+
105116
py_binary(
106117
name = "flowers_train",
107118
srcs = [
@@ -124,6 +135,17 @@ py_library(
124135
],
125136
)
126137

138+
py_library(
139+
name = "inception_distributed_train",
140+
srcs = [
141+
"inception_distributed_train.py",
142+
],
143+
deps = [
144+
":image_processing",
145+
":inception",
146+
],
147+
)
148+
127149
py_binary(
128150
name = "build_image_data",
129151
srcs = ["data/build_image_data.py"],
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright 2016 Google Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
# pylint: disable=line-too-long
16+
"""A binary to train Inception in a distributed manner using multiple systems.
17+
18+
Please see accompanying README.md for details and instructions.
19+
"""
20+
from __future__ import absolute_import
21+
from __future__ import division
22+
from __future__ import print_function
23+
24+
import tensorflow as tf
25+
26+
from inception import inception_distributed_train
27+
from inception.imagenet_data import ImagenetData
28+
29+
FLAGS = tf.app.flags.FLAGS
30+
31+
32+
def main(unused_args):
33+
assert FLAGS.job_name in ['ps', 'worker'], 'job_name must be ps or worker'
34+
35+
# Extract all the hostnames for the ps and worker jobs to construct the
36+
# cluster spec.
37+
ps_hosts = FLAGS.ps_hosts.split(',')
38+
worker_hosts = FLAGS.worker_hosts.split(',')
39+
tf.logging.info('PS hosts are: %s' % ps_hosts)
40+
tf.logging.info('Worker hosts are: %s' % worker_hosts)
41+
42+
cluster_spec = tf.train.ClusterSpec({'ps': ps_hosts,
43+
'worker': worker_hosts})
44+
server = tf.train.Server(
45+
{'ps': ps_hosts,
46+
'worker': worker_hosts},
47+
job_name=FLAGS.job_name,
48+
task_index=FLAGS.task_id)
49+
50+
if FLAGS.job_name == 'ps':
51+
# `ps` jobs wait for incoming connections from the workers.
52+
server.join()
53+
else:
54+
# `worker` jobs will actually do the work.
55+
dataset = ImagenetData(subset=FLAGS.subset)
56+
assert dataset.data_files()
57+
# Only the chief checks for or creates train_dir.
58+
if FLAGS.task_id == 0:
59+
if not tf.gfile.Exists(FLAGS.train_dir):
60+
tf.gfile.MakeDirs(FLAGS.train_dir)
61+
inception_distributed_train.train(server.target, dataset, cluster_spec)
62+
63+
if __name__ == '__main__':
64+
tf.logging.set_verbosity(tf.logging.INFO)
65+
tf.app.run()

0 commit comments

Comments
 (0)