Skip to content

Commit 4b1f1b6

Browse files
Shining Suntensorflower-gardener
authored andcommitted
Add tests for the keras experimental save and load with DS
PiperOrigin-RevId: 248204945
1 parent 275ed3d commit 4b1f1b6

File tree

5 files changed

+266
-77
lines changed

5 files changed

+266
-77
lines changed

tensorflow/contrib/distribute/python/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ py_library(
2222
"//tensorflow/python/distribute:combinations",
2323
"//tensorflow/python/distribute:model_combinations",
2424
"//tensorflow/python/distribute:multi_worker_test_base",
25+
"//tensorflow/python/distribute:saved_model_test_base",
2526
"//tensorflow/python/distribute:single_loss_example",
2627
"//tensorflow/python/distribute:strategy_combinations",
2728
"//tensorflow/python/distribute:strategy_test_lib",

tensorflow/python/distribute/BUILD

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -972,20 +972,36 @@ py_library(
972972
],
973973
)
974974

975-
distribute_py_test(
976-
name = "saved_model_test",
977-
size = "medium",
978-
srcs = ["saved_model_test.py"],
979-
main = "saved_model_test.py",
980-
tags = [
981-
"no_pip", # b/131691139
982-
],
975+
py_library(
976+
name = "saved_model_test_base",
977+
srcs = ["saved_model_test_base.py"],
983978
deps = [
984979
":combinations",
985980
":model_combinations",
986981
":strategy_combinations",
987982
"//tensorflow/python/eager:test",
988-
"//tensorflow/python/saved_model",
989983
"//third_party/py/numpy",
990984
],
991985
)
986+
987+
distribute_py_test(
988+
name = "saved_model_save_load_test",
989+
size = "medium",
990+
srcs = ["saved_model_save_load_test.py"],
991+
main = "saved_model_save_load_test.py",
992+
deps = [
993+
":saved_model_test_base",
994+
"//tensorflow/python/saved_model",
995+
],
996+
)
997+
998+
distribute_py_test(
999+
name = "keras_saved_model_test",
1000+
size = "medium",
1001+
srcs = ["keras_saved_model_test.py"],
1002+
main = "keras_saved_model_test.py",
1003+
deps = [
1004+
":saved_model_test_base",
1005+
"//tensorflow/python/keras:saving",
1006+
],
1007+
)
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Tests for saving and loading using keras experimental APIs with DS."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
from tensorflow.python.distribute import combinations
22+
from tensorflow.python.distribute import saved_model_test_base as test_base
23+
from tensorflow.python.eager import test
24+
from tensorflow.python.keras.saving import saved_model
25+
26+
27+
class KerasExperimentalSaveLoadTest(test_base.TestSavedModelBase):
28+
29+
def setUp(self):
30+
self._root_dir = 'keras_experimental_save_load'
31+
super(KerasExperimentalSaveLoadTest, self).setUp()
32+
33+
def _save_model(self, model, saved_dir):
34+
saved_model.export_saved_model(model, saved_dir)
35+
36+
def _load_and_run_model(self, distribution, saved_dir, predict_dataset,
37+
output_name):
38+
restored_keras_model = saved_model.load_from_saved_model(saved_dir)
39+
return restored_keras_model.predict(
40+
predict_dataset, steps=test_base.PREDICT_STEPS)
41+
42+
@combinations.generate(test_base.simple_models_with_strategies())
43+
def test_save_no_strategy_restore_strategy(self, model_and_input,
44+
distribution):
45+
self.run_test_save_no_strategy_restore_strategy(model_and_input,
46+
distribution)
47+
48+
@combinations.generate(test_base.simple_models_with_strategies())
49+
def test_save_strategy_restore_no_strategy(self, model_and_input,
50+
distribution):
51+
self.run_test_save_strategy_restore_no_strategy(model_and_input,
52+
distribution)
53+
54+
@combinations.generate(test_base.simple_models_with_strategy_pairs())
55+
def test_save_strategy_restore_strategy(self, model_and_input,
56+
distribution_pair):
57+
self.run_test_save_strategy_restore_strategy(model_and_input,
58+
distribution_pair)
59+
60+
61+
if __name__ == '__main__':
62+
test.main()
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Tests for saving and loading using tf's saved_model APIs with DS."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
from tensorflow.python.distribute import combinations
22+
from tensorflow.python.distribute import saved_model_test_base as test_base
23+
from tensorflow.python.eager import test
24+
from tensorflow.python.saved_model import saved_model
25+
26+
_DEFAULT_FUNCTION_KEY = 'serving_default'
27+
28+
29+
class SavedModelSaveAndLoadTest(test_base.TestSavedModelBase):
30+
31+
def setUp(self):
32+
self._root_dir = 'saved_model_save_load'
33+
super(SavedModelSaveAndLoadTest, self).setUp()
34+
35+
def _save_model(self, model, saved_dir):
36+
saved_model.save(model, saved_dir)
37+
38+
def _load_and_run_model(self, distribution, saved_dir, predict_dataset,
39+
output_name):
40+
dist_predict_dataset = distribution.experimental_distribute_dataset(
41+
predict_dataset)
42+
per_replica_predict_data = next(iter(dist_predict_dataset))
43+
func = saved_model.load(saved_dir)
44+
result = distribution.experimental_run_v2(
45+
func.signatures[_DEFAULT_FUNCTION_KEY], per_replica_predict_data)
46+
return result[output_name]
47+
48+
@combinations.generate(test_base.simple_models_with_strategies())
49+
def test_save_no_strategy_restore_strategy(self, model_and_input,
50+
distribution):
51+
self.skipTest(('Saving/loading model with tf.distribute.Strategy is not ',
52+
'supported.'))
53+
self.run_test_save_no_strategy_restore_strategy(model_and_input,
54+
distribution)
55+
56+
@combinations.generate(test_base.simple_models_with_strategies())
57+
def test_save_strategy_restore_no_strategy(self, model_and_input,
58+
distribution):
59+
self.skipTest(('Saving/loading model with tf.distribute.Strategy is not ',
60+
'supported.'))
61+
self.run_test_save_strategy_restore_no_strategy(model_and_input,
62+
distribution)
63+
64+
@combinations.generate(test_base.simple_models_with_strategy_pairs())
65+
def test_save_strategy_restore_strategy(self, model_and_input,
66+
distribution_pair):
67+
self.skipTest(('Saving/loading model with tf.distribute.Strategy is not ',
68+
'supported.'))
69+
self.run_test_save_strategy_restore_strategy(model_and_input,
70+
distribution_pair)
71+
72+
73+
if __name__ == '__main__':
74+
test.main()

0 commit comments

Comments
 (0)