Switchboard importer

This commit is contained in:
Andre Natal 2016-10-29 15:21:01 -07:00 committed by Andre
parent bc8b046605
commit 32a436309e
2 changed files with 371 additions and 9 deletions

360
util/importers/swb.py Normal file
View File

@ -0,0 +1,360 @@
import fnmatch
import numpy as np
import os
import random
import subprocess
import wave
import tensorflow as tf
import unicodedata
import codecs
from glob import glob
from itertools import cycle
from math import ceil
from Queue import PriorityQueue
from Queue import Queue
from threading import Thread
from util.audio import audiofile_to_input_vector
from util.gpu import get_available_gpus
from util.text import text_to_char_array, ctc_label_dense_to_sparse, validate_label
class DataSets(object):
def __init__(self, train, dev, test):
self._dev = dev
self._test = test
self._train = train
def start_queue_threads(self, session):
self._dev.start_queue_threads(session)
self._test.start_queue_threads(session)
self._train.start_queue_threads(session)
@property
def train(self):
return self._train
@property
def dev(self):
return self._dev
@property
def test(self):
return self._test
class DataSet(object):
def __init__(self, txt_files, thread_count, batch_size, numcep, numcontext, dataset_dir):
self._numcep = numcep
self._x = tf.placeholder(tf.float32, [None, numcep + (2 * numcep * numcontext)])
self._x_length = tf.placeholder(tf.int32, [])
self._y = tf.placeholder(tf.int32, [None, ])
self._y_length = tf.placeholder(tf.int32, [])
self._example_queue = tf.PaddingFIFOQueue(shapes=[[None, numcep + (2 * numcep * numcontext)], [], [None, ], []],
dtypes = [tf.float32, tf.int32, tf.int32, tf.int32],
capacity = 2 * self._get_device_count() * batch_size)
self._enqueue_op = self._example_queue.enqueue([self._x, self._x_length, self._y, self._y_length])
self._txt_files = txt_files
self._batch_size = batch_size
self._numcontext = numcontext
self._thread_count = thread_count
self._files_circular_list = self._create_files_circular_list()
def _get_device_count(self):
available_gpus = get_available_gpus()
return max(len(available_gpus), 1)
def start_queue_threads(self, session):
batch_threads = [Thread(target=self._populate_batch_queue, args=(session,)) for i in xrange(self._thread_count)]
for batch_thread in batch_threads:
batch_thread.daemon = True
batch_thread.start()
def _create_files_circular_list(self):
priorityQueue = PriorityQueue()
for txt_file in self._txt_files:
wav_file = os.path.splitext(txt_file)[0] + ".wav"
wav_file_size = os.path.getsize(wav_file)
priorityQueue.put((wav_file_size, (txt_file, wav_file)))
files_list = []
while not priorityQueue.empty():
priority, (txt_file, wav_file) = priorityQueue.get()
files_list.append((txt_file, wav_file))
return cycle(files_list)
def _populate_batch_queue(self, session):
for txt_file, wav_file in self._files_circular_list:
source = audiofile_to_input_vector(wav_file, self._numcep, self._numcontext)
source_len = len(source)
with codecs.open(txt_file, encoding="utf-8") as open_txt_file:
target = unicodedata.normalize("NFKD", open_txt_file.read()).encode("ascii", "ignore")
target = text_to_char_array(target)
target_len = len(target)
session.run(self._enqueue_op, feed_dict={
self._x: source,
self._x_length: source_len,
self._y: target,
self._y_length: target_len})
def next_batch(self):
source, source_lengths, target, target_lengths = self._example_queue.dequeue_many(self._batch_size)
return source, source_lengths, target, target_lengths
@property
def total_batches(self):
# Note: If len(_txt_files) % _batch_size != 0, this re-uses initial _txt_files
return int(ceil(float(len(self._txt_files)) / float(self._batch_size)))
def read_data_sets(data_dir, batch_size, numcep, numcontext, thread_count=8, limit_dev=0, limit_test=0, limit_train=0):
data_dir = os.path.join(data_dir, "LDC97S62")
# Conditionally convert swb sph data to wav
_maybe_convert_wav(data_dir, "swb1_d1", "swb1_d1-wav")
_maybe_convert_wav(data_dir, "swb1_d2", "swb1_d2-wav")
_maybe_convert_wav(data_dir, "swb1_d3", "swb1_d3-wav")
_maybe_convert_wav(data_dir, "swb1_d4", "swb1_d4-wav")
# Conditionally split wav data
_maybe_split_wav(data_dir, "swb_ms98_transcriptions", "swb1_d1-wav",
"swb1_d1-split-wav")
_maybe_split_wav(data_dir, "swb_ms98_transcriptions", "swb1_d2-wav",
"swb1_d2-split-wav")
_maybe_split_wav(data_dir, "swb_ms98_transcriptions", "swb1_d3-wav",
"swb1_d3-split-wav")
_maybe_split_wav(data_dir, "swb_ms98_transcriptions", "swb1_d4-wav",
"swb1_d4-split-wav")
_maybe_split_transcriptions(data_dir, "swb_ms98_transcriptions")
_maybe_split_sets(data_dir, ["swb1_d1-split-wav", "swb1_d2-split-wav", "swb1_d3-split-wav", "swb1_d4-split-wav"],
"final_sets")
# Create dev DataSet
dev = _read_data_set(data_dir, "final_sets/dev", thread_count, batch_size, numcep,
numcontext, limit_dev)
# Create test DataSet
test = _read_data_set(data_dir, "final_sets/test", thread_count, batch_size, numcep,
numcontext, limit_test)
# Create train DataSet
train = _read_data_set(data_dir, "final_sets/train", thread_count, batch_size, numcep,
numcontext, limit_train)
# Return DataSets
return DataSets(train, dev, test)
def _maybe_convert_wav(data_dir, original_data, converted_data):
source_dir = os.path.join(data_dir, original_data)
target_dir = os.path.join(data_dir, converted_data)
# Conditionally convert sph files to wav files
if os.path.exists(target_dir):
print("skipping maybe_convert_wav")
return
# Create target_dir
os.makedirs(target_dir)
# Loop over sph files in source_dir and convert each to 16-bit PCM wav
for root, dirnames, filenames in os.walk(source_dir):
for filename in fnmatch.filter(filenames, "*.sph"):
for channel in ['1', '2']:
sph_file = os.path.join(root, filename)
wav_filename = os.path.splitext(os.path.basename(sph_file))[0] + "-" + channel + ".wav"
wav_file = os.path.join(target_dir, wav_filename)
print("converting {} to {}".format(sph_file, wav_file))
subprocess.check_call(["sph2pipe", "-c", channel, "-p", "-f", "rif", sph_file, wav_file])
def _parse_transcriptions(trans_file):
segments = []
with open(trans_file, "r") as fin:
for line in fin:
if line.startswith("#") or len(line) <= 1:
continue
filename_time_beg = 0;
filename_time_end = line.find(" ", filename_time_beg)
start_time_beg = filename_time_end + 1
start_time_end = line.find(" ", start_time_beg)
stop_time_beg = start_time_end + 1
stop_time_end = line.find(" ", stop_time_beg)
transcript_beg = stop_time_end + 1
transcript_end = len(line)
if validate_label(line[transcript_beg:transcript_end].strip()) == None:
continue
segments.append({
"start_time": float(line[start_time_beg:start_time_end]),
"stop_time": float(line[stop_time_beg:stop_time_end]),
"speaker": line[6],
"transcript": line[transcript_beg:transcript_end].strip().lower(),
})
return segments
def _maybe_split_wav(data_dir, trans_data, original_data, converted_data):
trans_dir = os.path.join(data_dir, trans_data)
source_dir = os.path.join(data_dir, original_data)
target_dir = os.path.join(data_dir, converted_data)
if os.path.exists(target_dir):
print("skipping maybe_split_wav")
return
os.makedirs(target_dir)
# Loop over transcription files and split corresponding wav
for root, dirnames, filenames in os.walk(trans_dir):
for filename in fnmatch.filter(filenames, "*.text"):
if "trans" not in filename:
continue
trans_file = os.path.join(root, filename)
segments = _parse_transcriptions(trans_file)
# Open wav corresponding to transcription file
channel = ("2","1")[(os.path.splitext(os.path.basename(trans_file))[0])[6] == 'A']
wav_filename = "sw0" + (os.path.splitext(os.path.basename(trans_file))[0])[2:6] + "-" + channel + ".wav"
wav_file = os.path.join(source_dir, wav_filename)
print("splitting {} according to {}".format(wav_file, trans_file))
if not os.path.exists(wav_file):
print("skipping. does not exist:" + wav_file)
continue
origAudio = wave.open(wav_file, "r")
# Loop over segments and split wav_file for each segment
for segment in segments:
# Create wav segment filename
start_time = segment["start_time"]
stop_time = segment["stop_time"]
new_wav_filename = os.path.splitext(os.path.basename(trans_file))[0] + "-" + str(
start_time) + "-" + str(stop_time) + ".wav"
new_wav_file = os.path.join(target_dir, new_wav_filename)
# If the wav segment filename does not exist create it
if not os.path.exists(new_wav_file):
_split_wav(origAudio, start_time, stop_time, new_wav_file)
# Close origAudio
origAudio.close()
# Remove wav_file
# os.remove(wav_file)
def _split_wav(origAudio, start_time, stop_time, new_wav_file):
frameRate = origAudio.getframerate()
origAudio.setpos(int(start_time * frameRate))
chunkData = origAudio.readframes(int((stop_time - start_time) * frameRate))
chunkAudio = wave.open(new_wav_file, "w")
chunkAudio.setnchannels(origAudio.getnchannels())
chunkAudio.setsampwidth(origAudio.getsampwidth())
chunkAudio.setframerate(frameRate)
chunkAudio.writeframes(chunkData)
chunkAudio.close()
def _maybe_split_transcriptions(data_dir, original_data):
source_dir = os.path.join(data_dir, original_data)
wav_dirs = ["swb1_d1-split-wav", "swb1_d2-split-wav", "swb1_d3-split-wav", "swb1_d4-split-wav"]
if os.path.exists(os.path.join(source_dir, "split_transcriptions_done")):
print("skipping maybe_split_transcriptions")
return
# Loop over transcription files and split them into individual files for
# each utterance
for root, dirnames, filenames in os.walk(source_dir):
for filename in fnmatch.filter(filenames, "*.text"):
if "trans" not in filename:
continue
trans_file = os.path.join(root, filename)
segments = _parse_transcriptions(trans_file)
# Loop over segments and split wav_file for each segment
for segment in segments:
start_time = segment["start_time"]
stop_time = segment["stop_time"]
txt_filename = os.path.splitext(os.path.basename(trans_file))[0] + "-" + str(start_time) + "-" + str(
stop_time) + ".txt"
wav_filename = os.path.splitext(os.path.basename(trans_file))[0] + "-" + str(start_time) + "-" + str(
stop_time) + ".wav"
transcript = validate_label(segment["transcript"])
for wav_dir in wav_dirs:
if os.path.exists(os.path.join(data_dir, wav_dir, wav_filename)):
# If the transcript is valid and the txt segment filename does
# not exist create it
txt_file = os.path.join(data_dir, wav_dir, txt_filename)
if transcript != None and not os.path.exists(txt_file):
with open(txt_file, "w") as fout:
fout.write(transcript)
break
with open(os.path.join(source_dir, "split_transcriptions_done"), "w") as fout:
fout.write(
"This file signals to the importer than the transcription of this source dir has already been completed.")
def _maybe_split_sets(data_dir, original_data, converted_data):
target_dir = os.path.join(data_dir, converted_data)
if os.path.exists(target_dir):
return;
os.makedirs(target_dir)
filelist = []
for dir in original_data:
source_dir = os.path.join(data_dir, dir)
filelist.extend(glob(os.path.join(source_dir, "*.txt")))
# We initially split the entire set into 80% train and 20% test, then
# split the train set into 80% train and 20% validation.
train_beg = 0
train_end = int(0.8 * len(filelist))
dev_beg = int(0.8 * train_end)
dev_end = train_end
train_end = dev_beg
test_beg = dev_end
test_end = len(filelist)
_maybe_split_dataset(filelist[train_beg:train_end], os.path.join(target_dir, "train"))
_maybe_split_dataset(filelist[dev_beg:dev_end], os.path.join(target_dir, "dev"))
_maybe_split_dataset(filelist[test_beg:test_end], os.path.join(target_dir, "test"))
def _maybe_split_dataset(filelist, target_dir):
if not os.path.exists(target_dir):
os.makedirs(target_dir)
for txt_file in filelist:
new_txt_file = os.path.join(target_dir, os.path.basename(txt_file))
os.rename(txt_file, new_txt_file)
wav_file = os.path.splitext(txt_file)[0] + ".wav"
new_wav_file = os.path.join(target_dir, os.path.basename(wav_file))
os.rename(wav_file, new_wav_file)
def _read_data_set(work_dir, data_set, thread_count, batch_size, numcep, numcontext, limit=0):
# Create data set dir
dataset_dir = os.path.join(work_dir, data_set)
# Obtain list of txt files
txt_files = glob(os.path.join(work_dir, data_set, "*.txt"))
if limit > 0:
txt_files = txt_files[:limit]
# Return DataSet
return DataSet(txt_files, thread_count, batch_size, numcep, numcontext, dataset_dir)

View File

@ -173,19 +173,21 @@ def ctc_label_dense_to_sparse(labels, label_lengths, batch_size):
def validate_label(label):
# For now we can only handle [a-z ']
if "(" in label or \
"<" in label or \
"[" in label or \
"]" in label or \
"&" in label or \
"*" in label or \
re.search(r"[0-9]", label) != None:
return None
"<" in label or \
"[" in label or \
"]" in label or \
"&" in label or \
"*" in label or \
"{" in label or \
re.search(r"[0-9]", label) != None:
return None
label = label.replace("-", "")
label = label.replace("_", "")
label = label.replace(".", "")
label = label.replace(",", "")
label = label.replace("?", "")
label = label.strip()
return label