diff --git a/bin/import_cv.py b/bin/import_cv.py index 7dd04d84..a9b9447e 100755 --- a/bin/import_cv.py +++ b/bin/import_cv.py @@ -18,7 +18,7 @@ from os import path from threading import RLock from multiprocessing.dummy import Pool from multiprocessing import cpu_count -from util.text import validate_label +from util.importers import validate_label_eng as validate_label from util.downloader import maybe_download, SIMPLE_BAR FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] diff --git a/bin/import_cv2.py b/bin/import_cv2.py index a4aba6bc..7f8222d7 100755 --- a/bin/import_cv2.py +++ b/bin/import_cv2.py @@ -26,7 +26,7 @@ from multiprocessing.dummy import Pool from multiprocessing import cpu_count from util.downloader import SIMPLE_BAR from util.text import Alphabet -from util.importers import get_importers_parser, validate_label_eng as validate_label +from util.importers import get_importers_parser, get_validate_label from util.helpers import secs_to_hours @@ -144,6 +144,7 @@ if __name__ == "__main__": PARSER.add_argument('--space_after_every_character', action='store_true', help='To help transcript join by white space') PARAMS = PARSER.parse_args() + validate_label = get_validate_label(PARAMS) AUDIO_DIR = PARAMS.audio_dir if PARAMS.audio_dir else os.path.join(PARAMS.tsv_dir, 'clips') ALPHABET = Alphabet(PARAMS.filter_alphabet) if PARAMS.filter_alphabet else None diff --git a/bin/import_fisher.py b/bin/import_fisher.py index e3340244..dd054765 100755 --- a/bin/import_fisher.py +++ b/bin/import_fisher.py @@ -19,7 +19,7 @@ import unicodedata import librosa import soundfile # <= Has an external dependency on libsndfile -from util.text import validate_label +from util.importers import validate_label_eng as validate_label def _download_and_preprocess_data(data_dir): # Assume data_dir contains extracted LDC2004S13, LDC2004T19, LDC2005S13, LDC2005T19 diff --git a/bin/import_gram_vaani.py b/bin/import_gram_vaani.py index 500ed5de..141478b8 100755 --- a/bin/import_gram_vaani.py +++ b/bin/import_gram_vaani.py @@ -10,7 +10,7 @@ import csv import math import urllib import logging -from util.importers import get_importers_parser +from util.importers import get_importers_parser, get_validate_label import subprocess from os import path from pathlib import Path @@ -19,8 +19,6 @@ import swifter import pandas as pd from sox import Transformer -from util.text import validate_label - __version__ = "0.1.0" _logger = logging.getLogger(__name__) @@ -290,6 +288,7 @@ def main(args): args ([str]): command line parameter list """ args = parse_args(args) + validate_label = get_validate_label(args) setup_logging(args.loglevel) _logger.info("Starting GramVaani importer...") _logger.info("Starting loading GramVaani csv...") diff --git a/bin/import_lingua_libre.py b/bin/import_lingua_libre.py index b9d6106a..bc11203d 100755 --- a/bin/import_lingua_libre.py +++ b/bin/import_lingua_libre.py @@ -7,8 +7,9 @@ import os import sys sys.path.insert(1, os.path.join(sys.path[0], '..')) -from util.importers import get_importers_parser +from util.importers import get_importers_parser, get_validate_label +import argparse import csv import re import sox @@ -26,7 +27,7 @@ from os import path from glob import glob from util.downloader import maybe_download -from util.text import Alphabet, validate_label +from util.text import Alphabet from util.helpers import secs_to_hours FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] @@ -185,6 +186,7 @@ def handle_args(): if __name__ == "__main__": CLI_ARGS = handle_args() ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None + validate_label = get_validate_label(CLI_ARGS) bogus_regexes = [] if CLI_ARGS.bogus_records: diff --git a/bin/import_m-ailabs.py b/bin/import_m-ailabs.py index 16d1bf54..540c8139 100755 --- a/bin/import_m-ailabs.py +++ b/bin/import_m-ailabs.py @@ -9,7 +9,7 @@ import sys sys.path.insert(1, os.path.join(sys.path[0], '..')) -from util.importers import get_importers_parser +from util.importers import get_importers_parser, get_validate_label import csv import subprocess @@ -26,7 +26,7 @@ from os import path from glob import glob from util.downloader import maybe_download -from util.text import Alphabet, validate_label +from util.text import Alphabet from util.helpers import secs_to_hours FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] @@ -182,6 +182,7 @@ if __name__ == "__main__": CLI_ARGS = handle_args() ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None SKIP_LIST = filter(None, CLI_ARGS.skiplist.split(',')) + validate_label = get_validate_label(CLI_ARGS) def label_filter(label): if CLI_ARGS.normalize: diff --git a/bin/import_slr57.py b/bin/import_slr57.py index 16bac05b..b5bbef9c 100755 --- a/bin/import_slr57.py +++ b/bin/import_slr57.py @@ -7,7 +7,7 @@ import os import sys sys.path.insert(1, os.path.join(sys.path[0], '..')) -from util.importers import get_importers_parser +from util.importers import get_importers_parser, get_validate_label import csv import re @@ -27,7 +27,7 @@ from os import path from glob import glob from util.downloader import maybe_download -from util.text import Alphabet, validate_label +from util.text import Alphabet from util.helpers import secs_to_hours FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] @@ -203,6 +203,7 @@ def handle_args(): if __name__ == "__main__": CLI_ARGS = handle_args() ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None + validate_label = get_validate_label(CLI_ARGS) def label_filter(label): if CLI_ARGS.normalize: diff --git a/bin/import_swb.py b/bin/import_swb.py index e4261aa2..b682ae30 100755 --- a/bin/import_swb.py +++ b/bin/import_swb.py @@ -20,7 +20,7 @@ import wave import codecs import tarfile import requests -from util.text import validate_label +from util.importers import validate_label_eng as validate_label import librosa import soundfile # <= Has an external dependency on libsndfile diff --git a/bin/import_swc.py b/bin/import_swc.py index 93410805..e5114156 100755 --- a/bin/import_swc.py +++ b/bin/import_swc.py @@ -27,7 +27,8 @@ from os import path from glob import glob from collections import Counter from multiprocessing.pool import ThreadPool -from util.text import Alphabet, validate_label +from util.text import Alphabet +from util.importers import validate_label_eng as validate_label from util.downloader import maybe_download, SIMPLE_BAR SWC_URL = "https://www2.informatik.uni-hamburg.de/nats/pub/SWC/SWC_{language}.tar" diff --git a/bin/import_ts.py b/bin/import_ts.py index 363e639e..d899f1a3 100755 --- a/bin/import_ts.py +++ b/bin/import_ts.py @@ -8,7 +8,7 @@ import re import sys sys.path.insert(1, os.path.join(sys.path[0], '..')) -from util.importers import get_importers_parser +from util.importers import get_importers_parser, get_validate_label import csv import unidecode @@ -25,7 +25,6 @@ from util.downloader import SIMPLE_BAR from os import path from util.downloader import maybe_download -from util.text import validate_label from util.helpers import secs_to_hours FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript'] @@ -193,4 +192,5 @@ def handle_args(): if __name__ == "__main__": cli_args = handle_args() + validate_label = get_validate_label(cli_args) _download_and_preprocess_data(cli_args.target_dir, cli_args.english_compatible) diff --git a/bin/import_tuda.py b/bin/import_tuda.py index 89590144..857be405 100755 --- a/bin/import_tuda.py +++ b/bin/import_tuda.py @@ -21,7 +21,8 @@ import xml.etree.cElementTree as ET from os import path from collections import Counter -from util.text import Alphabet, validate_label +from util.text import Alphabet +from util.importers import validate_label_eng as validate_label from util.downloader import maybe_download, SIMPLE_BAR TUDA_VERSION = 'v2' diff --git a/requirements_tests.txt b/requirements_tests.txt index b998a06a..1e472e22 100644 --- a/requirements_tests.txt +++ b/requirements_tests.txt @@ -1 +1,2 @@ absl-py +argparse diff --git a/util/importers.py b/util/importers.py index 9f7ba8df..3efec973 100644 --- a/util/importers.py +++ b/util/importers.py @@ -1,10 +1,38 @@ import argparse +import importlib +import os import re +import sys def get_importers_parser(description): parser = argparse.ArgumentParser(description=description) + parser.add_argument('--validate_label_locale', help='Path to a Python file defining a |validate_label| function for your locale. WARNING: THIS WILL ADD THIS FILE\'s DIRECTORY INTO PYTHONPATH.') return parser +def get_validate_label(args): + """ + Expects an argparse.Namespace argument to search for validate_label_locale parameter. + If found, this will modify Python's library search path and add the directory of the + file pointed by the validate_label_locale argument. + + :param args: The importer's CLI argument object + :type args: argparse.Namespace + + :return: The user-supplied validate_label function + :type: function + """ + if 'validate_label_locale' not in args or (args.validate_label_locale is None): + print('WARNING: No --validate_label_locale specified, your might end with inconsistent dataset.') + return validate_label_eng + if not os.path.exists(os.path.abspath(args.validate_label_locale)): + print('ERROR: Inexistent --validate_label_locale specified. Please check.') + return None + module_dir = os.path.abspath(os.path.dirname(args.validate_label_locale)) + sys.path.insert(1, module_dir) + fname = os.path.basename(args.validate_label_locale).replace('.py', '') + locale_module = importlib.import_module(fname, package=None) + return locale_module.validate_label + # Validate and normalize transcriptions. Returns a cleaned version of the label # or None if it's invalid. def validate_label_eng(label): diff --git a/util/test_data/validate_locale_fra.py b/util/test_data/validate_locale_fra.py new file mode 100644 index 00000000..4265fcde --- /dev/null +++ b/util/test_data/validate_locale_fra.py @@ -0,0 +1,2 @@ +def validate_label(label): + return label diff --git a/util/test_importers.py b/util/test_importers.py index 884c2193..281e4ee1 100644 --- a/util/test_importers.py +++ b/util/test_importers.py @@ -1,6 +1,7 @@ import unittest -from .importers import validate_label_eng +from argparse import Namespace +from .importers import validate_label_eng, get_validate_label class TestValidateLabelEng(unittest.TestCase): @@ -8,5 +9,30 @@ class TestValidateLabelEng(unittest.TestCase): label = validate_label_eng("this is a 1 2 3 test") self.assertEqual(label, None) +class TestGetValidateLabel(unittest.TestCase): + + def test_no_validate_label_locale(self): + f = get_validate_label(Namespace()) + self.assertEqual(f('toto'), 'toto') + self.assertEqual(f('toto1234'), None) + self.assertEqual(f('toto1234[{[{[]'), None) + + def test_validate_label_locale_default(self): + f = get_validate_label(Namespace(validate_label_locale=None)) + self.assertEqual(f('toto'), 'toto') + self.assertEqual(f('toto1234'), None) + self.assertEqual(f('toto1234[{[{[]'), None) + + def test_get_validate_label_missing(self): + args = Namespace(validate_label_locale='util/test_data/validate_locale_ger.py') + f = get_validate_label(args) + self.assertEqual(f, None) + + def test_get_validate_label(self): + args = Namespace(validate_label_locale='util/test_data/validate_locale_fra.py') + f = get_validate_label(args) + l = f('toto') + self.assertEqual(l, 'toto') + if __name__ == '__main__': unittest.main()