diff --git a/DeepSpeech.py b/DeepSpeech.py index 37cf222a..1883724d 100755 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -18,7 +18,7 @@ from datetime import datetime from ds_ctcdecoder import ctc_beam_search_decoder, Scorer from evaluate import evaluate from six.moves import zip, range -from tensorflow.python.tools import freeze_graph +from tensorflow.python.tools import freeze_graph, strip_unused_lib from util.config import Config, initialize_globals from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features from util.flags import create_flags, FLAGS @@ -711,7 +711,7 @@ def export(): os.makedirs(FLAGS.export_dir) def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=None): - return freeze_graph.freeze_graph_with_def_protos( + frozen = freeze_graph.freeze_graph_with_def_protos( input_graph_def=tf.get_default_graph().as_graph_def(), input_saver_def=saver.as_saver_def(), input_checkpoint=checkpoint_path, @@ -723,6 +723,13 @@ def export(): variable_names_blacklist=variables_blacklist, initializer_nodes='') + input_node_names = [] + return strip_unused_lib.strip_unused( + input_graph_def=frozen, + input_node_names=input_node_names, + output_node_names=output_node_names.split(','), + placeholder_type_enum=tf.float32.as_datatype_enum) + if not FLAGS.export_tflite: frozen_graph = do_graph_freeze(output_node_names=output_names, variables_blacklist='previous_state_c,previous_state_h') frozen_graph.version = int(file_relative_read('GRAPH_VERSION').strip())