mirror of
https://github.com/mozilla/DeepSpeech.git
synced 2025-10-26 11:19:39 +00:00
Merge pull request #127 from Cwiiis/serving-client
Write a Tensorflow Serving client, fixes #21
This commit is contained in:
commit
6449fc45bd
15
client/BUILD
Normal file
15
client/BUILD
Normal file
@ -0,0 +1,15 @@
|
||||
# Description: Deepspeech Serving Client.
|
||||
|
||||
load("//tensorflow_serving:serving.bzl", "serving_proto_library")
|
||||
|
||||
py_binary(
|
||||
name = "deepspeech_client",
|
||||
srcs = [
|
||||
"deepspeech_client.py",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow_serving/apis:predict_proto_py_pb2",
|
||||
"//tensorflow_serving/apis:prediction_service_proto_py_pb2",
|
||||
"@org_tensorflow//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
42
client/README.md
Normal file
42
client/README.md
Normal file
@ -0,0 +1,42 @@
|
||||
# DeepSpeech client
|
||||
|
||||
A client for running queries on an exported DeepSpeech model.
|
||||
|
||||
## Requirements
|
||||
|
||||
* [Tensorflow Serving](https://tensorflow.github.io/serving/setup)
|
||||
|
||||
## Building
|
||||
|
||||
Create a symbolic link in the Tensorflow Serving checkout to the deepspeech client directory.
|
||||
|
||||
```
|
||||
cd serving
|
||||
ln -s ../DeepSpeech/deepspeech_client ./
|
||||
```
|
||||
|
||||
If you haven't already, you'll need to build the Tensorflow Server.
|
||||
|
||||
```
|
||||
bazel build -c opt //tensorflow_serving/model_servers:tensorflow_model_server
|
||||
```
|
||||
|
||||
Then you can build the DeepSpeech client.
|
||||
|
||||
```
|
||||
bazel build -c opt //deepspeech_client
|
||||
```
|
||||
|
||||
## Running
|
||||
|
||||
Start a server running an exported DeepSpeech model.
|
||||
|
||||
```
|
||||
bazel-bin/tensorflow_serving/model_servers/tensorflow_model_server --port=9000 --model_name=deepspeech --model_base_path=/path/to/deepspeech/export
|
||||
```
|
||||
|
||||
Now run the client.
|
||||
|
||||
```
|
||||
bazel-bin/deepspeech_client/deepspeech_client --server=localhost:9000 --file=/path/to/audio.wav
|
||||
```
|
||||
75
client/deepspeech_client.py
Normal file
75
client/deepspeech_client.py
Normal file
@ -0,0 +1,75 @@
|
||||
#!/usr/bin/env python2.7
|
||||
|
||||
"""A client that talks to tensorflow_model_server loaded with deepspeech model.
|
||||
|
||||
The client queries the service with the given audio and prints a ranked list
|
||||
of decoded outputs to the standard output, one per line.
|
||||
|
||||
Typical usage example:
|
||||
|
||||
deepspeech_client.py --server=localhost:9000 --file audio.wav
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
from grpc.beta import implementations
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow_serving.apis import predict_pb2
|
||||
from tensorflow_serving.apis import prediction_service_pb2
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), '..'))
|
||||
from util.text import ndarray_to_text
|
||||
from util.audio import audiofile_to_input_vector
|
||||
|
||||
tf.app.flags.DEFINE_string('server', '', 'PredictionService host:port')
|
||||
tf.app.flags.DEFINE_string('file', '', 'Wave audio file')
|
||||
# These need to match the constants used when training the deepspeech model
|
||||
tf.app.flags.DEFINE_string('n_input', 26, 'Number of MFCC features')
|
||||
tf.app.flags.DEFINE_string('n_context', 9, 'Number of frames of context')
|
||||
FLAGS = tf.app.flags.FLAGS
|
||||
|
||||
def _create_rpc_callback(event):
|
||||
def _callback(result_future):
|
||||
exception = result_future.exception()
|
||||
if exception:
|
||||
print exception
|
||||
else:
|
||||
results = tf.contrib.util.make_ndarray(result_future.result().outputs['outputs'])
|
||||
for result in results[0]:
|
||||
print ndarray_to_text(result)
|
||||
event.set()
|
||||
return _callback
|
||||
|
||||
def do_inference(hostport, audio):
|
||||
host, port = hostport.split(':')
|
||||
channel = implementations.insecure_channel(host, int(port))
|
||||
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
|
||||
|
||||
request = predict_pb2.PredictRequest()
|
||||
request.model_spec.name = 'deepspeech'
|
||||
request.inputs['input'].CopyFrom(tf.contrib.util.make_tensor_proto(audio))
|
||||
|
||||
event = threading.Event()
|
||||
result_future = stub.Predict.future(request, 5.0) # 5 seconds
|
||||
result_future.add_done_callback(_create_rpc_callback(event))
|
||||
if event.is_set() != True:
|
||||
event.wait()
|
||||
|
||||
def main(_):
|
||||
if not FLAGS.server:
|
||||
print 'please specify server host:port'
|
||||
return
|
||||
if not FLAGS.file:
|
||||
print 'pleace specify an audio file'
|
||||
return
|
||||
|
||||
audio_waves = audiofile_to_input_vector(
|
||||
FLAGS.file, FLAGS.n_input, FLAGS.n_context)
|
||||
audio = np.array([ audio_waves ])
|
||||
do_inference(FLAGS.server, audio)
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.app.run()
|
||||
@ -67,6 +67,12 @@ def sparse_tuple_to_texts(tuple):
|
||||
# List of strings
|
||||
return results
|
||||
|
||||
def ndarray_to_text(value):
|
||||
results = ''
|
||||
for i in range(len(value)):
|
||||
results += chr(value[i] + FIRST_INDEX)
|
||||
return results.replace('`', ' ')
|
||||
|
||||
def wer(original, result):
|
||||
"""
|
||||
The WER is defined as the editing/Levenshtein distance on word level
|
||||
|
||||
Loading…
Reference in New Issue
Block a user