Skip to content
Merged
Prev Previous commit
Next Next commit
Cleanup1
  • Loading branch information
glenn-jocher authored Feb 18, 2022
commit 14d0d78e01ce3cf7fd25d6f0a85d142e3740eb08
30 changes: 16 additions & 14 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
import pandas as pd
import torch
import torch.nn as nn
from packaging.version import parse as parse_version
from torch.utils.mobile_optimizer import optimize_for_mobile

FILE = Path(__file__).resolve()
Expand Down Expand Up @@ -248,7 +247,7 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F

def export_saved_model(model, im, file, dynamic,
tf_nms=False, agnostic_nms=False, topk_per_class=100, topk_all=100, iou_thres=0.45,
conf_thres=0.25, prefix=colorstr('TensorFlow SavedModel:')):
conf_thres=0.25, use_keras=False, prefix=colorstr('TensorFlow SavedModel:')):
# YOLOv5 TensorFlow SavedModel export
try:
import tensorflow as tf
Expand All @@ -269,18 +268,21 @@ def export_saved_model(model, im, file, dynamic,
keras_model = keras.Model(inputs=inputs, outputs=outputs)
keras_model.trainable = False
keras_model.summary()
m = tf.function(lambda x: keras_model(x)) # full model
m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
frozen_func = convert_variables_to_constants_v2(m)
tfm = tf.Module()
tfm.__call__ = tf.function(lambda x: frozen_func(x),
[tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype)])
tfm.__call__(im)
tf.saved_model.save(
tfm,
f,
options=tf.saved_model.SaveOptions() if parse_version(tf.__version__) < parse_version("2.6") else
tf.saved_model.SaveOptions(experimental_custom_gradients=False))
if use_keras:
keras_model.save(f, save_format='tf')
else:
m = tf.function(lambda x: keras_model(x)) # full model
m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
frozen_func = convert_variables_to_constants_v2(m)
tfm = tf.Module()
tfm.__call__ = tf.function(lambda x: frozen_func(x),
[tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype)])
tfm.__call__(im)
tf.saved_model.save(
tfm,
f,
options=tf.saved_model.SaveOptions(experimental_custom_gradients=False) if
check_version(tf.__version__, '2.6') else tf.saved_model.SaveOptions())
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
return keras_model, f
except Exception as e:
Expand Down