mirror of
https://github.com/mozilla/DeepSpeech.git
synced 2025-10-26 11:19:39 +00:00
Don't duplicate graph to do validation
This commit is contained in:
parent
757fc74e1e
commit
77100ed1df
@ -91,7 +91,7 @@
|
||||
"from util.log import merge_logs\n",
|
||||
"from util.gpu import get_available_gpus\n",
|
||||
"from util.shared_lib import check_cupti\n",
|
||||
"from util.text import sparse_tensor_value_to_texts, wers\n",
|
||||
"from util.text import ctc_label_dense_to_sparse, sparse_tensor_value_to_texts, wers\n",
|
||||
"from tensorflow.python.ops import ctc_ops\n",
|
||||
"from tensorflow.contrib.session_bundle import exporter\n",
|
||||
"\n",
|
||||
@ -169,7 +169,9 @@
|
||||
"source": [
|
||||
"Note that we use the Adam optimizer[[3]](http://arxiv.org/abs/1412.6980) instead of Nesterov’s Accelerated Gradient [[4]](http://www.cs.utoronto.ca/~ilya/pubs/2013/1051_2.pdf) used in the original DeepSpeech paper, as, at the time of writing, TensorFlow does not have an implementation of Nesterov’s Accelerated Gradient [[4]](http://www.cs.utoronto.ca/~ilya/pubs/2013/1051_2.pdf).\n",
|
||||
"\n",
|
||||
"As we will also employ dropout on the feedforward layers of the network, we need to define a parameter `dropout_rate` that keeps track of the dropout rate for these layers"
|
||||
"As we will also employ dropout on the feedforward layers of the network, we need to define a parameter `dropout_rate` that keeps track of the dropout rate for these layers.\n",
|
||||
"\n",
|
||||
"To avoid graph duplication when performing validation steps, we use a placeholder and an alternative feed_dict during validation to pull data from the corresponding queue."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -185,11 +187,20 @@
|
||||
"# This global placeholder will be used for all dropout definitions\n",
|
||||
"dropout_rate_placeholder = tf.placeholder(tf.float32)\n",
|
||||
"\n",
|
||||
"# The feed_dict used for training employs the given dropout_rate\n",
|
||||
"feed_dict_train = { dropout_rate_placeholder: dropout_rate }\n",
|
||||
"# This placeholder will be used to select between queues\n",
|
||||
"queue_selector_placeholder = tf.placeholder(tf.uint8)\n",
|
||||
"\n",
|
||||
"# While the feed_dict used for validation, test and train progress reporting employs zero dropout\n",
|
||||
"feed_dict = { dropout_rate_placeholder: 0.0 }"
|
||||
"# The feed_dict used for training employs the given dropout_rate\n",
|
||||
"feed_dict_train = { dropout_rate_placeholder: dropout_rate,\n",
|
||||
" queue_selector_placeholder: 0 }\n",
|
||||
"\n",
|
||||
"# The feed dict used for validation employs zero dropout and selects the validation queue\n",
|
||||
"feed_dict_validate = { dropout_rate_placeholder: 0.0,\n",
|
||||
" queue_selector_placeholder: 1 }\n",
|
||||
"\n",
|
||||
"# While the feed_dict used for test reporting employs zero dropout\n",
|
||||
"feed_dict_test = { dropout_rate_placeholder: 0.0,\n",
|
||||
" queue_selector_placeholder: 0 }"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -644,10 +655,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def calculate_accuracy_and_loss(batch_set):\n",
|
||||
" # Obtain the next batch of data\n",
|
||||
" batch_x, batch_seq_len, batch_y = batch_set.next_batch()\n",
|
||||
"\n",
|
||||
"def calculate_accuracy_and_loss(batch_x, batch_seq_len, batch_y):\n",
|
||||
" # Calculate the logits of the batch using BiRNN\n",
|
||||
" logits = BiRNN(batch_x, tf.to_int64(batch_seq_len))\n",
|
||||
" \n",
|
||||
@ -677,15 +685,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The first lines of `calculate_accuracy_and_loss()`\n",
|
||||
"```python\n",
|
||||
"def calculate_accuracy_and_loss(batch_set):\n",
|
||||
" # Obtain the next batch of data\n",
|
||||
" batch_x, batch_seq_len, batch_y = batch_set.next_batch()\n",
|
||||
"```\n",
|
||||
"simply obtain the next mini-batch of data.\n",
|
||||
"\n",
|
||||
"The next line\n",
|
||||
"The first line of `calculate_accuracy_and_loss()`\n",
|
||||
"```python\n",
|
||||
" # Calculate the logits from the BiRNN\n",
|
||||
" logits = BiRNN(batch_x, batch_seq_len)\n",
|
||||
@ -912,7 +912,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_tower_results(batch_set, optimizer=None):\n",
|
||||
"def get_tower_results(batch_sets, optimizer=None):\n",
|
||||
" # Tower decodings to return\n",
|
||||
" tower_decodings = []\n",
|
||||
" # Tower labels to return\n",
|
||||
@ -932,9 +932,18 @@
|
||||
" with tf.device(available_devices[i]):\n",
|
||||
" # Create a scope for all operations of tower i\n",
|
||||
" with tf.name_scope('tower_%d' % i) as scope:\n",
|
||||
" # Fetch the next batch of data\n",
|
||||
" batch_x, batch_x_seq_len, batch_y, batch_y_seq_len = \\\n",
|
||||
" tf.cond(tf.less(queue_selector_placeholder, 1),\n",
|
||||
" lambda: batch_sets[0].next_batch(),\n",
|
||||
" lambda: batch_sets[1].next_batch())\n",
|
||||
" \n",
|
||||
" # Calculate the avg_loss and accuracy and retrieve the decoded \n",
|
||||
" # batch along with the original batch's labels (Y) of this tower\n",
|
||||
" total_loss, avg_loss, accuracy, decoded, labels = calculate_accuracy_and_loss(batch_set)\n",
|
||||
" batch_y = ctc_label_dense_to_sparse(batch_y, batch_y_seq_len, batch_size)\n",
|
||||
" total_loss, avg_loss, accuracy, decoded, labels = calculate_accuracy_and_loss(batch_x,\n",
|
||||
" batch_x_seq_len,\n",
|
||||
" batch_y)\n",
|
||||
" \n",
|
||||
" # Allow for variables to be re-used by the next tower\n",
|
||||
" tf.get_variable_scope().reuse_variables()\n",
|
||||
@ -1155,7 +1164,7 @@
|
||||
"source": [
|
||||
"def get_results_params(data_set):\n",
|
||||
" # Get tower results\n",
|
||||
" tower_decodings, tower_labels, _, tower_total_losses, _, _ = get_tower_results(data_set)\n",
|
||||
" tower_decodings, tower_labels, _, tower_total_losses, _, _ = get_tower_results([data_set, data_set])\n",
|
||||
" # Join the individual results tensors into a results_params tuple\n",
|
||||
" return (tower_labels, tower_decodings, tower_total_losses)\n",
|
||||
" "
|
||||
@ -1268,7 +1277,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def run_inference(session, caption, data_set, results_params=None):\n",
|
||||
"def run_inference(session, caption, data_set, feed_dict, results_params=None):\n",
|
||||
" if results_params is None:\n",
|
||||
" # Get the data_set specific graph end-points\n",
|
||||
" results_params = get_results_params(data_set)\n",
|
||||
@ -1322,10 +1331,7 @@
|
||||
" tower_total_losses, \\\n",
|
||||
" tower_avg_losses, \\\n",
|
||||
" avg_accuracy \\\n",
|
||||
" = get_tower_results(data_sets.train, optimizer)\n",
|
||||
" \n",
|
||||
" # Validation results parameters\n",
|
||||
" dev_results_params = get_results_params(data_sets.dev)\n",
|
||||
" = get_tower_results([data_sets.train, data_sets.dev], optimizer)\n",
|
||||
" \n",
|
||||
" # Average tower gradients\n",
|
||||
" avg_tower_gradients = average_gradients(tower_gradients)\n",
|
||||
@ -1348,8 +1354,8 @@
|
||||
" # Start importer's queue threads\n",
|
||||
" data_sets.start_queue_threads(session)\n",
|
||||
" \n",
|
||||
" # Training results parameters\n",
|
||||
" train_results_params = (tower_labels, tower_decodings, tower_total_losses)\n",
|
||||
" # Result parameters\n",
|
||||
" results_params = (tower_labels, tower_decodings, tower_total_losses)\n",
|
||||
" \n",
|
||||
" # Prepare tensor board logging\n",
|
||||
" merged = tf.merge_all_summaries()\n",
|
||||
@ -1382,7 +1388,7 @@
|
||||
" # Create training results tuple\n",
|
||||
" train_results = ([],[],[])\n",
|
||||
" # Extend the session.run parameters\n",
|
||||
" params.append(train_results_params)\n",
|
||||
" params.append(results_params)\n",
|
||||
"\n",
|
||||
" # Loop over the batches\n",
|
||||
" for batch in range(int(ceil(batches_per_device))):\n",
|
||||
@ -1417,8 +1423,7 @@
|
||||
" \n",
|
||||
" # Validation step\n",
|
||||
" if epoch % validation_step == 0:\n",
|
||||
" dev_wer = run_inference(session, \"Validation\", data_sets.dev, results_params=dev_results_params)\n",
|
||||
" \n",
|
||||
" dev_wer = run_inference(session, \"Validation\", data_sets.dev, feed_dict_validate, results_params)\n",
|
||||
"\n",
|
||||
" # Checkpoint the model\n",
|
||||
" if (epoch % checkpoint_step == 0) or (epoch == training_iters - 1):\n",
|
||||
@ -1469,7 +1474,7 @@
|
||||
" duration = duration.days * 86400 + duration.seconds\n",
|
||||
" \n",
|
||||
" # Finally the model is tested against some unbiased data-set\n",
|
||||
" test_wer = run_inference(session, \"Test\", data_sets.test)"
|
||||
" test_wer = run_inference(session, \"Test\", data_sets.test, feed_dict_test)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@ -82,8 +82,7 @@ class DataSet(object):
|
||||
|
||||
def next_batch(self):
|
||||
source, source_lengths, target, target_lengths = self._example_queue.dequeue_many(self._batch_size)
|
||||
sparse_labels = ctc_label_dense_to_sparse(target, target_lengths, self._batch_size)
|
||||
return source, source_lengths, sparse_labels
|
||||
return source, source_lengths, target, target_lengths
|
||||
|
||||
@property
|
||||
def total_batches(self):
|
||||
|
||||
@ -98,8 +98,7 @@ class DataSet(object):
|
||||
|
||||
def next_batch(self):
|
||||
source, source_lengths, target, target_lengths = self._example_queue.dequeue_many(self._batch_size)
|
||||
sparse_labels = ctc_label_dense_to_sparse(target, target_lengths, self._batch_size)
|
||||
return source, source_lengths, sparse_labels
|
||||
return source, source_lengths, target, target_lengths
|
||||
|
||||
@property
|
||||
def total_batches(self):
|
||||
|
||||
@ -106,8 +106,7 @@ class DataSet(object):
|
||||
|
||||
def next_batch(self):
|
||||
source, source_lengths, target, target_lengths = self._example_queue.dequeue_many(self._batch_size)
|
||||
sparse_labels = ctc_label_dense_to_sparse(target, target_lengths, self._batch_size)
|
||||
return source, source_lengths, sparse_labels
|
||||
return source, source_lengths, target, target_lengths
|
||||
|
||||
@property
|
||||
def total_batches(self):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user