mirror of
https://github.com/mozilla/DeepSpeech.git
synced 2025-10-26 11:19:39 +00:00
Adapt Fisher importer to other importer API changes
This commit is contained in:
parent
33c9521a6f
commit
3ef07ce7d0
@ -13,7 +13,7 @@ from Queue import PriorityQueue
|
||||
from threading import Thread
|
||||
from util.audio import audiofile_to_input_vector
|
||||
from util.gpu import get_available_gpus
|
||||
from util.text import text_to_char_array, validate_label
|
||||
from util.text import text_to_char_array, validate_label, ctc_label_dense_to_sparse
|
||||
|
||||
class DataSets(object):
|
||||
def __init__(self, train, dev, test):
|
||||
@ -77,7 +77,7 @@ class DataSet(object):
|
||||
files_list.append((txt_file, wav_file))
|
||||
return cycle(files_list)
|
||||
|
||||
def _populate_batch_queue(self):
|
||||
def _populate_batch_queue(self, session):
|
||||
for txt_file, wav_file in self._files_circular_list:
|
||||
source = audiofile_to_input_vector(wav_file, self._numcep, self._numcontext)
|
||||
source_len = len(source)
|
||||
@ -85,22 +85,26 @@ class DataSet(object):
|
||||
target = unicodedata.normalize("NFKD", open_txt_file.read()).encode("ascii", "ignore")
|
||||
target = text_to_char_array(target)
|
||||
target_len = len(target)
|
||||
session.run(self._enqueue_op, feed_dict={
|
||||
self._x: source,
|
||||
self._x_length: source_len,
|
||||
self._y: target,
|
||||
self._y_length: target_len})
|
||||
try:
|
||||
session.run(self._enqueue_op, feed_dict={
|
||||
self._x: source,
|
||||
self._x_length: source_len,
|
||||
self._y: target,
|
||||
self._y_length: target_len})
|
||||
except (RuntimeError, tf.errors.CancelledError):
|
||||
return
|
||||
|
||||
def next_batch(self):
|
||||
source, source_lengths, target, target_lengths = self._example_queue.dequeue_many(self._batch_size)
|
||||
return source, source_lengths, target, target_lengths
|
||||
sparse_labels = ctc_label_dense_to_sparse(target, target_lengths, self._batch_size)
|
||||
return source, source_lengths, sparse_labels
|
||||
|
||||
@property
|
||||
def total_batches(self):
|
||||
# Note: If len(_txt_files) % _batch_size != 0, this re-uses initial _txt_files
|
||||
return int(ceil(float(len(self._txt_files)) /float(self._batch_size)))
|
||||
|
||||
def read_data_sets(data_dir, batch_size, numcep, numcontext, thread_count=8):
|
||||
def read_data_sets(data_dir, train_batch_size, dev_batch_size, test_batch_size, numcep, numcontext, thread_count=8, limit_dev=0, limit_test=0, limit_train=0):
|
||||
# Assume data_dir contains extracted LDC2004S13, LDC2004T19, LDC2005S13, LDC2005T19
|
||||
|
||||
# Conditionally convert Fisher sph data to wav
|
||||
@ -120,13 +124,13 @@ def read_data_sets(data_dir, batch_size, numcep, numcontext, thread_count=8):
|
||||
_maybe_split_sets(data_dir, "fisher-2005-split-wav", "fisher-2005-split-wav-sets")
|
||||
|
||||
# Create train DataSet
|
||||
train = _read_data_set(data_dir, "fisher-200?-split-wav-sets/train", thread_count, batch_size, numcep, numcontext)
|
||||
train = _read_data_set(data_dir, "fisher-200?-split-wav-sets/train", thread_count, train_batch_size, numcep, numcontext, limit=limit_train)
|
||||
|
||||
# Create dev DataSet
|
||||
dev = _read_data_set(data_dir, "fisher-200?-split-wav-sets/dev", thread_count, batch_size, numcep, numcontext)
|
||||
dev = _read_data_set(data_dir, "fisher-200?-split-wav-sets/dev", thread_count, dev_batch_size, numcep, numcontext, limit=limit_dev)
|
||||
|
||||
# Create test DataSet
|
||||
test = _read_data_set(data_dir, "fisher-200?-split-wav-sets/test", thread_count, batch_size, numcep, numcontext)
|
||||
test = _read_data_set(data_dir, "fisher-200?-split-wav-sets/test", thread_count, test_batch_size, numcep, numcontext, limit=limit_test)
|
||||
|
||||
# Return DataSets
|
||||
return DataSets(train, dev, test)
|
||||
@ -299,12 +303,14 @@ def _maybe_split_dataset(filelist, target_dir):
|
||||
new_wav_file = os.path.join(target_dir, os.path.basename(wav_file))
|
||||
os.rename(wav_file, new_wav_file)
|
||||
|
||||
def _read_data_set(work_dir, data_set, thread_count, batch_size, numcep, numcontext):
|
||||
def _read_data_set(work_dir, data_set, thread_count, batch_size, numcep, numcontext, limit=0):
|
||||
# Create data set dir
|
||||
dataset_dir = os.path.join(work_dir, data_set)
|
||||
|
||||
# Obtain list of txt files
|
||||
txt_files = glob(os.path.join(dataset_dir, "*.txt"))
|
||||
if limit > 0:
|
||||
txt_files = txt_files[:limit]
|
||||
|
||||
# Return DataSet
|
||||
return DataSet(txt_files, thread_count, batch_size, numcep, numcontext)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user