DeepSpeech/util/preprocess.py
2018-09-14 11:20:28 -03:00

102 lines
4.0 KiB
Python

import numpy as np
import os
import pandas
import tables
from functools import partial
from multiprocessing.dummy import Pool
from util.audio import audiofile_to_input_vector
from util.text import text_to_char_array
def pmap(fun, iterable, threads=8):
pool = Pool(threads)
results = pool.map(fun, iterable)
pool.close()
return results
def process_single_file(row, numcep, numcontext, alphabet):
# row = index, Series
_, file = row
features = audiofile_to_input_vector(file.wav_filename, numcep, numcontext)
transcript = text_to_char_array(file.transcript, alphabet)
if (2*numcontext + len(features)) < len(transcript):
raise ValueError('Error: Audio file {} is too short for transcription.'.format(file.wav_filename))
return features, len(features), transcript, len(transcript)
# load samples from CSV, compute features, optionally cache results on disk
def preprocess(csv_files, batch_size, numcep, numcontext, alphabet, hdf5_cache_path=None):
COLUMNS = ('features', 'features_len', 'transcript', 'transcript_len')
print('Preprocessing', csv_files)
if hdf5_cache_path and os.path.exists(hdf5_cache_path):
with tables.open_file(hdf5_cache_path, 'r') as file:
features = file.root.features[:]
features_len = file.root.features_len[:]
transcript = file.root.transcript[:]
transcript_len = file.root.transcript_len[:]
# features are stored flattened, so reshape into
# [n_steps, (n_input + 2*n_context*n_input)]
for i in range(len(features)):
features[i] = np.reshape(features[i], [features_len[i], -1])
in_data = list(zip(features, features_len,
transcript, transcript_len))
print('Loaded from cache at', hdf5_cache_path)
return pandas.DataFrame(data=in_data, columns=COLUMNS)
source_data = None
for csv in csv_files:
file = pandas.read_csv(csv, encoding='utf-8', na_filter=False)
#FIXME: not cross-platform
csv_dir = os.path.dirname(os.path.abspath(csv))
file['wav_filename'] = file['wav_filename'].str.replace(r'(^[^/])', lambda m: os.path.join(csv_dir, m.group(1)))
if source_data is None:
source_data = file
else:
source_data = source_data.append(file)
step_fn = partial(process_single_file,
numcep=numcep,
numcontext=numcontext,
alphabet=alphabet)
out_data = pmap(step_fn, source_data.iterrows())
if hdf5_cache_path:
print('Saving to', hdf5_cache_path)
# list of tuples -> tuple of lists
features, features_len, transcript, transcript_len = zip(*out_data)
with tables.open_file(hdf5_cache_path, 'w') as file:
features_dset = file.create_vlarray(file.root,
'features',
tables.Float32Atom(),
filters=tables.Filters(complevel=1))
# VLArray atoms need to be 1D, so flatten feature array
for f in features:
features_dset.append(np.reshape(f, -1))
features_len_dset = file.create_array(file.root,
'features_len',
features_len)
transcript_dset = file.create_vlarray(file.root,
'transcript',
tables.Int32Atom(),
filters=tables.Filters(complevel=1))
for t in transcript:
transcript_dset.append(t)
transcript_len_dset = file.create_array(file.root,
'transcript_len',
transcript_len)
print('Preprocessing done')
return pandas.DataFrame(data=out_data, columns=COLUMNS)