Localizeable validate_label

Fixes #2804
This commit is contained in:
Alexandre Lissy 2020-03-10 16:22:08 +01:00 committed by Alexandre Lissy
parent f9e05fe0c3
commit ce59228824
15 changed files with 81 additions and 18 deletions

View File

@ -18,7 +18,7 @@ from os import path
from threading import RLock
from multiprocessing.dummy import Pool
from multiprocessing import cpu_count
from util.text import validate_label
from util.importers import validate_label_eng as validate_label
from util.downloader import maybe_download, SIMPLE_BAR
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']

View File

@ -26,7 +26,7 @@ from multiprocessing.dummy import Pool
from multiprocessing import cpu_count
from util.downloader import SIMPLE_BAR
from util.text import Alphabet
from util.importers import get_importers_parser, validate_label_eng as validate_label
from util.importers import get_importers_parser, get_validate_label
from util.helpers import secs_to_hours
@ -144,6 +144,7 @@ if __name__ == "__main__":
PARSER.add_argument('--space_after_every_character', action='store_true', help='To help transcript join by white space')
PARAMS = PARSER.parse_args()
validate_label = get_validate_label(PARAMS)
AUDIO_DIR = PARAMS.audio_dir if PARAMS.audio_dir else os.path.join(PARAMS.tsv_dir, 'clips')
ALPHABET = Alphabet(PARAMS.filter_alphabet) if PARAMS.filter_alphabet else None

View File

@ -19,7 +19,7 @@ import unicodedata
import librosa
import soundfile # <= Has an external dependency on libsndfile
from util.text import validate_label
from util.importers import validate_label_eng as validate_label
def _download_and_preprocess_data(data_dir):
# Assume data_dir contains extracted LDC2004S13, LDC2004T19, LDC2005S13, LDC2005T19

View File

