mirror of
https://github.com/mozilla/DeepSpeech.git
synced 2025-10-26 11:19:39 +00:00
118 lines
7.3 KiB
Python
118 lines
7.3 KiB
Python
from __future__ import absolute_import, division, print_function
|
|
|
|
import tensorflow as tf
|
|
|
|
|
|
FLAGS = tf.app.flags.FLAGS
|
|
|
|
|
|
def create_flags():
|
|
# Importer
|
|
# ========
|
|
|
|
f = tf.app.flags
|
|
|
|
f.DEFINE_string('train_files', '', 'comma separated list of files specifying the dataset used for training. Multiple files will get merged. If empty, training will not be run.')
|
|
f.DEFINE_string('dev_files', '', 'comma separated list of files specifying the dataset used for validation. Multiple files will get merged. If empty, validation will not be run.')
|
|
f.DEFINE_string('test_files', '', 'comma separated list of files specifying the dataset used for testing. Multiple files will get merged. If empty, the model will not be tested.')
|
|
|
|
f.DEFINE_string('feature_cache', '', 'path where cached features extracted from --train_files will be saved. If empty, caching will be done in memory and no files will be written.')
|
|
|
|
f.DEFINE_integer('feature_win_len', 32, 'feature extraction audio window length in milliseconds')
|
|
f.DEFINE_integer('feature_win_step', 20, 'feature extraction window step length in milliseconds')
|
|
f.DEFINE_integer('audio_sample_rate', 16000, 'sample rate value expected by model')
|
|
|
|
# Global Constants
|
|
# ================
|
|
|
|
f.DEFINE_integer('epochs', 75, 'how many epochs (complete runs through the train files) to train for')
|
|
|
|
f.DEFINE_float('dropout_rate', 0.05, 'dropout rate for feedforward layers')
|
|
f.DEFINE_float('dropout_rate2', -1.0, 'dropout rate for layer 2 - defaults to dropout_rate')
|
|
f.DEFINE_float('dropout_rate3', -1.0, 'dropout rate for layer 3 - defaults to dropout_rate')
|
|
f.DEFINE_float('dropout_rate4', 0.0, 'dropout rate for layer 4 - defaults to 0.0')
|
|
f.DEFINE_float('dropout_rate5', 0.0, 'dropout rate for layer 5 - defaults to 0.0')
|
|
f.DEFINE_float('dropout_rate6', -1.0, 'dropout rate for layer 6 - defaults to dropout_rate')
|
|
|
|
f.DEFINE_float('relu_clip', 20.0, 'ReLU clipping value for non-recurrent layers')
|
|
|
|
# Adam optimizer(http://arxiv.org/abs/1412.6980) parameters
|
|
|
|
f.DEFINE_float('beta1', 0.9, 'beta 1 parameter of Adam optimizer')
|
|
f.DEFINE_float('beta2', 0.999, 'beta 2 parameter of Adam optimizer')
|
|
f.DEFINE_float('epsilon', 1e-8, 'epsilon parameter of Adam optimizer')
|
|
f.DEFINE_float('learning_rate', 0.001, 'learning rate of Adam optimizer')
|
|
|
|
# Batch sizes
|
|
|
|
f.DEFINE_integer('train_batch_size', 1, 'number of elements in a training batch')
|
|
f.DEFINE_integer('dev_batch_size', 1, 'number of elements in a validation batch')
|
|
f.DEFINE_integer('test_batch_size', 1, 'number of elements in a test batch')
|
|
|
|
f.DEFINE_integer('export_batch_size', 1, 'number of elements per batch on the exported graph')
|
|
|
|
# Performance(UNSUPPORTED)
|
|
f.DEFINE_integer('inter_op_parallelism_threads', 0, 'number of inter-op parallelism threads - see tf.ConfigProto for more details')
|
|
f.DEFINE_integer('intra_op_parallelism_threads', 0, 'number of intra-op parallelism threads - see tf.ConfigProto for more details')
|
|
|
|
# Sample limits
|
|
|
|
f.DEFINE_integer('limit_train', 0, 'maximum number of elements to use from train set - 0 means no limit')
|
|
f.DEFINE_integer('limit_dev', 0, 'maximum number of elements to use from validation set- 0 means no limit')
|
|
f.DEFINE_integer('limit_test', 0, 'maximum number of elements to use from test set- 0 means no limit')
|
|
|
|
# Checkpointing
|
|
|
|
f.DEFINE_string('checkpoint_dir', '', 'directory in which checkpoints are stored - defaults to directory "deepspeech/checkpoints" within user\'s data home specified by the XDG Base Directory Specification')
|
|
f.DEFINE_integer('checkpoint_secs', 600, 'checkpoint saving interval in seconds')
|
|
f.DEFINE_integer('max_to_keep', 5, 'number of checkpoint files to keep - default value is 5')
|
|
f.DEFINE_string('load', 'auto', '"last" for loading most recent epoch checkpoint, "best" for loading best validated checkpoint, "init" for initializing a fresh model, "auto" for trying the other options in order last > best > init')
|
|
|
|
# Exporting
|
|
|
|
f.DEFINE_string('export_dir', '', 'directory in which exported models are stored - if omitted, the model won\'t get exported')
|
|
f.DEFINE_integer('export_version', 1, 'version number of the exported model')
|
|
f.DEFINE_boolean('remove_export', False, 'whether to remove old exported models')
|
|
f.DEFINE_boolean('export_tflite', False, 'export a graph ready for TF Lite engine')
|
|
f.DEFINE_boolean('use_seq_length', True, 'have sequence_length in the exported graph(will make tfcompile unhappy)')
|
|
f.DEFINE_integer('n_steps', 16, 'how many timesteps to process at once by the export graph, higher values mean more latency')
|
|
f.DEFINE_string('export_language', '', 'language the model was trained on e.g. "en" or "English". Gets embedded into exported model.')
|
|
|
|
# Reporting
|
|
|
|
f.DEFINE_integer('log_level', 1, 'log level for console logs - 0: INFO, 1: WARN, 2: ERROR, 3: FATAL')
|
|
f.DEFINE_boolean('show_progressbar', True, 'Show progress for training, validation and testing processes. Log level should be > 0.')
|
|
|
|
f.DEFINE_boolean('log_placement', False, 'whether to log device placement of the operators to the console')
|
|
f.DEFINE_integer('report_count', 10, 'number of phrases with lowest WER(best matching) to print out during a WER report')
|
|
|
|
f.DEFINE_string('summary_dir', '', 'target directory for TensorBoard summaries - defaults to directory "deepspeech/summaries" within user\'s data home specified by the XDG Base Directory Specification')
|
|
|
|
# Geometry
|
|
|
|
f.DEFINE_integer('n_hidden', 2048, 'layer width to use when initialising layers')
|
|
|
|
# Initialization
|
|
|
|
f.DEFINE_integer('random_seed', 4568, 'default random seed that is used to initialize variables')
|
|
|
|
# Early Stopping
|
|
|
|
f.DEFINE_boolean('early_stop', True, 'enable early stopping mechanism over validation dataset. If validation is not being run, early stopping is disabled.')
|
|
f.DEFINE_integer('es_steps', 4, 'number of validations to consider for early stopping. Loss is not stored in the checkpoint so when checkpoint is revived it starts the loss calculation from start at that point')
|
|
f.DEFINE_float('es_mean_th', 0.5, 'mean threshold for loss to determine the condition if early stopping is required')
|
|
f.DEFINE_float('es_std_th', 0.5, 'standard deviation threshold for loss to determine the condition if early stopping is required')
|
|
|
|
# Decoder
|
|
|
|
f.DEFINE_string('alphabet_config_path', 'data/alphabet.txt', 'path to the configuration file specifying the alphabet used by the network. See the comment in data/alphabet.txt for a description of the format.')
|
|
f.DEFINE_string('lm_binary_path', 'data/lm/lm.binary', 'path to the language model binary file created with KenLM')
|
|
f.DEFINE_string('lm_trie_path', 'data/lm/trie', 'path to the language model trie file created with native_client/generate_trie')
|
|
f.DEFINE_integer('beam_width', 1024, 'beam width used in the CTC decoder when building candidate transcriptions')
|
|
f.DEFINE_float('lm_alpha', 0.75, 'the alpha hyperparameter of the CTC decoder. Language Model weight.')
|
|
f.DEFINE_float('lm_beta', 1.85, 'the beta hyperparameter of the CTC decoder. Word insertion weight.')
|
|
|
|
# Inference mode
|
|
|
|
f.DEFINE_string('one_shot_infer', '', 'one-shot inference mode: specify a wav file and the script will load the checkpoint and perform inference on it.')
|