1515"""A Python interface for creating dataset servers."""
1616
1717import collections
18+ from typing import Iterable
1819
1920# pylint: disable=invalid-import-order,g-bad-import-order, unused-import
2021from tensorflow .core .protobuf import service_config_pb2
2425from tensorflow .python .util .tf_export import tf_export
2526
2627
27- def _get_time_or_placeholder (value ):
28+ def _get_time_or_placeholder (value ) -> int :
2829 """Modifies time-based config values to account for special behaviors."""
2930
3031 # Servers interpret time values of 0 to mean "choose a reasonable
@@ -217,7 +218,7 @@ def start(self):
217218 """
218219 self ._server .start ()
219220
220- def join (self ):
221+ def join (self ) -> None :
221222 """Blocks until the server has shut down.
222223
223224 This is useful when starting a dedicated dispatch process.
@@ -234,7 +235,7 @@ def join(self):
234235 """
235236 self ._server .join ()
236237
237- def stop (self ):
238+ def stop (self ) -> None :
238239 """Stops the server.
239240
240241 Raises:
@@ -244,7 +245,7 @@ def stop(self):
244245 self ._stop ()
245246
246247 @property
247- def target (self ):
248+ def target (self ) -> str :
248249 """Returns a target that can be used to connect to the server.
249250
250251 >>> dispatcher = tf.data.experimental.service.DispatchServer()
@@ -258,7 +259,7 @@ def target(self):
258259 return "{0}://localhost:{1}" .format (self ._config .protocol ,
259260 self ._server .bound_port ())
260261
261- def _stop (self ):
262+ def _stop (self ) -> None :
262263 """Stops the server.
263264
264265 Raises:
@@ -267,22 +268,23 @@ def _stop(self):
267268 """
268269 self ._server .stop ()
269270
270- def __del__ (self ):
271+ def __del__ (self ) -> None :
271272 self ._stop ()
272273
273274 @property
274- def _address (self ):
275+ def _address (self ) -> str :
275276 """Returns the address of the server.
276277
277278 The returned string will be in the form address:port, e.g. "localhost:1000".
278279 """
279280 return "localhost:{0}" .format (self ._server .bound_port ())
280281
281- def _num_workers (self ):
282+ def _num_workers (self ) -> int :
282283 """Returns the number of workers registered with the dispatcher."""
283284 return self ._server .num_workers ()
284285
285- def _snapshot_streams (self , path ):
286+ def _snapshot_streams (
287+ self , path ) -> Iterable [_pywrap_server_lib .SnapshotStreamInfoWrapper ]:
286288 """Returns information about all the streams for a snapshot."""
287289 return self ._server .snapshot_streams (path )
288290
@@ -405,7 +407,7 @@ def __init__(self, config, start=True):
405407 if start :
406408 self ._server .start ()
407409
408- def start (self ):
410+ def start (self ) -> None :
409411 """Starts this server.
410412
411413 Raises:
@@ -414,7 +416,7 @@ def start(self):
414416 """
415417 self ._server .start ()
416418
417- def join (self ):
419+ def join (self ) -> None :
418420 """Blocks until the server has shut down.
419421
420422 This is useful when starting a dedicated worker process.
@@ -433,7 +435,7 @@ def join(self):
433435 """
434436 self ._server .join ()
435437
436- def stop (self ):
438+ def stop (self ) -> None :
437439 """Stops the server.
438440
439441 Raises:
@@ -442,7 +444,7 @@ def stop(self):
442444 """
443445 self ._stop ()
444446
445- def _stop (self ):
447+ def _stop (self ) -> None :
446448 """Stops the server.
447449
448450 Raises:
@@ -451,22 +453,23 @@ def _stop(self):
451453 """
452454 self ._server .stop ()
453455
454- def __del__ (self ):
456+ def __del__ (self ) -> None :
455457 self ._stop ()
456458
457459 @property
458- def _address (self ):
460+ def _address (self ) -> str :
459461 """Returns the address of the server.
460462
461463 The returned string will be in the form address:port, e.g. "localhost:1000".
462464 """
463465 return "localhost:{0}" .format (self ._server .bound_port ())
464466
465- def _num_tasks (self ):
467+ def _num_tasks (self ) -> int :
466468 """Returns the number of tasks currently being executed on the worker."""
467469 return self ._server .num_tasks ()
468470
469- def _snapshot_task_progresses (self ):
471+ def _snapshot_task_progresses (
472+ self ) -> Iterable [_pywrap_server_lib .SnapshotTaskProgressWrapper ]:
470473 """Returns the progresses of the snapshot tasks currently being executed.
471474
472475 Returns:
0 commit comments