@ -10,7 +10,7 @@ import csv
import math
import urllib
import logging
from util.importers import get_importers_parser
from util.importers import get_importers_parser, get_validate_label
import subprocess
from os import path
from pathlib import Path
@ -19,8 +19,6 @@ import swifter
import pandas as pd
from sox import Transformer
from util.text import validate_label
__version__ = "0.1.0"
_logger = logging.getLogger(__name__)
@ -290,6 +288,7 @@ def main(args):
args ([str]): command line parameter list
"""
args = parse_args(args)
validate_label = get_validate_label(args)
setup_logging(args.loglevel)
_logger.info("Starting GramVaani importer...")
_logger.info("Starting loading GramVaani csv...")

View File

@ -7,8 +7,9 @@ import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser
from util.importers import get_importers_parser, get_validate_label
import argparse
import csv
import re
import sox
@ -26,7 +27,7 @@ from os import path
from glob import glob
from util.downloader import maybe_download
from util.text import Alphabet, validate_label
from util.text import Alphabet
from util.helpers import secs_to_hours
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
@ -185,6 +186,7 @@ def handle_args():
if __name__ == "__main__":
CLI_ARGS = handle_args()
ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None
validate_label = get_validate_label(CLI_ARGS)
bogus_regexes = []
if CLI_ARGS.bogus_records:

View File

@ -9,7 +9,7 @@ import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser
from util.importers import get_importers_parser, get_validate_label
import csv
import subprocess
@ -26,7 +26,7 @@ from os import path
from glob import glob
from util.downloader import maybe_download
from util.text import Alphabet, validate_label
from util.text import Alphabet
from util.helpers import secs_to_hours
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
@ -182,6 +182,7 @@ if __name__ == "__main__":
CLI_ARGS = handle_args()
ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None
SKIP_LIST = filter(None, CLI_ARGS.skiplist.split(','))
validate_label = get_validate_label(CLI_ARGS)
def label_filter(label):
if CLI_ARGS.normalize:

View File

@ -7,7 +7,7 @@ import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser
from util.importers import get_importers_parser, get_validate_label
import csv
import re
@ -27,7 +27,7 @@ from os import path
from glob import glob
from util.downloader import maybe_download
from util.text import Alphabet, validate_label
from util.text import Alphabet
from util.helpers import secs_to_hours
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
@ -203,6 +203,7 @@ def handle_args():
if __name__ == "__main__":
CLI_ARGS = handle_args()
ALPHABET = Alphabet(CLI_ARGS.filter_alphabet) if CLI_ARGS.filter_alphabet else None
validate_label = get_validate_label(CLI_ARGS)
def label_filter(label):
if CLI_ARGS.normalize:

View File

@ -20,7 +20,7 @@ import wave
import codecs
import tarfile
import requests
from util.text import validate_label
from util.importers import validate_label_eng as validate_label
import librosa
import soundfile # <= Has an external dependency on libsndfile

View File

@ -27,7 +27,8 @@ from os import path
from glob import glob
from collections import Counter
from multiprocessing.pool import ThreadPool
from util.text import Alphabet, validate_label
from util.text import Alphabet
from util.importers import validate_label_eng as validate_label
from util.downloader import maybe_download, SIMPLE_BAR
SWC_URL = "https://www2.informatik.uni-hamburg.de/nats/pub/SWC/SWC_{language}.tar"

View File

@ -8,7 +8,7 @@ import re
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from util.importers import get_importers_parser
from util.importers import get_importers_parser, get_validate_label
import csv
import unidecode
@ -25,7 +25,6 @@ from util.downloader import SIMPLE_BAR
from os import path
from util.downloader import maybe_download
from util.text import validate_label
from util.helpers import secs_to_hours
FIELDNAMES = ['wav_filename', 'wav_filesize', 'transcript']
@ -193,4 +192,5 @@ def handle_args():
if __name__ == "__main__":
cli_args = handle_args()
validate_label = get_validate_label(cli_args)
_download_and_preprocess_data(cli_args.target_dir, cli_args.english_compatible)

View File

@ -21,7 +21,8 @@ import xml.etree.cElementTree as ET
from os import path
from collections import Counter
from util.text import Alphabet, validate_label
from util.text import Alphabet
from util.importers import validate_label_eng as validate_label
from util.downloader import maybe_download, SIMPLE_BAR
TUDA_VERSION = 'v2'

View File

@ -1 +1,2 @@
absl-py
argparse

View File

@ -1,10 +1,38 @@
import argparse
import importlib
import os
import re
import sys
def get_importers_parser(description):
parser = argparse.ArgumentParser(description=description)
parser.add_argument('--validate_label_locale', help='Path to a Python file defining a |validate_label| function for your locale. WARNING: THIS WILL ADD THIS FILE\'s DIRECTORY INTO PYTHONPATH.')
return parser
def get_validate_label(args):
"""
Expects an argparse.Namespace argument to search for validate_label_locale parameter.
If found, this will modify Python's library search path and add the directory of the
file pointed by the validate_label_locale argument.
:param args: The importer's CLI argument object
:type args: argparse.Namespace
:return: The user-supplied validate_label function
:type: function
"""
if 'validate_label_locale' not in args or (args.validate_label_locale is None):
print('WARNING: No --validate_label_locale specified, your might end with inconsistent dataset.')
return validate_label_eng
if not os.path.exists(os.path.abspath(args.validate_label_locale)):
print('ERROR: Inexistent --validate_label_locale specified. Please check.')
return None
module_dir = os.path.abspath(os.path.dirname(args.validate_label_locale))
sys.path.insert(1, module_dir)
fname = os.path.basename(args.validate_label_locale).replace('.py', '')
locale_module = importlib.import_module(fname, package=None)
return locale_module.validate_label
# Validate and normalize transcriptions. Returns a cleaned version of the label
# or None if it's invalid.
def validate_label_eng(label):

View File

@ -0,0 +1,2 @@
def validate_label(label):
return label

View File

@ -1,6 +1,7 @@
import unittest
from .importers import validate_label_eng
from argparse import Namespace
from .importers import validate_label_eng, get_validate_label
class TestValidateLabelEng(unittest.TestCase):
@ -8,5 +9,30 @@ class TestValidateLabelEng(unittest.TestCase):
label = validate_label_eng("this is a 1 2 3 test")
self.assertEqual(label, None)
class TestGetValidateLabel(unittest.TestCase):
def test_no_validate_label_locale(self):
f = get_validate_label(Namespace())
self.assertEqual(f('toto'), 'toto')
self.assertEqual(f('toto1234'), None)
self.assertEqual(f('toto1234[{[{[]'), None)
def test_validate_label_locale_default(self):
f = get_validate_label(Namespace(validate_label_locale=None))
self.assertEqual(f('toto'), 'toto')
self.assertEqual(f('toto1234'), None)
self.assertEqual(f('toto1234[{[{[]'), None)
def test_get_validate_label_missing(self):
args = Namespace(validate_label_locale='util/test_data/validate_locale_ger.py')
f = get_validate_label(args)
self.assertEqual(f, None)
def test_get_validate_label(self):
args = Namespace(validate_label_locale='util/test_data/validate_locale_fra.py')
f = get_validate_label(args)
l = f('toto')
self.assertEqual(l, 'toto')
if __name__ == '__main__':
unittest.main()