Adapt Fisher importer to other importer API changes

This commit is contained in:
Reuben Morais 2016-12-21 14:38:17 -02:00
parent 33c9521a6f
commit 3ef07ce7d0

View File

@ -13,7 +13,7 @@ from Queue import PriorityQueue
from threading import Thread
from util.audio import audiofile_to_input_vector
from util.gpu import get_available_gpus
from util.text import text_to_char_array, validate_label
from util.text import text_to_char_array, validate_label, ctc_label_dense_to_sparse
class DataSets(object):
def __init__(self, train, dev, test):
@ -77,7 +77,7 @@ class DataSet(object):
files_list.append((txt_file, wav_file))
return cycle(files_list)
def _populate_batch_queue(self):
def _populate_batch_queue(self, session):
for txt_file, wav_file in self._files_circular_list:
source = audiofile_to_input_vector(wav_file, self._numcep, self._numcontext)
source_len = len(source)
@ -85,22 +85,26 @@ class DataSet(object):
target = unicodedata.normalize("NFKD", open_txt_file.read()).encode("ascii", "ignore")
target = text_to_char_array(target)
target_len = len(target)
session.run(self._enqueue_op, feed_dict={
self._x: source,
self._x_length: source_len,
self._y: target,
self._y_length: target_len})
try:
session.run(self._enqueue_op, feed_dict={
self._x: source,
self._x_length: source_len,
self._y: target,
self._y_length: target_len})
except (RuntimeError, tf.errors.CancelledError):
return
def next_batch(self):
source, source_lengths, target, target_lengths = self._example_queue.dequeue_many(self._batch_size)
return source, source_lengths, target, target_lengths
sparse_labels = ctc_label_dense_to_sparse(target, target_lengths, self._batch_size)
return source, source_lengths, sparse_labels
@property
def total_batches(self):
# Note: If len(_txt_files) % _batch_size != 0, this re-uses initial _txt_files
return int(ceil(float(len(self._txt_files)) /float(self._batch_size)))
def read_data_sets(data_dir, batch_size, numcep, numcontext, thread_count=8):
def read_data_sets(data_dir, train_batch_size, dev_batch_size, test_batch_size, numcep, numcontext, thread_count=8, limit_dev=0, limit_test=0, limit_train=0):
# Assume data_dir contains extracted LDC2004S13, LDC2004T19, LDC2005S13, LDC2005T19
# Conditionally convert Fisher sph data to wav
@ -120,13 +124,13 @@ def read_data_sets(data_dir, batch_size, numcep, numcontext, thread_count=8):
_maybe_split_sets(data_dir, "fisher-2005-split-wav", "fisher-2005-split-wav-sets")
# Create train DataSet
train = _read_data_set(data_dir, "fisher-200?-split-wav-sets/train", thread_count, batch_size, numcep, numcontext)
train = _read_data_set(data_dir, "fisher-200?-split-wav-sets/train", thread_count, train_batch_size, numcep, numcontext, limit=limit_train)
# Create dev DataSet
dev = _read_data_set(data_dir, "fisher-200?-split-wav-sets/dev", thread_count, batch_size, numcep, numcontext)
dev = _read_data_set(data_dir, "fisher-200?-split-wav-sets/dev", thread_count, dev_batch_size, numcep, numcontext, limit=limit_dev)
# Create test DataSet
test = _read_data_set(data_dir, "fisher-200?-split-wav-sets/test", thread_count, batch_size, numcep, numcontext)
test = _read_data_set(data_dir, "fisher-200?-split-wav-sets/test", thread_count, test_batch_size, numcep, numcontext, limit=limit_test)
# Return DataSets
return DataSets(train, dev, test)
@ -299,12 +303,14 @@ def _maybe_split_dataset(filelist, target_dir):
new_wav_file = os.path.join(target_dir, os.path.basename(wav_file))
os.rename(wav_file, new_wav_file)
def _read_data_set(work_dir, data_set, thread_count, batch_size, numcep, numcontext):
def _read_data_set(work_dir, data_set, thread_count, batch_size, numcep, numcontext, limit=0):
# Create data set dir
dataset_dir = os.path.join(work_dir, data_set)
# Obtain list of txt files
txt_files = glob(os.path.join(dataset_dir, "*.txt"))
if limit > 0:
txt_files = txt_files[:limit]
# Return DataSet
return DataSet(txt_files, thread_count, batch_size, numcep, numcontext)