Skip to content

Commit 7ffd513

Browse files
committed
minor refactor to optimize_graph
1 parent 567cf7a commit 7ffd513

File tree

2 files changed

+9
-26
lines changed

2 files changed

+9
-26
lines changed

00_Miscellaneous/model_optimisation/optimize_graph.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ def freeze_model(saved_model_dir, output_node_names, output_filename):
314314
print('****************************************')
315315

316316

317-
def convert_graph_def_to_saved_model(export_dir, graph_filepath):
317+
def convert_graph_def_to_saved_model(export_dir, graph_filepath, output_key, output_node_name):
318318
if tf.gfile.Exists(export_dir):
319319
tf.gfile.DeleteRecursively(export_dir)
320320
graph_def = get_graph_def_from_file(graph_filepath)
@@ -327,8 +327,8 @@ def convert_graph_def_to_saved_model(export_dir, graph_filepath):
327327
node.name: session.graph.get_tensor_by_name(
328328
'{}:0'.format(node.name))
329329
for node in graph_def.node if node.op=='Placeholder'},
330-
outputs={'class_ids': session.graph.get_tensor_by_name(
331-
'head/predictions/class_ids:0')}
330+
outputs={output_key: session.graph.get_tensor_by_name(
331+
output_node_name)}
332332
)
333333
print('****************************************')
334334
print('Optimized graph converted to SavedModel!')
@@ -412,7 +412,8 @@ def main(args):
412412

413413
# convert to saved model and output metagraph again
414414
optimized_export_dir = os.path.join(export_dir, 'optimized')
415-
convert_graph_def_to_saved_model(optimized_export_dir, optimized_filepath)
415+
convert_graph_def_to_saved_model(optimized_export_dir, optimized_filepath, 'class_ids',
416+
'head/predictions/class_ids:0')
416417
get_size(optimized_export_dir, 'saved_model.pb')
417418
get_metagraph(optimized_export_dir)
418419

00_Miscellaneous/model_optimisation/optimize_graph_keras.py

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@
2929

3030
from inference_test import inference_test, load_mnist_keras
3131
from optimize_graph import (run_experiment, get_graph_def_from_saved_model,
32-
describe_graph, convert_graph_def_to_saved_model, get_size, get_metagraph,
33-
get_graph_def_from_file, freeze_model, optimize_graph, TRANSFORMS)
32+
describe_graph, get_size, get_metagraph, get_graph_def_from_file,
33+
convert_graph_def_to_saved_model, freeze_model, optimize_graph, TRANSFORMS)
3434

3535
NUM_CLASSES = 10
3636
MODELS_LOCATION = 'models/mnist'
@@ -134,25 +134,6 @@ def make_serving_input_receiver_fn():
134134
return export_dir
135135

136136

137-
def convert_graph_def_to_saved_model(export_dir, graph_filepath, output_key, output_node_name):
138-
if tf.gfile.Exists(export_dir):
139-
tf.gfile.DeleteRecursively(export_dir)
140-
graph_def = get_graph_def_from_file(graph_filepath)
141-
with tf.Session(graph=tf.Graph()) as session:
142-
tf.import_graph_def(graph_def, name='')
143-
tf.saved_model.simple_save(
144-
session,
145-
export_dir,
146-
inputs={
147-
node.name: session.graph.get_tensor_by_name(
148-
'{}:0'.format(node.name))
149-
for node in graph_def.node if node.op=='Placeholder'},
150-
outputs={output_key: session.graph.get_tensor_by_name(
151-
output_node_name)}
152-
)
153-
print('Optimized graph converted to SavedModel!')
154-
155-
156137
def setup_model():
157138
train_data, train_labels, eval_data, eval_labels = load_mnist_keras()
158139
export_dir = train_and_export_model(train_data, train_labels)
@@ -216,7 +197,8 @@ def main(args):
216197

217198
# convert to saved model and output metagraph again
218199
optimized_export_dir = os.path.join(export_dir, 'optimized')
219-
convert_graph_def_to_saved_model(optimized_export_dir, optimized_filepath, 'softmax', 'softmax/Softmax:0')
200+
convert_graph_def_to_saved_model(optimized_export_dir, optimized_filepath,
201+
'softmax', 'softmax/Softmax:0')
220202
get_size(optimized_export_dir, 'saved_model.pb')
221203
get_metagraph(optimized_export_dir)
222204

0 commit comments

Comments
 (0)