From e7bbfbf703b82bee83cb60265ae67f337cbc3f00 Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Mon, 14 Nov 2016 15:04:44 -0200 Subject: [PATCH] Convert Fisher importer to new input system --- util/importers/fisher.py | 84 ++++++++++++++++++++-------------------- 1 file changed, 41 insertions(+), 43 deletions(-) diff --git a/util/importers/fisher.py b/util/importers/fisher.py index 53f65b34..8ade40ac 100644 --- a/util/importers/fisher.py +++ b/util/importers/fisher.py @@ -1,19 +1,19 @@ import fnmatch -import numpy as np import os -import random import subprocess import wave +import tensorflow as tf +import unicodedata +import codecs from glob import glob from itertools import cycle from math import ceil from Queue import PriorityQueue -from Queue import Queue from threading import Thread from util.audio import audiofile_to_input_vector from util.gpu import get_available_gpus -from util.text import texts_to_sparse_tensor, validate_label +from util.text import text_to_char_array, ctc_label_dense_to_sparse, validate_label class DataSets(object): def __init__(self, train, dev, test): @@ -21,6 +21,11 @@ class DataSets(object): self._test = test self._train = train + def start_queue_threads(self, session): + self._dev.start_queue_threads(session) + self._test.start_queue_threads(session) + self._train.start_queue_threads(session) + @property def train(self): return self._train @@ -34,23 +39,28 @@ class DataSets(object): return self._test class DataSet(object): - def __init__(self, graph, txt_files, thread_count, batch_size, numcep, numcontext): - self._graph = graph + def __init__(self, txt_files, thread_count, batch_size, numcep, numcontext): self._numcep = numcep - self._batch_queue = Queue(2 * self._get_device_count()) + self._x = tf.placeholder(tf.float32, [None, numcep + (2 * numcep * numcontext)]) + self._x_length = tf.placeholder(tf.int32, []) + self._y = tf.placeholder(tf.int32, [None,]) + self._y_length = tf.placeholder(tf.int32, []) + self._example_queue = tf.PaddingFIFOQueue(shapes=[[None, numcep + (2 * numcep * numcontext)], [], [None,], []], + dtypes=[tf.float32, tf.int32, tf.int32, tf.int32], + capacity=2 * self._get_device_count() * batch_size) + self._enqueue_op = self._example_queue.enqueue([self._x, self._x_length, self._y, self._y_length]) self._txt_files = txt_files self._batch_size = batch_size self._numcontext = numcontext self._thread_count = thread_count self._files_circular_list = self._create_files_circular_list() - self._start_queue_threads() def _get_device_count(self): available_gpus = get_available_gpus() return max(len(available_gpus), 1) - def _start_queue_threads(self): - batch_threads = [Thread(target=self._populate_batch_queue) for i in xrange(self._thread_count)] + def start_queue_threads(self, session): + batch_threads = [Thread(target=self._populate_batch_queue, args=(session,)) for i in xrange(self._thread_count)] for batch_thread in batch_threads: batch_thread.daemon = True batch_thread.start() @@ -68,42 +78,30 @@ class DataSet(object): return cycle(files_list) def _populate_batch_queue(self): - with self._graph.as_default(): - n_steps = 0 - sources = [] - targets = [] - batch_index = 0 - for txt_file, wav_file in self._files_circular_list: - if batch_index == self._batch_size: - # Put batch on queue - target = texts_to_sparse_tensor(targets) - for index, next_source in enumerate(sources): - npad = ((0,(n_steps - next_source.shape[0])), (0,0)) - sources[index] = np.pad(next_source, pad_width=npad, mode='constant') - source = np.array(sources) - self._batch_queue.put((source, target)) - n_steps = 0 - sources = [] - targets = [] - batch_index = 0 - next_source = audiofile_to_input_vector(wav_file, self._numcep, self._numcontext) - if n_steps < next_source.shape[0]: - n_steps = next_source.shape[0] - sources.append(next_source) - with open(txt_file) as open_txt_file: - targets.append(open_txt_file.read()) - batch_index = batch_index + 1 + for txt_file, wav_file in self._files_circular_list: + source = audiofile_to_input_vector(wav_file, self._numcep, self._numcontext) + source_len = len(source) + with codecs.open(txt_file, encoding="utf-8") as open_txt_file: + target = unicodedata.normalize("NFKD", open_txt_file.read()).encode("ascii", "ignore") + target = text_to_char_array(target) + target_len = len(target) + session.run(self._enqueue_op, feed_dict={ + self._x: source, + self._x_length: source_len, + self._y: target, + self._y_length: target_len}) def next_batch(self): - source, target = self._batch_queue.get() - return (source, target, source.shape[1]) + source, source_lengths, target, target_lengths = self._example_queue.dequeue_many(self._batch_size) + sparse_labels = ctc_label_dense_to_sparse(target, target_lengths, self._batch_size) + return source, source_lengths, sparse_labels @property def total_batches(self): # Note: If len(_txt_files) % _batch_size != 0, this re-uses initial _txt_files return int(ceil(float(len(self._txt_files)) /float(self._batch_size))) -def read_data_sets(graph, data_dir, batch_size, numcep, numcontext, thread_count=8): +def read_data_sets(data_dir, batch_size, numcep, numcontext, thread_count=8): # Assume data_dir contains extracted LDC2004S13, LDC2004T19, LDC2005S13, LDC2005T19 # Conditionally convert Fisher sph data to wav @@ -123,13 +121,13 @@ def read_data_sets(graph, data_dir, batch_size, numcep, numcontext, thread_count _maybe_split_sets(data_dir, "fisher-2005-split-wav", "fisher-2005-split-wav-sets") # Create train DataSet - train = _read_data_set(graph, data_dir, "fisher-200?-split-wav-sets/train", thread_count, batch_size, numcep, numcontext) + train = _read_data_set(data_dir, "fisher-200?-split-wav-sets/train", thread_count, batch_size, numcep, numcontext) # Create dev DataSet - dev = _read_data_set(graph, data_dir, "fisher-200?-split-wav-sets/dev", thread_count, batch_size, numcep, numcontext) + dev = _read_data_set(data_dir, "fisher-200?-split-wav-sets/dev", thread_count, batch_size, numcep, numcontext) # Create test DataSet - test = _read_data_set(graph, data_dir, "fisher-200?-split-wav-sets/test", thread_count, batch_size, numcep, numcontext) + test = _read_data_set(data_dir, "fisher-200?-split-wav-sets/test", thread_count, batch_size, numcep, numcontext) # Return DataSets return DataSets(train, dev, test) @@ -302,7 +300,7 @@ def _maybe_split_dataset(filelist, target_dir): new_wav_file = os.path.join(target_dir, os.path.basename(wav_file)) os.rename(wav_file, new_wav_file) -def _read_data_set(graph, work_dir, data_set, thread_count, batch_size, numcep, numcontext): +def _read_data_set(work_dir, data_set, thread_count, batch_size, numcep, numcontext): # Create data set dir dataset_dir = os.path.join(work_dir, data_set) @@ -310,4 +308,4 @@ def _read_data_set(graph, work_dir, data_set, thread_count, batch_size, numcep, txt_files = glob(os.path.join(dataset_dir, "*.txt")) # Return DataSet - return DataSet(graph, txt_files, thread_count, batch_size, numcep, numcontext) + return DataSet(txt_files, thread_count, batch_size, numcep, numcontext)