Convert Fisher importer to new input system

This commit is contained in:
Reuben Morais 2016-11-14 15:04:44 -02:00
parent a40df7251e
commit e7bbfbf703

View File

@ -1,19 +1,19 @@
import fnmatch
import numpy as np
import os
import random
import subprocess
import wave
import tensorflow as tf
import unicodedata
import codecs
from glob import glob
from itertools import cycle
from math import ceil
from Queue import PriorityQueue
from Queue import Queue
from threading import Thread
from util.audio import audiofile_to_input_vector
from util.gpu import get_available_gpus
from util.text import texts_to_sparse_tensor, validate_label
from util.text import text_to_char_array, ctc_label_dense_to_sparse, validate_label
class DataSets(object):
def __init__(self, train, dev, test):
@ -21,6 +21,11 @@ class DataSets(object):
self._test = test
self._train = train
def start_queue_threads(self, session):
self._dev.start_queue_threads(session)
self._test.start_queue_threads(session)
self._train.start_queue_threads(session)
@property
def train(self):
return self._train
@ -34,23 +39,28 @@ class DataSets(object):
return self._test
class DataSet(object):
def __init__(self, graph, txt_files, thread_count, batch_size, numcep, numcontext):
self._graph = graph
def __init__(self, txt_files, thread_count, batch_size, numcep, numcontext):
self._numcep = numcep
self._batch_queue = Queue(2 * self._get_device_count())
self._x = tf.placeholder(tf.float32, [None, numcep + (2 * numcep * numcontext)])
self._x_length = tf.placeholder(tf.int32, [])
self._y = tf.placeholder(tf.int32, [None,])
self._y_length = tf.placeholder(tf.int32, [])
self._example_queue = tf.PaddingFIFOQueue(shapes=[[None, numcep + (2 * numcep * numcontext)], [], [None,], []],
dtypes=[tf.float32, tf.int32, tf.int32, tf.int32],
capacity=2 * self._get_device_count() * batch_size)
self._enqueue_op = self._example_queue.enqueue([self._x, self._x_length, self._y, self._y_length])
self._txt_files = txt_files
self._batch_size = batch_size
self._numcontext = numcontext
self._thread_count = thread_count
self._files_circular_list = self._create_files_circular_list()
self._start_queue_threads()
def _get_device_count(self):
available_gpus = get_available_gpus()
return max(len(available_gpus), 1)
def _start_queue_threads(self):
batch_threads = [Thread(target=self._populate_batch_queue) for i in xrange(self._thread_count)]
def start_queue_threads(self, session):
batch_threads = [Thread(target=self._populate_batch_queue, args=(session,)) for i in xrange(self._thread_count)]
for batch_thread in batch_threads:
batch_thread.daemon = True
batch_thread.start()
@ -68,42 +78,30 @@ class DataSet(object):
return cycle(files_list)
def _populate_batch_queue(self):
with self._graph.as_default():
n_steps = 0
sources = []
targets = []
batch_index = 0
for txt_file, wav_file in self._files_circular_list:
if batch_index == self._batch_size:
# Put batch on queue
target = texts_to_sparse_tensor(targets)
for index, next_source in enumerate(sources):
npad = ((0,(n_steps - next_source.shape[0])), (0,0))
sources[index] = np.pad(next_source, pad_width=npad, mode='constant')
source = np.array(sources)
self._batch_queue.put((source, target))
n_steps = 0
sources = []
targets = []
batch_index = 0
next_source = audiofile_to_input_vector(wav_file, self._numcep, self._numcontext)
if n_steps < next_source.shape[0]:
n_steps = next_source.shape[0]
sources.append(next_source)
with open(txt_file) as open_txt_file:
targets.append(open_txt_file.read())
batch_index = batch_index + 1
for txt_file, wav_file in self._files_circular_list:
source = audiofile_to_input_vector(wav_file, self._numcep, self._numcontext)
source_len = len(source)
with codecs.open(txt_file, encoding="utf-8") as open_txt_file:
target = unicodedata.normalize("NFKD", open_txt_file.read()).encode("ascii", "ignore")
target = text_to_char_array(target)
target_len = len(target)
session.run(self._enqueue_op, feed_dict={
self._x: source,
self._x_length: source_len,
self._y: target,
self._y_length: target_len})
def next_batch(self):
source, target = self._batch_queue.get()
return (source, target, source.shape[1])
source, source_lengths, target, target_lengths = self._example_queue.dequeue_many(self._batch_size)
sparse_labels = ctc_label_dense_to_sparse(target, target_lengths, self._batch_size)
return source, source_lengths, sparse_labels
@property
def total_batches(self):
# Note: If len(_txt_files) % _batch_size != 0, this re-uses initial _txt_files
return int(ceil(float(len(self._txt_files)) /float(self._batch_size)))
def read_data_sets(graph, data_dir, batch_size, numcep, numcontext, thread_count=8):
def read_data_sets(data_dir, batch_size, numcep, numcontext, thread_count=8):
# Assume data_dir contains extracted LDC2004S13, LDC2004T19, LDC2005S13, LDC2005T19
# Conditionally convert Fisher sph data to wav
@ -123,13 +121,13 @@ def read_data_sets(graph, data_dir, batch_size, numcep, numcontext, thread_count
_maybe_split_sets(data_dir, "fisher-2005-split-wav", "fisher-2005-split-wav-sets")
# Create train DataSet
train = _read_data_set(graph, data_dir, "fisher-200?-split-wav-sets/train", thread_count, batch_size, numcep, numcontext)
train = _read_data_set(data_dir, "fisher-200?-split-wav-sets/train", thread_count, batch_size, numcep, numcontext)
# Create dev DataSet
dev = _read_data_set(graph, data_dir, "fisher-200?-split-wav-sets/dev", thread_count, batch_size, numcep, numcontext)
dev = _read_data_set(data_dir, "fisher-200?-split-wav-sets/dev", thread_count, batch_size, numcep, numcontext)
# Create test DataSet
test = _read_data_set(graph, data_dir, "fisher-200?-split-wav-sets/test", thread_count, batch_size, numcep, numcontext)
test = _read_data_set(data_dir, "fisher-200?-split-wav-sets/test", thread_count, batch_size, numcep, numcontext)
# Return DataSets
return DataSets(train, dev, test)
@ -302,7 +300,7 @@ def _maybe_split_dataset(filelist, target_dir):
new_wav_file = os.path.join(target_dir, os.path.basename(wav_file))
os.rename(wav_file, new_wav_file)
def _read_data_set(graph, work_dir, data_set, thread_count, batch_size, numcep, numcontext):
def _read_data_set(work_dir, data_set, thread_count, batch_size, numcep, numcontext):
# Create data set dir
dataset_dir = os.path.join(work_dir, data_set)
@ -310,4 +308,4 @@ def _read_data_set(graph, work_dir, data_set, thread_count, batch_size, numcep,
txt_files = glob(os.path.join(dataset_dir, "*.txt"))
# Return DataSet
return DataSet(graph, txt_files, thread_count, batch_size, numcep, numcontext)
return DataSet(txt_files, thread_count, batch_size, numcep, numcontext)