diff --git a/util/importers/fisher.py b/util/importers/fisher.py index b143d022..120e5e65 100644 --- a/util/importers/fisher.py +++ b/util/importers/fisher.py @@ -13,7 +13,7 @@ from Queue import PriorityQueue from threading import Thread from util.audio import audiofile_to_input_vector from util.gpu import get_available_gpus -from util.text import text_to_char_array, validate_label +from util.text import text_to_char_array, validate_label, ctc_label_dense_to_sparse class DataSets(object): def __init__(self, train, dev, test): @@ -77,7 +77,7 @@ class DataSet(object): files_list.append((txt_file, wav_file)) return cycle(files_list) - def _populate_batch_queue(self): + def _populate_batch_queue(self, session): for txt_file, wav_file in self._files_circular_list: source = audiofile_to_input_vector(wav_file, self._numcep, self._numcontext) source_len = len(source) @@ -85,22 +85,26 @@ class DataSet(object): target = unicodedata.normalize("NFKD", open_txt_file.read()).encode("ascii", "ignore") target = text_to_char_array(target) target_len = len(target) - session.run(self._enqueue_op, feed_dict={ - self._x: source, - self._x_length: source_len, - self._y: target, - self._y_length: target_len}) + try: + session.run(self._enqueue_op, feed_dict={ + self._x: source, + self._x_length: source_len, + self._y: target, + self._y_length: target_len}) + except (RuntimeError, tf.errors.CancelledError): + return def next_batch(self): source, source_lengths, target, target_lengths = self._example_queue.dequeue_many(self._batch_size) - return source, source_lengths, target, target_lengths + sparse_labels = ctc_label_dense_to_sparse(target, target_lengths, self._batch_size) + return source, source_lengths, sparse_labels @property def total_batches(self): # Note: If len(_txt_files) % _batch_size != 0, this re-uses initial _txt_files return int(ceil(float(len(self._txt_files)) /float(self._batch_size))) -def read_data_sets(data_dir, batch_size, numcep, numcontext, thread_count=8): +def read_data_sets(data_dir, train_batch_size, dev_batch_size, test_batch_size, numcep, numcontext, thread_count=8, limit_dev=0, limit_test=0, limit_train=0): # Assume data_dir contains extracted LDC2004S13, LDC2004T19, LDC2005S13, LDC2005T19 # Conditionally convert Fisher sph data to wav @@ -120,13 +124,13 @@ def read_data_sets(data_dir, batch_size, numcep, numcontext, thread_count=8): _maybe_split_sets(data_dir, "fisher-2005-split-wav", "fisher-2005-split-wav-sets") # Create train DataSet - train = _read_data_set(data_dir, "fisher-200?-split-wav-sets/train", thread_count, batch_size, numcep, numcontext) + train = _read_data_set(data_dir, "fisher-200?-split-wav-sets/train", thread_count, train_batch_size, numcep, numcontext, limit=limit_train) # Create dev DataSet - dev = _read_data_set(data_dir, "fisher-200?-split-wav-sets/dev", thread_count, batch_size, numcep, numcontext) + dev = _read_data_set(data_dir, "fisher-200?-split-wav-sets/dev", thread_count, dev_batch_size, numcep, numcontext, limit=limit_dev) # Create test DataSet - test = _read_data_set(data_dir, "fisher-200?-split-wav-sets/test", thread_count, batch_size, numcep, numcontext) + test = _read_data_set(data_dir, "fisher-200?-split-wav-sets/test", thread_count, test_batch_size, numcep, numcontext, limit=limit_test) # Return DataSets return DataSets(train, dev, test) @@ -299,12 +303,14 @@ def _maybe_split_dataset(filelist, target_dir): new_wav_file = os.path.join(target_dir, os.path.basename(wav_file)) os.rename(wav_file, new_wav_file) -def _read_data_set(work_dir, data_set, thread_count, batch_size, numcep, numcontext): +def _read_data_set(work_dir, data_set, thread_count, batch_size, numcep, numcontext, limit=0): # Create data set dir dataset_dir = os.path.join(work_dir, data_set) # Obtain list of txt files txt_files = glob(os.path.join(dataset_dir, "*.txt")) + if limit > 0: + txt_files = txt_files[:limit] # Return DataSet return DataSet(txt_files, thread_count, batch_size, numcep, numcontext)