Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 0 additions & 69 deletions keras/src/models/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1219,75 +1219,6 @@ def test_functional_deeply_nested_outputs_struct_losses(self):
)
self.assertListEqual(hist_keys, ref_keys)

@parameterized.named_parameters(
("tf_saved_model", "tf_saved_model"),
("onnx", "onnx"),
)
@pytest.mark.skipif(
backend.backend() not in ("tensorflow", "jax", "torch"),
reason=(
"Currently, `Model.export` only supports the tensorflow, jax and "
"torch backends."
),
)
@pytest.mark.skipif(
testing.jax_uses_gpu(), reason="Leads to core dumps on CI"
)
def test_export(self, export_format):
if export_format == "tf_saved_model" and testing.torch_uses_gpu():
self.skipTest("Leads to core dumps on CI")

temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
model = _get_model()
x1 = np.random.rand(1, 3).astype("float32")
x2 = np.random.rand(1, 3).astype("float32")
ref_output = model([x1, x2])

model.export(temp_filepath, format=export_format)

if export_format == "tf_saved_model":
import tensorflow as tf

revived_model = tf.saved_model.load(temp_filepath)
self.assertAllClose(ref_output, revived_model.serve([x1, x2]))

# Test with a different batch size
if backend.backend() == "torch":
# TODO: Dynamic shape is not supported yet in the torch backend
return
revived_model.serve(
[
np.concatenate([x1, x1], axis=0),
np.concatenate([x2, x2], axis=0),
]
)
elif export_format == "onnx":
import onnxruntime

ort_session = onnxruntime.InferenceSession(temp_filepath)
ort_inputs = {
k.name: v for k, v in zip(ort_session.get_inputs(), [x1, x2])
}
self.assertAllClose(
ref_output, ort_session.run(None, ort_inputs)[0]
)

# Test with a different batch size
if backend.backend() == "torch":
# TODO: Dynamic shape is not supported yet in the torch backend
return
ort_inputs = {
k.name: v
for k, v in zip(
ort_session.get_inputs(),
[
np.concatenate([x1, x1], axis=0),
np.concatenate([x2, x2], axis=0),
],
)
}
ort_session.run(None, ort_inputs)

def test_export_error(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
model = _get_model()
Expand Down
Loading