mirror of
https://github.com/mozilla/DeepSpeech.git
synced 2025-10-26 11:19:39 +00:00
Convert Fisher importer to new input system
This commit is contained in:
parent
a40df7251e
commit
e7bbfbf703
@ -1,19 +1,19 @@
|
||||
import fnmatch
|
||||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
import subprocess
|
||||
import wave
|
||||
import tensorflow as tf
|
||||
import unicodedata
|
||||
import codecs
|
||||
|
||||
from glob import glob
|
||||
from itertools import cycle
|
||||
from math import ceil
|
||||
from Queue import PriorityQueue
|
||||
from Queue import Queue
|
||||
from threading import Thread
|
||||
from util.audio import audiofile_to_input_vector
|
||||
from util.gpu import get_available_gpus
|
||||
from util.text import texts_to_sparse_tensor, validate_label
|
||||
from util.text import text_to_char_array, ctc_label_dense_to_sparse, validate_label
|
||||
|
||||
class DataSets(object):
|
||||
def __init__(self, train, dev, test):
|
||||
@ -21,6 +21,11 @@ class DataSets(object):
|
||||
self._test = test
|
||||
self._train = train
|
||||
|
||||
def start_queue_threads(self, session):
|
||||
self._dev.start_queue_threads(session)
|
||||
self._test.start_queue_threads(session)
|
||||
self._train.start_queue_threads(session)
|
||||
|
||||
@property
|
||||
def train(self):
|
||||
return self._train
|
||||
@ -34,23 +39,28 @@ class DataSets(object):
|
||||
return self._test
|
||||
|
||||
class DataSet(object):
|
||||
def __init__(self, graph, txt_files, thread_count, batch_size, numcep, numcontext):
|
||||
self._graph = graph
|
||||
def __init__(self, txt_files, thread_count, batch_size, numcep, numcontext):
|
||||
self._numcep = numcep
|
||||
self._batch_queue = Queue(2 * self._get_device_count())
|
||||
self._x = tf.placeholder(tf.float32, [None, numcep + (2 * numcep * numcontext)])
|
||||
self._x_length = tf.placeholder(tf.int32, [])
|
||||
self._y = tf.placeholder(tf.int32, [None,])
|
||||
self._y_length = tf.placeholder(tf.int32, [])
|
||||
self._example_queue = tf.PaddingFIFOQueue(shapes=[[None, numcep + (2 * numcep * numcontext)], [], [None,], []],
|
||||
dtypes=[tf.float32, tf.int32, tf.int32, tf.int32],
|
||||
capacity=2 * self._get_device_count() * batch_size)
|
||||
self._enqueue_op = self._example_queue.enqueue([self._x, self._x_length, self._y, self._y_length])
|
||||
self._txt_files = txt_files
|
||||
self._batch_size = batch_size
|
||||
self._numcontext = numcontext
|
||||
self._thread_count = thread_count
|
||||
self._files_circular_list = self._create_files_circular_list()
|
||||
self._start_queue_threads()
|
||||
|
||||
def _get_device_count(self):
|
||||
available_gpus = get_available_gpus()
|
||||
return max(len(available_gpus), 1)
|
||||
|
||||
def _start_queue_threads(self):
|
||||
batch_threads = [Thread(target=self._populate_batch_queue) for i in xrange(self._thread_count)]
|
||||
def start_queue_threads(self, session):
|
||||
batch_threads = [Thread(target=self._populate_batch_queue, args=(session,)) for i in xrange(self._thread_count)]
|
||||
for batch_thread in batch_threads:
|
||||
batch_thread.daemon = True
|
||||
batch_thread.start()
|
||||
@ -68,42 +78,30 @@ class DataSet(object):
|
||||
return cycle(files_list)
|
||||
|
||||
def _populate_batch_queue(self):
|
||||
with self._graph.as_default():
|
||||
n_steps = 0
|
||||
sources = []
|
||||
targets = []
|
||||
batch_index = 0
|
||||
for txt_file, wav_file in self._files_circular_list:
|
||||
if batch_index == self._batch_size:
|
||||
# Put batch on queue
|
||||
target = texts_to_sparse_tensor(targets)
|
||||
for index, next_source in enumerate(sources):
|
||||
npad = ((0,(n_steps - next_source.shape[0])), (0,0))
|
||||
sources[index] = np.pad(next_source, pad_width=npad, mode='constant')
|
||||
source = np.array(sources)
|
||||
self._batch_queue.put((source, target))
|
||||
n_steps = 0
|
||||
sources = []
|
||||
targets = []
|
||||
batch_index = 0
|
||||
next_source = audiofile_to_input_vector(wav_file, self._numcep, self._numcontext)
|
||||
if n_steps < next_source.shape[0]:
|
||||
n_steps = next_source.shape[0]
|
||||
sources.append(next_source)
|
||||
with open(txt_file) as open_txt_file:
|
||||
targets.append(open_txt_file.read())
|
||||
batch_index = batch_index + 1
|
||||
for txt_file, wav_file in self._files_circular_list:
|
||||
source = audiofile_to_input_vector(wav_file, self._numcep, self._numcontext)
|
||||
source_len = len(source)
|
||||
with codecs.open(txt_file, encoding="utf-8") as open_txt_file:
|
||||
target = unicodedata.normalize("NFKD", open_txt_file.read()).encode("ascii", "ignore")
|
||||
target = text_to_char_array(target)
|
||||
target_len = len(target)
|
||||
session.run(self._enqueue_op, feed_dict={
|
||||
self._x: source,
|
||||
self._x_length: source_len,
|
||||
self._y: target,
|
||||
self._y_length: target_len})
|
||||
|
||||
def next_batch(self):
|
||||
source, target = self._batch_queue.get()
|
||||
return (source, target, source.shape[1])
|
||||
source, source_lengths, target, target_lengths = self._example_queue.dequeue_many(self._batch_size)
|
||||
sparse_labels = ctc_label_dense_to_sparse(target, target_lengths, self._batch_size)
|
||||
return source, source_lengths, sparse_labels
|
||||
|
||||
@property
|
||||
def total_batches(self):
|
||||
# Note: If len(_txt_files) % _batch_size != 0, this re-uses initial _txt_files
|
||||
return int(ceil(float(len(self._txt_files)) /float(self._batch_size)))
|
||||
|
||||
def read_data_sets(graph, data_dir, batch_size, numcep, numcontext, thread_count=8):
|
||||
def read_data_sets(data_dir, batch_size, numcep, numcontext, thread_count=8):
|
||||
# Assume data_dir contains extracted LDC2004S13, LDC2004T19, LDC2005S13, LDC2005T19
|
||||
|
||||
# Conditionally convert Fisher sph data to wav
|
||||
@ -123,13 +121,13 @@ def read_data_sets(graph, data_dir, batch_size, numcep, numcontext, thread_count
|
||||
_maybe_split_sets(data_dir, "fisher-2005-split-wav", "fisher-2005-split-wav-sets")
|
||||
|
||||
# Create train DataSet
|
||||
train = _read_data_set(graph, data_dir, "fisher-200?-split-wav-sets/train", thread_count, batch_size, numcep, numcontext)
|
||||
train = _read_data_set(data_dir, "fisher-200?-split-wav-sets/train", thread_count, batch_size, numcep, numcontext)
|
||||
|
||||
# Create dev DataSet
|
||||
dev = _read_data_set(graph, data_dir, "fisher-200?-split-wav-sets/dev", thread_count, batch_size, numcep, numcontext)
|
||||
dev = _read_data_set(data_dir, "fisher-200?-split-wav-sets/dev", thread_count, batch_size, numcep, numcontext)
|
||||
|
||||
# Create test DataSet
|
||||
test = _read_data_set(graph, data_dir, "fisher-200?-split-wav-sets/test", thread_count, batch_size, numcep, numcontext)
|
||||
test = _read_data_set(data_dir, "fisher-200?-split-wav-sets/test", thread_count, batch_size, numcep, numcontext)
|
||||
|
||||
# Return DataSets
|
||||
return DataSets(train, dev, test)
|
||||
@ -302,7 +300,7 @@ def _maybe_split_dataset(filelist, target_dir):
|
||||
new_wav_file = os.path.join(target_dir, os.path.basename(wav_file))
|
||||
os.rename(wav_file, new_wav_file)
|
||||
|
||||
def _read_data_set(graph, work_dir, data_set, thread_count, batch_size, numcep, numcontext):
|
||||
def _read_data_set(work_dir, data_set, thread_count, batch_size, numcep, numcontext):
|
||||
# Create data set dir
|
||||
dataset_dir = os.path.join(work_dir, data_set)
|
||||
|
||||
@ -310,4 +308,4 @@ def _read_data_set(graph, work_dir, data_set, thread_count, batch_size, numcep,
|
||||
txt_files = glob(os.path.join(dataset_dir, "*.txt"))
|
||||
|
||||
# Return DataSet
|
||||
return DataSet(graph, txt_files, thread_count, batch_size, numcep, numcontext)
|
||||
return DataSet(txt_files, thread_count, batch_size, numcep, numcontext)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user