Skip to content

Commit 3707838

Browse files
andylytensorflower-gardener
authored andcommitted
Pass non empty MLIR module serialized string when constructing TpuCompilationCacheKey.
Added a test for MLIR bridge using TPUStrategy compiling two programs with the same signature but different bodies. PiperOrigin-RevId: 323096104 Change-Id: I2d2cd7033f762a0756b7de2ed44aa411234d8ca9
1 parent 5966258 commit 3707838

File tree

3 files changed

+104
-3
lines changed

3 files changed

+104
-3
lines changed

tensorflow/core/tpu/kernels/tpu_compile_op_common.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -662,9 +662,8 @@ Status TpuCompileOpKernelCommon::ComputeInternal(OpKernelContext* ctx) {
662662
}
663663

664664
const TpuCompilationCacheKey key = CreateCompilationCacheKey(
665-
function_.name(), metadata_.function_library_fingerprint(),
666-
/*mlir_module=*/"", guaranteed_constants, dynamic_shapes, metadata_,
667-
*mesh_state);
665+
function_.name(), metadata_.function_library_fingerprint(), mlir_module_,
666+
guaranteed_constants, dynamic_shapes, metadata_, *mesh_state);
668667

669668
// Process-wide cache of TPU executables.
670669
TpuCompilationCacheInterface* cache;

tensorflow/python/distribute/BUILD

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,21 @@ tpu_py_test(
657657
],
658658
)
659659

660+
tpu_py_test(
661+
name = "tpu_strategy_compilation_test",
662+
srcs = ["tpu_strategy_compilation_test.py"],
663+
disable_experimental = True,
664+
disable_mlir_bridge = False,
665+
python_version = "PY3",
666+
tags = ["no_oss"],
667+
deps = [
668+
":tpu_strategy",
669+
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
670+
"//tensorflow/python/eager:remote",
671+
"//tensorflow/python/eager:test",
672+
],
673+
)
674+
660675
# Used only by estimator.
661676
py_library(
662677
name = "estimator_training",
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Tests for TPUStrategy in regards to compiling programs."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
from tensorflow.python.distribute import tpu_strategy as tpu_lib
22+
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
23+
from tensorflow.python.eager import def_function
24+
from tensorflow.python.eager import remote
25+
from tensorflow.python.eager import test
26+
from tensorflow.python.framework import constant_op
27+
from tensorflow.python.platform import flags
28+
from tensorflow.python.tpu import tpu_strategy_util
29+
30+
FLAGS = flags.FLAGS
31+
flags.DEFINE_string("tpu", "", "Name of TPU to connect to.")
32+
flags.DEFINE_string("project", None, "Name of GCP project with TPU.")
33+
flags.DEFINE_string("zone", None, "Name of GCP zone with TPU.")
34+
35+
36+
def get_tpu_cluster_resolver():
37+
resolver = tpu_cluster_resolver.TPUClusterResolver(
38+
tpu=FLAGS.tpu,
39+
zone=FLAGS.zone,
40+
project=FLAGS.project,
41+
)
42+
return resolver
43+
44+
45+
def get_tpu_strategy():
46+
resolver = get_tpu_cluster_resolver()
47+
remote.connect_to_cluster(resolver)
48+
tpu_strategy_util.initialize_tpu_system(resolver)
49+
strategy = tpu_lib.TPUStrategyV2(resolver)
50+
return strategy
51+
52+
53+
# TODO(b/158494076): Merge this test back into TPUStrategy tests
54+
# (tpu_strategy_test) once MLIR bridge is enabled by default.
55+
class TPUStrategyCompilationTest(test.TestCase):
56+
57+
def test_functions_compile_same_signature(self):
58+
"""Tests compiling different functions with the same signature."""
59+
strategy = get_tpu_strategy()
60+
61+
@def_function.function
62+
def return_one():
63+
64+
def computation():
65+
return constant_op.constant(1)
66+
67+
return strategy.run(computation)
68+
69+
@def_function.function
70+
def return_two():
71+
72+
def computation():
73+
return constant_op.constant(2)
74+
75+
return strategy.run(computation)
76+
77+
expected_result_ones = [1 for _ in range(0, strategy.num_replicas_in_sync)]
78+
self.assertAllEqual(expected_result_ones,
79+
strategy.experimental_local_results(return_one()))
80+
81+
expected_result_twos = [2 for _ in range(0, strategy.num_replicas_in_sync)]
82+
self.assertAllEqual(expected_result_twos,
83+
strategy.experimental_local_results(return_two()))
84+
85+
86+
if __name__ == "__main__":
87+
test.main()

0 commit comments

Comments
 (0)