Changes addressing PR comments

This commit is contained in:
Tilman Kamp 2016-11-28 16:54:07 +01:00
parent 313a71e9da
commit 2e26e19dc3

View File

@ -98,10 +98,6 @@
"ds_importer = os.environ.get('ds_importer', 'ted')\n",
"ds_dataset_path = os.environ.get('ds_dataset_path', os.path.join('./data', ds_importer))\n",
"\n",
"ds_limit_train = int(os.environ.get('ds_limit_train', 0))\n",
"ds_limit_dev = int(os.environ.get('ds_limit_dev', 0))\n",
"ds_limit_test = int(os.environ.get('ds_limit_test', 0))\n",
"\n",
"import importlib\n",
"ds_importer_module = importlib.import_module('util.importers.%s' % ds_importer)\n",
"\n",
@ -132,6 +128,9 @@
"* `test_batch_size` - The number of elements in a test batch\n",
"* `display_step` - The number of epochs we cycle through before displaying progress\n",
"* `checkpoint_step` - The number of epochs we cycle through before checkpointing the model\n",
"* `limit_train` - The maximum amount of samples taken from (the beginning of) the train set - 0 meaning no limit\n",
"* `limit_dev` - The maximum amount of samples taken from (the beginning of) the validation set - 0 meaning no limit\n",
"* `limit_test` - The maximum amount of samples taken from (the beginning of) the test set - 0 meaning no limit\n",
"* `checkpoint_dir` - The directory in which checkpoints are stored\n",
"* `restore_checkpoint` - Whether to resume from checkpoints when training\n",
"* `export_dir` - The directory in which exported models are stored\n",
@ -160,6 +159,9 @@
"display_step = int(os.environ.get('ds_display_step', 1)) # TODO: Determine a reasonable value for this\n",
"validation_step = int(os.environ.get('ds_validation_step', 1)) # TODO: Determine a reasonable value for this\n",
"checkpoint_step = int(os.environ.get('ds_checkpoint_step', 5)) # TODO: Determine a reasonable value for this\n",
"limit_train = int(os.environ.get('ds_limit_train', 0))\n",
"limit_dev = int(os.environ.get('ds_limit_dev', 0))\n",
"limit_test = int(os.environ.get('ds_limit_test', 0))\n",
"checkpoint_dir = os.environ.get('ds_checkpoint_dir', xdg.save_data_path('deepspeech'))\n",
"restore_checkpoint = bool(int(os.environ.get('ds_restore_checkpoint', 0)))\n",
"export_dir = os.environ.get('ds_export_dir', None)\n",
@ -1324,9 +1326,9 @@
" test_batch_size, \\\n",
" n_input, \\\n",
" n_context, \\\n",
" limit_dev=ds_limit_dev, \\\n",
" limit_test=ds_limit_test, \\\n",
" limit_train=ds_limit_train)\n",
" limit_dev=limit_dev, \\\n",
" limit_test=limit_test, \\\n",
" limit_train=limit_train)\n",
"\n",
"def read_data_set(set_name):\n",
" # Obtain all the data sets\n",