From 6178c31a203af08101494d5affde822959aa612a Mon Sep 17 00:00:00 2001 From: Chris Lord Date: Tue, 8 Nov 2016 11:45:28 +0100 Subject: [PATCH] Write a Tensorflow Serving client --- client/BUILD | 15 ++++++++ client/README.md | 42 +++++++++++++++++++++ client/deepspeech_client.py | 75 +++++++++++++++++++++++++++++++++++++ util/text.py | 6 +++ 4 files changed, 138 insertions(+) create mode 100644 client/BUILD create mode 100644 client/README.md create mode 100644 client/deepspeech_client.py diff --git a/client/BUILD b/client/BUILD new file mode 100644 index 00000000..9df66ff4 --- /dev/null +++ b/client/BUILD @@ -0,0 +1,15 @@ +# Description: Deepspeech Serving Client. + +load("//tensorflow_serving:serving.bzl", "serving_proto_library") + +py_binary( + name = "deepspeech_client", + srcs = [ + "deepspeech_client.py", + ], + deps = [ + "//tensorflow_serving/apis:predict_proto_py_pb2", + "//tensorflow_serving/apis:prediction_service_proto_py_pb2", + "@org_tensorflow//tensorflow:tensorflow_py", + ], +) diff --git a/client/README.md b/client/README.md new file mode 100644 index 00000000..10626e52 --- /dev/null +++ b/client/README.md @@ -0,0 +1,42 @@ +# DeepSpeech client + +A client for running queries on an exported DeepSpeech model. + +## Requirements + +* [Tensorflow Serving](https://tensorflow.github.io/serving/setup) + +## Building + +Create a symbolic link in the Tensorflow Serving checkout to the deepspeech client directory. + +``` +cd serving +ln -s ../DeepSpeech/deepspeech_client ./ +``` + +If you haven't already, you'll need to build the Tensorflow Server. + +``` +bazel build -c opt //tensorflow_serving/model_servers:tensorflow_model_server +``` + +Then you can build the DeepSpeech client. + +``` +bazel build -c opt //deepspeech_client +``` + +## Running + +Start a server running an exported DeepSpeech model. + +``` +bazel-bin/tensorflow_serving/model_servers/tensorflow_model_server --port=9000 --model_name=deepspeech --model_base_path=/path/to/deepspeech/export +``` + +Now run the client. + +``` +bazel-bin/deepspeech_client/deepspeech_client --server=localhost:9000 --file=/path/to/audio.wav +``` diff --git a/client/deepspeech_client.py b/client/deepspeech_client.py new file mode 100644 index 00000000..58685fcb --- /dev/null +++ b/client/deepspeech_client.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python2.7 + +"""A client that talks to tensorflow_model_server loaded with deepspeech model. + +The client queries the service with the given audio and prints a ranked list +of decoded outputs to the standard output, one per line. + +Typical usage example: + + deepspeech_client.py --server=localhost:9000 --file audio.wav +""" + +import os +import sys +import threading +from grpc.beta import implementations +import numpy as np +import tensorflow as tf + +from tensorflow_serving.apis import predict_pb2 +from tensorflow_serving.apis import prediction_service_pb2 + +sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), '..')) +from util.text import ndarray_to_text +from util.audio import audiofile_to_input_vector + +tf.app.flags.DEFINE_string('server', '', 'PredictionService host:port') +tf.app.flags.DEFINE_string('file', '', 'Wave audio file') +# These need to match the constants used when training the deepspeech model +tf.app.flags.DEFINE_string('n_input', 26, 'Number of MFCC features') +tf.app.flags.DEFINE_string('n_context', 9, 'Number of frames of context') +FLAGS = tf.app.flags.FLAGS + +def _create_rpc_callback(event): + def _callback(result_future): + exception = result_future.exception() + if exception: + print exception + else: + results = tf.contrib.util.make_ndarray(result_future.result().outputs['outputs']) + for result in results[0]: + print ndarray_to_text(result) + event.set() + return _callback + +def do_inference(hostport, audio): + host, port = hostport.split(':') + channel = implementations.insecure_channel(host, int(port)) + stub = prediction_service_pb2.beta_create_PredictionService_stub(channel) + + request = predict_pb2.PredictRequest() + request.model_spec.name = 'deepspeech' + request.inputs['input'].CopyFrom(tf.contrib.util.make_tensor_proto(audio)) + + event = threading.Event() + result_future = stub.Predict.future(request, 5.0) # 5 seconds + result_future.add_done_callback(_create_rpc_callback(event)) + if event.is_set() != True: + event.wait() + +def main(_): + if not FLAGS.server: + print 'please specify server host:port' + return + if not FLAGS.file: + print 'pleace specify an audio file' + return + + audio_waves = audiofile_to_input_vector( + FLAGS.file, FLAGS.n_input, FLAGS.n_context) + audio = np.array([ audio_waves ]) + do_inference(FLAGS.server, audio) + +if __name__ == '__main__': + tf.app.run() diff --git a/util/text.py b/util/text.py index 36d10c2b..6ca26e98 100644 --- a/util/text.py +++ b/util/text.py @@ -67,6 +67,12 @@ def sparse_tuple_to_texts(tuple): # List of strings return results +def ndarray_to_text(value): + results = '' + for i in range(len(value)): + results += chr(value[i] + FIRST_INDEX) + return results.replace('`', ' ') + def wer(original, result): """ The WER is defined as the editing/Levenshtein distance on word level