Skip to content

Commit 5a596bc

Browse files
yangustc07tensorflower-gardener
authored andcommitted
#tf-data-service Add return type annotations for tf.data service lib.
PiperOrigin-RevId: 557348637
1 parent a35d2dd commit 5a596bc

File tree

1 file changed

+20
-17
lines changed

1 file changed

+20
-17
lines changed

tensorflow/python/data/experimental/service/server_lib.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""A Python interface for creating dataset servers."""
1616

1717
import collections
18+
from typing import Iterable
1819

1920
# pylint: disable=invalid-import-order,g-bad-import-order, unused-import
2021
from tensorflow.core.protobuf import service_config_pb2
@@ -24,7 +25,7 @@
2425
from tensorflow.python.util.tf_export import tf_export
2526

2627

27-
def _get_time_or_placeholder(value):
28+
def _get_time_or_placeholder(value) -> int:
2829
"""Modifies time-based config values to account for special behaviors."""
2930

3031
# Servers interpret time values of 0 to mean "choose a reasonable
@@ -217,7 +218,7 @@ def start(self):
217218
"""
218219
self._server.start()
219220

220-
def join(self):
221+
def join(self) -> None:
221222
"""Blocks until the server has shut down.
222223
223224
This is useful when starting a dedicated dispatch process.
@@ -234,7 +235,7 @@ def join(self):
234235
"""
235236
self._server.join()
236237

237-
def stop(self):
238+
def stop(self) -> None:
238239
"""Stops the server.
239240
240241
Raises:
@@ -244,7 +245,7 @@ def stop(self):
244245
self._stop()
245246

246247
@property
247-
def target(self):
248+
def target(self) -> str:
248249
"""Returns a target that can be used to connect to the server.
249250
250251
>>> dispatcher = tf.data.experimental.service.DispatchServer()
@@ -258,7 +259,7 @@ def target(self):
258259
return "{0}://localhost:{1}".format(self._config.protocol,
259260
self._server.bound_port())
260261

261-
def _stop(self):
262+
def _stop(self) -> None:
262263
"""Stops the server.
263264
264265
Raises:
@@ -267,22 +268,23 @@ def _stop(self):
267268
"""
268269
self._server.stop()
269270

270-
def __del__(self):
271+
def __del__(self) -> None:
271272
self._stop()
272273

273274
@property
274-
def _address(self):
275+
def _address(self) -> str:
275276
"""Returns the address of the server.
276277
277278
The returned string will be in the form address:port, e.g. "localhost:1000".
278279
"""
279280
return "localhost:{0}".format(self._server.bound_port())
280281

281-
def _num_workers(self):
282+
def _num_workers(self) -> int:
282283
"""Returns the number of workers registered with the dispatcher."""
283284
return self._server.num_workers()
284285

285-
def _snapshot_streams(self, path):
286+
def _snapshot_streams(
287+
self, path) -> Iterable[_pywrap_server_lib.SnapshotStreamInfoWrapper]:
286288
"""Returns information about all the streams for a snapshot."""
287289
return self._server.snapshot_streams(path)
288290

@@ -405,7 +407,7 @@ def __init__(self, config, start=True):
405407
if start:
406408
self._server.start()
407409

408-
def start(self):
410+
def start(self) -> None:
409411
"""Starts this server.
410412
411413
Raises:
@@ -414,7 +416,7 @@ def start(self):
414416
"""
415417
self._server.start()
416418

417-
def join(self):
419+
def join(self) -> None:
418420
"""Blocks until the server has shut down.
419421
420422
This is useful when starting a dedicated worker process.
@@ -433,7 +435,7 @@ def join(self):
433435
"""
434436
self._server.join()
435437

436-
def stop(self):
438+
def stop(self) -> None:
437439
"""Stops the server.
438440
439441
Raises:
@@ -442,7 +444,7 @@ def stop(self):
442444
"""
443445
self._stop()
444446

445-
def _stop(self):
447+
def _stop(self) -> None:
446448
"""Stops the server.
447449
448450
Raises:
@@ -451,22 +453,23 @@ def _stop(self):
451453
"""
452454
self._server.stop()
453455

454-
def __del__(self):
456+
def __del__(self) -> None:
455457
self._stop()
456458

457459
@property
458-
def _address(self):
460+
def _address(self) -> str:
459461
"""Returns the address of the server.
460462
461463
The returned string will be in the form address:port, e.g. "localhost:1000".
462464
"""
463465
return "localhost:{0}".format(self._server.bound_port())
464466

465-
def _num_tasks(self):
467+
def _num_tasks(self) -> int:
466468
"""Returns the number of tasks currently being executed on the worker."""
467469
return self._server.num_tasks()
468470

469-
def _snapshot_task_progresses(self):
471+
def _snapshot_task_progresses(
472+
self) -> Iterable[_pywrap_server_lib.SnapshotTaskProgressWrapper]:
470473
"""Returns the progresses of the snapshot tasks currently being executed.
471474
472475
Returns:

0 commit comments

Comments
 (0)