Skip to content

Commit 87ab169

Browse files
crccwtensorflower-gardener
authored andcommitted
Update MultiWorkerMirroredStrategy API doc
PiperOrigin-RevId: 323064599 Change-Id: Ie75ae964a8edbc9060ef5f1731c8c9ab34404fe6
1 parent cbc87da commit 87ab169

File tree

1 file changed

+48
-50
lines changed

1 file changed

+48
-50
lines changed

tensorflow/python/distribute/collective_all_reduce_strategy.py

Lines changed: 48 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -44,37 +44,53 @@
4444
from tensorflow.python.util.tf_export import tf_export
4545

4646

47-
# TODO(yuefengz): support in-graph replication.
4847
@tf_export("distribute.experimental.MultiWorkerMirroredStrategy", v1=[])
4948
class CollectiveAllReduceStrategy(distribute_lib.Strategy):
5049
"""A distribution strategy for synchronous training on multiple workers.
5150
5251
This strategy implements synchronous distributed training across multiple
5352
workers, each with potentially multiple GPUs. Similar to
54-
`tf.distribute.MirroredStrategy`, it creates copies of all variables in the
55-
model on each device across all workers.
56-
57-
It uses CollectiveOps's implementation of multi-worker all-reduce to
58-
to keep variables in sync. A collective op is a single op in the
59-
TensorFlow graph which can automatically choose an all-reduce algorithm in
60-
the TensorFlow runtime according to hardware, network topology and tensor
61-
sizes.
62-
63-
By default it uses all local GPUs or CPU for single-worker training.
64-
65-
When 'TF_CONFIG' environment variable is set, it parses cluster_spec,
66-
task_type and task_id from 'TF_CONFIG' and turns into a multi-worker strategy
67-
which mirrored models on GPUs of all machines in a cluster. In the current
68-
implementation, it uses all GPUs in a cluster and it assumes all workers have
69-
the same number of GPUs.
70-
71-
You can also pass a `distribute.cluster_resolver.ClusterResolver` instance
72-
when instantiating the strategy. The task_type, task_id etc. will be parsed
73-
from the resolver instance instead of from the `TF_CONFIG` env var.
74-
75-
It supports both eager mode and graph mode. However, for eager mode, it has to
76-
set up the eager context in its constructor and therefore all ops in eager
77-
mode have to run after the strategy object is created.
53+
`tf.distribute.MirroredStrategy`, it replicates all variables and computations
54+
to each local device. The difference is that it uses a distributed collective
55+
implementation (e.g. all-reduce), so that multiple workers can work together.
56+
57+
You need to launch your program on each worker and configure
58+
`cluster_resolver` correctly. For example, if you are using
59+
`tf.distribute.cluster_resolver.TFConfigClusterResolver`, each worker needs to
60+
have its corresponding `task_type` and `task_id` set in the `TF_CONFIG`
61+
environment variable.
62+
63+
Your program runs on each worker as-is. Note that collectives require each
64+
worker to participate. All `tf.distribute` and non `tf.distribute` API may use
65+
collectives internally, e.g. checkpointing and saving since reading a
66+
`tf.Variable` with `tf.VariableSynchronization.ON_READ` all-reduces the value.
67+
Therefore it's recommended to run exactly the same program on each worker.
68+
Dispatching based on `task_type` or `task_id` of the worker is error-prone.
69+
70+
`cluster_resolver.num_accelerators()` determines the number of GPUs the
71+
strategy uses. If it's zero, the strategy uses the CPU. All workers need to
72+
use the same number of devices, otherwise the behavior is undefined.
73+
74+
This strategy is not intended for TPU. Use
75+
`tf.distribute.experimental.TPUStrategy` instead.
76+
77+
__Saving__
78+
79+
You need to save and checkpoint on all workers instead of just one. This is
80+
because variables whose synchronization=ON_READ triggers aggregation during
81+
saving. It's recommended to save to a different path on each worker to avoid
82+
race conditions. Each worker saves the same thing. See
83+
[Multi-worker training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras#model_saving_and_loading)
84+
tutorial for examples.
85+
86+
__Known Issues__
87+
88+
* `tf.distribute.cluster_resolver.TFConfigClusterResolver` does not return the
89+
correct number of accelerators. The strategy uses all available GPUs if
90+
`cluster_resolver` is `tf.distribute.cluster_resolver.TFConfigClusterResolver`
91+
or `None`.
92+
* In eager mode, the strategy needs to be created before calling any other
93+
Tensorflow API.
7894
7995
"""
8096
# TODO(anjalisridhar): Update our guides with examples showing how we can use
@@ -87,14 +103,13 @@ def __init__(
87103
"""Creates the strategy.
88104
89105
Args:
90-
communication: optional Enum of type
91-
`distribute.experimental.CollectiveCommunication`. This provides a way
92-
for the user to override the choice of collective op communication.
93-
Possible values include `AUTO`, `RING`, and `NCCL`.
94-
cluster_resolver: optional `distribute.cluster_resolver.ClusterResolver`
95-
object. The default ClusterResolver that is used is the
96-
TFConfigClusterResolver which is instantiated from the TF_CONFIG env
97-
var.
106+
communication: optional
107+
`tf.distribute.experimental.CollectiveCommunication`. This is a hint on
108+
the preferred collective communication implementation. Possible values
109+
include `AUTO`, `RING`, and `NCCL`.
110+
cluster_resolver: optional
111+
`tf.distribute.cluster_resolver.ClusterResolver`. If `None`,
112+
`tf.distribute.cluster_resolver.TFConfigClusterResolver` is used.
98113
"""
99114
# TODO(b/150151677): consider move communication to CollectiveHints.
100115
super(CollectiveAllReduceStrategy, self).__init__(
@@ -121,23 +136,6 @@ def _from_local_devices(
121136
obj.extended._initialize_local(TFConfigClusterResolver(), devices=devices) # pylint: disable=protected-access
122137
return obj
123138

124-
def scope(self): # pylint: disable=useless-super-delegation
125-
"""Returns a context manager selecting this Strategy as current.
126-
127-
Inside a `with strategy.scope():` code block, this thread
128-
will use a variable creator set by `strategy`, and will
129-
enter its "cross-replica context".
130-
131-
In `MultiWorkerMirroredStrategy`, all variables created inside
132-
`strategy.scope() will be mirrored on all replicas of each worker.
133-
Moreover, it also sets a default device scope so that ops without
134-
specified devices will end up on the correct worker.
135-
136-
Returns:
137-
A context manager to use for creating variables with this strategy.
138-
"""
139-
return super(CollectiveAllReduceStrategy, self).scope()
140-
141139
@property
142140
def cluster_resolver(self):
143141
"""Returns the cluster resolver associated with this strategy.

0 commit comments

Comments
 (0)