mirror of
https://github.com/mozilla/DeepSpeech.git
synced 2025-10-26 11:19:39 +00:00
parent
f9e05fe0c3
commit
ce59228824
@ -18,7 +18,7 @@ from os import path
|
||||
from threading import RLock
|
||||
from multiprocessing.dummy import Pool
|
||||
from multiprocessing import cpu_count
|
||||
from util.text import validate_label
|
||||
from util.importers import validate_label_eng as validate_label
|
||||
from util.downloader import maybe_download, SIMPLE_BAR
|
||||
|
||||
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
|
||||
@ -26,7 +26,7 @@ from multiprocessing.dummy import Pool
|
||||
from multiprocessing import cpu_count
|
||||
from util.downloader import SIMPLE_BAR
|
||||
from util.text import Alphabet
|
||||
from util.importers import get_importers_parser, validate_label_eng as validate_label
|
||||
from util.importers import get_importers_parser, get_validate_label
|
||||
from util.helpers import secs_to_hours
|
||||
|
||||
|
||||
@ -144,6 +144,7 @@ if __name__ == "__main__":
|
||||
PARSER.add_argument('--space_after_every_character', action='store_true', help='To help transcript join by white space')
|
||||
|
||||
PARAMS = PARSER.parse_args()
|
||||
validate_label = get_validate_label(PARAMS)
|
||||
|
||||
AUDIO_DIR = PARAMS.audio_dir if PARAMS.audio_dir else os.path.join(PARAMS.tsv_dir, 'clips')
|
||||
ALPHABET = Alphabet(PARAMS.filter_alphabet) if PARAMS.filter_alphabet else None
|
||||
|
||||
@ -19,7 +19,7 @@ import unicodedata
|
||||
import librosa
|
||||
import soundfile # <= Has an external dependency on libsndfile
|
||||
|
||||
from util.text import validate_label
|
||||
from util.importers import validate_label_eng as validate_label
|
||||
|
||||
def _download_and_preprocess_data(data_dir):
|
||||
# Assume data_dir contains extracted LDC2004S13, LDC2004T19, LDC2005S13, LDC2005T19
|
||||
|
||||
@ -10,7 +10,7 @@ import csv
|
||||
import math
|
||||
import urllib
|
||||
import logging
|
||||
from util.importers import get_importers_parser
|
||||
from util.importers import get_importers_parser, get_validate_label
|
||||
import subprocess
|
||||
from os import path
|
||||
from pathlib import Path
|
||||
@ -19,8 +19,6 @@ import swifter
|
||||
import pandas as pd
|
||||
from sox import Transformer
|
||||
|
||||
from util.text import validate_label
|
||||
|
||||
|
||||
__version__ = "0.1.0"
|
||||
_logger = logging.getLogger(__name__)
|
||||
@ -290,6 +288,7 @@ def main(args):
|
||||
args ([str]): command line parameter list
|
||||
"""
|
||||
args = parse_args(args)
|
||||
validate_label = get_validate_label(args)
|
||||
setup_logging(args.loglevel)
|
||||
_logger.info("Starting GramVaani importer...")
|
||||
_logger.info("Starting loading GramVaani csv...")
|
||||
|
||||
@ -7,8 +7,9 @@ import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
from util.importers import get_importers_parser
|
||||
from util.importers import get_importers_parser, get_validate_label
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import re
|
||||
import sox
|
||||
@ -26,7 +27,7 @@ from os import path
|
||||
from glob import glob
|
||||
|
||||
from util.downloader import maybe_download
|
||||
from util.text import Alphabet, validate_label
|
||||
from util.text import Alphabet
|
||||
from util.helpers import secs_to_hours
|
||||
|
||||
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
@ -185,6 +186,7 @@ def handle_args():
|
||||
if __name__ == "__main__":
|
||||
CLI_ARGS = handle_args()
|
||||
ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None
|
||||
validate_label = get_validate_label(CLI_ARGS)
|
||||
|
||||
bogus_regexes = []
|
||||
if CLI_ARGS.bogus_records:
|
||||
|
||||
@ -9,7 +9,7 @@ import sys
|
||||
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
from util.importers import get_importers_parser
|
||||
from util.importers import get_importers_parser, get_validate_label
|
||||
|
||||
import csv
|
||||
import subprocess
|
||||
@ -26,7 +26,7 @@ from os import path
|
||||
from glob import glob
|
||||
|
||||
from util.downloader import maybe_download
|
||||
from util.text import Alphabet, validate_label
|
||||
from util.text import Alphabet
|
||||
from util.helpers import secs_to_hours
|
||||
|
||||
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
@ -182,6 +182,7 @@ if __name__ == "__main__":
|
||||
CLI_ARGS = handle_args()
|
||||
ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None
|
||||
SKIP_LIST = filter(None, CLI_ARGS.skiplist.split(','))
|
||||
validate_label = get_validate_label(CLI_ARGS)
|
||||
|
||||
def label_filter(label):
|
||||
if CLI_ARGS.normalize:
|
||||
|
||||
@ -7,7 +7,7 @@ import os
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
from util.importers import get_importers_parser
|
||||
from util.importers import get_importers_parser, get_validate_label
|
||||
|
||||
import csv
|
||||
import re
|
||||
@ -27,7 +27,7 @@ from os import path
|
||||
from glob import glob
|
||||
|
||||
from util.downloader import maybe_download
|
||||
from util.text import Alphabet, validate_label
|
||||
from util.text import Alphabet
|
||||
from util.helpers import secs_to_hours
|
||||
|
||||
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
@ -203,6 +203,7 @@ def handle_args():
|
||||
if __name__ == "__main__":
|
||||
CLI_ARGS = handle_args()
|
||||
ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None
|
||||
validate_label = get_validate_label(CLI_ARGS)
|
||||
|
||||
def label_filter(label):
|
||||
if CLI_ARGS.normalize:
|
||||
|
||||
@ -20,7 +20,7 @@ import wave
|
||||
import codecs
|
||||
import tarfile
|
||||
import requests
|
||||
from util.text import validate_label
|
||||
from util.importers import validate_label_eng as validate_label
|
||||
import librosa
|
||||
import soundfile # <= Has an external dependency on libsndfile
|
||||
|
||||
|
||||
@ -27,7 +27,8 @@ from os import path
|
||||
from glob import glob
|
||||
from collections import Counter
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from util.text import Alphabet, validate_label
|
||||
from util.text import Alphabet
|
||||
from util.importers import validate_label_eng as validate_label
|
||||
from util.downloader import maybe_download, SIMPLE_BAR
|
||||
|
||||
SWC_URL = "https://www2.informatik.uni-hamburg.de/nats/pub/SWC/SWC_{language}.tar"
|
||||
|
||||
@ -8,7 +8,7 @@ import re
|
||||
import sys
|
||||
sys.path.insert(1, os.path.join(sys.path[0], '..'))
|
||||
|
||||
from util.importers import get_importers_parser
|
||||
from util.importers import get_importers_parser, get_validate_label
|
||||
|
||||
import csv
|
||||
import unidecode
|
||||
@ -25,7 +25,6 @@ from util.downloader import SIMPLE_BAR
|
||||
from os import path
|
||||
|
||||
from util.downloader import maybe_download
|
||||
from util.text import validate_label
|
||||
from util.helpers import secs_to_hours
|
||||
|
||||
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
|
||||
@ -193,4 +192,5 @@ def handle_args():
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli_args = handle_args()
|
||||
validate_label = get_validate_label(cli_args)
|
||||
_download_and_preprocess_data(cli_args.target_dir, cli_args.english_compatible)
|
||||
|
||||
@ -21,7 +21,8 @@ import xml.etree.cElementTree as ET
|
||||
|
||||
from os import path
|
||||
from collections import Counter
|
||||
from util.text import Alphabet, validate_label
|
||||
from util.text import Alphabet
|
||||
from util.importers import validate_label_eng as validate_label
|
||||
from util.downloader import maybe_download, SIMPLE_BAR
|
||||
|
||||
TUDA_VERSION = 'v2'
|
||||
|
||||
@ -1 +1,2 @@
|
||||
absl-py
|
||||
argparse
|
||||
|
||||
@ -1,10 +1,38 @@
|
||||
import argparse
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
def get_importers_parser(description):
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
parser.add_argument('--validate_label_locale', help='Path to a Python file defining a |validate_label| function for your locale. WARNING: THIS WILL ADD THIS FILE\'s DIRECTORY INTO PYTHONPATH.')
|
||||
return parser
|
||||
|
||||
def get_validate_label(args):
|
||||
"""
|
||||
Expects an argparse.Namespace argument to search for validate_label_locale parameter.
|
||||
If found, this will modify Python's library search path and add the directory of the
|
||||
file pointed by the validate_label_locale argument.
|
||||
|
||||
:param args: The importer's CLI argument object
|
||||
:type args: argparse.Namespace
|
||||
|
||||
:return: The user-supplied validate_label function
|
||||
:type: function
|
||||
"""
|
||||
if 'validate_label_locale' not in args or (args.validate_label_locale is None):
|
||||
print('WARNING: No --validate_label_locale specified, your might end with inconsistent dataset.')
|
||||
return validate_label_eng
|
||||
if not os.path.exists(os.path.abspath(args.validate_label_locale)):
|
||||
print('ERROR: Inexistent --validate_label_locale specified. Please check.')
|
||||
return None
|
||||
module_dir = os.path.abspath(os.path.dirname(args.validate_label_locale))
|
||||
sys.path.insert(1, module_dir)
|
||||
fname = os.path.basename(args.validate_label_locale).replace('.py', '')
|
||||
locale_module = importlib.import_module(fname, package=None)
|
||||
return locale_module.validate_label
|
||||
|
||||
# Validate and normalize transcriptions. Returns a cleaned version of the label
|
||||
# or None if it's invalid.
|
||||
def validate_label_eng(label):
|
||||
|
||||
2
util/test_data/validate_locale_fra.py
Normal file
2
util/test_data/validate_locale_fra.py
Normal file
@ -0,0 +1,2 @@
|
||||
def validate_label(label):
|
||||
return label
|
||||
@ -1,6 +1,7 @@
|
||||
import unittest
|
||||
|
||||
from .importers import validate_label_eng
|
||||
from argparse import Namespace
|
||||
from .importers import validate_label_eng, get_validate_label
|
||||
|
||||
class TestValidateLabelEng(unittest.TestCase):
|
||||
|
||||
@ -8,5 +9,30 @@ class TestValidateLabelEng(unittest.TestCase):
|
||||
label = validate_label_eng("this is a 1 2 3 test")
|
||||
self.assertEqual(label, None)
|
||||
|
||||
class TestGetValidateLabel(unittest.TestCase):
|
||||
|
||||
def test_no_validate_label_locale(self):
|
||||
f = get_validate_label(Namespace())
|
||||
self.assertEqual(f('toto'), 'toto')
|
||||
self.assertEqual(f('toto1234'), None)
|
||||
self.assertEqual(f('toto1234[{[{[]'), None)
|
||||
|
||||
def test_validate_label_locale_default(self):
|
||||
f = get_validate_label(Namespace(validate_label_locale=None))
|
||||
self.assertEqual(f('toto'), 'toto')
|
||||
self.assertEqual(f('toto1234'), None)
|
||||
self.assertEqual(f('toto1234[{[{[]'), None)
|
||||
|
||||
def test_get_validate_label_missing(self):
|
||||
args = Namespace(validate_label_locale='util/test_data/validate_locale_ger.py')
|
||||
f = get_validate_label(args)
|
||||
self.assertEqual(f, None)
|
||||
|
||||
def test_get_validate_label(self):
|
||||
args = Namespace(validate_label_locale='util/test_data/validate_locale_fra.py')
|
||||
f = get_validate_label(args)
|
||||
l = f('toto')
|
||||
self.assertEqual(l, 'toto')
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user