DeepSpeech/util/text.py
2019-04-11 07:02:21 -03:00

133 lines
4.4 KiB
Python

from __future__ import absolute_import, division, print_function
import codecs
import re
import numpy as np
from six.moves import range
class Alphabet(object):
def __init__(self, config_file):
self._config_file = config_file
self._label_to_str = []
self._str_to_label = {}
self._size = 0
with codecs.open(config_file, 'r', 'utf-8') as fin:
for line in fin:
if line[0:2] == '\\#':
line = '#\n'
elif line[0] == '#':
continue
self._label_to_str += line[:-1] # remove the line ending
self._str_to_label[line[:-1]] = self._size
self._size += 1
def string_from_label(self, label):
return self._label_to_str[label]
def label_from_string(self, string):
try:
return self._str_to_label[string]
except KeyError as e:
raise KeyError(
'''ERROR: Your transcripts contain characters which do not occur in data/alphabet.txt! Use util/check_characters.py to see what characters are in your {train,dev,test}.csv transcripts, and then add all these to data/alphabet.txt.'''
).with_traceback(e.__traceback__)
def decode(self, labels):
res = ''
for label in labels:
res += self.string_from_label(label)
return res
def size(self):
return self._size
def config_file(self):
return self._config_file
def text_to_char_array(original, alphabet):
r"""
Given a Python string ``original``, remove unsupported characters, map characters
to integers and return a numpy array representing the processed string.
"""
return np.asarray([alphabet.label_from_string(c) for c in original])
def wer_cer_batch(originals, results):
r"""
The WER is defined as the editing/Levenshtein distance on word level
divided by the amount of words in the original text.
In case of the original having more words (N) than the result and both
being totally different (all N words resulting in 1 edit operation each),
the WER will always be 1 (N / N = 1).
"""
# The WER is calculated on word (and NOT on character) level.
# Therefore we split the strings into words first
assert len(originals) == len(results)
total_cer = 0.0
total_char_length = 0.0
total_wer = 0.0
total_word_length = 0.0
for original, result in zip(originals, results):
total_cer += levenshtein(original, result)
total_char_length += len(original)
total_wer += levenshtein(original.split(), result.split())
total_word_length += len(original.split())
return total_wer / total_word_length, total_cer / total_char_length
# The following code is from: http://hetland.org/coding/python/levenshtein.py
# This is a straightforward implementation of a well-known algorithm, and thus
# probably shouldn't be covered by copyright to begin with. But in case it is,
# the author (Magnus Lie Hetland) has, to the extent possible under law,
# dedicated all copyright and related and neighboring rights to this software
# to the public domain worldwide, by distributing it under the CC0 license,
# version 1.0. This software is distributed without any warranty. For more
# information, see <http://creativecommons.org/publicdomain/zero/1.0>
def levenshtein(a, b):
"Calculates the Levenshtein distance between a and b."
n, m = len(a), len(b)
if n > m:
# Make sure n <= m, to use O(min(n,m)) space
a, b = b, a
n, m = m, n
current = list(range(n+1))
for i in range(1, m+1):
previous, current = current, [i]+[0]*n
for j in range(1, n+1):
add, delete = previous[j]+1, current[j-1]+1
change = previous[j-1]
if a[j-1] != b[i-1]:
change = change + 1
current[j] = min(add, delete, change)
return current[n]
# Validate and normalize transcriptions. Returns a cleaned version of the label
# or None if it's invalid.
def validate_label(label):
# For now we can only handle [a-z ']
if re.search(r"[0-9]|[(<\[\]&*{]", label) is not None:
return None
label = label.replace("-", "")
label = label.replace("_", "")
label = label.replace(".", "")
label = label.replace(",", "")
label = label.replace("?", "")
label = label.replace("\"", "")
label = label.strip()
label = label.lower()
return label if label else None