mirror of
https://github.com/mozilla/DeepSpeech.git
synced 2025-10-26 11:19:39 +00:00
Fix #15; logging top ten lowest loss samples
This commit is contained in:
parent
47bb4babde
commit
a7d009fef8
100
DeepSpeech.ipynb
100
DeepSpeech.ipynb
@ -602,7 +602,7 @@
|
||||
"source": [
|
||||
"In accord with [Deep Speech: Scaling up end-to-end speech recognition](http://arxiv.org/abs/1412.5567), the loss function used by our network should be the CTC loss function[[2]](http://www.cs.toronto.edu/~graves/preprint.pdf). Conveniently, this loss function is implemented in TensorFlow. Thus, we can simply make use of this implementation to define our loss.\n",
|
||||
"\n",
|
||||
"To do so we introduce a utility function `calculate_accuracy_and_loss()` that beam search decodes a mini-batch and calculates the average loss and accuracy. Next to loss and accuracy it returns the decoded result and the batch's original Y."
|
||||
"To do so we introduce a utility function `calculate_accuracy_and_loss()` that beam search decodes a mini-batch and calculates the loss and accuracy. Next to total and average loss it returns the accuracy, the decoded result and the batch's original Y."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -645,7 +645,7 @@
|
||||
" accuracy = tf.reduce_mean(distance)\n",
|
||||
"\n",
|
||||
" # Return results to the caller\n",
|
||||
" return avg_loss, accuracy, decoded, batch_y"
|
||||
" return total_loss, avg_loss, accuracy, decoded, batch_y"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -699,10 +699,10 @@
|
||||
" accuracy = tf.reduce_mean(distance)\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Finally, the `avg_loss`, accuracy, the decoded batch and the original batch's Y are returned to the caller\n",
|
||||
"Finally, the `total_loss`, `avg_loss`, `accuracy`, the `decoded` batch and the original `batch_y` are returned to the caller\n",
|
||||
"```python\n",
|
||||
" # Return results to the caller\n",
|
||||
" return avg_loss, accuracy, decoded, batch_y\n",
|
||||
" return total_loss, avg_loss, accuracy, decoded, batch_y\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
@ -892,6 +892,12 @@
|
||||
" tower_labels = []\n",
|
||||
" # Tower gradients to return\n",
|
||||
" tower_gradients = []\n",
|
||||
" # Tower total batch losses to return\n",
|
||||
" tower_total_losses = []\n",
|
||||
" # Tower avg batch losses to return\n",
|
||||
" tower_avg_losses = []\n",
|
||||
" # Tower accuracies to return\n",
|
||||
" tower_accuracies = []\n",
|
||||
" \n",
|
||||
" # Loop over available_devices\n",
|
||||
" for i in xrange(len(available_devices)):\n",
|
||||
@ -901,7 +907,7 @@
|
||||
" with tf.name_scope('tower_%d' % i) as scope:\n",
|
||||
" # Calculate the avg_loss and accuracy and retrieve the decoded \n",
|
||||
" # batch along with the original batch's labels (Y) of this tower\n",
|
||||
" avg_loss, accuracy, decoded, labels = calculate_accuracy_and_loss(batch_set)\n",
|
||||
" total_loss, avg_loss, accuracy, decoded, labels = calculate_accuracy_and_loss(batch_set)\n",
|
||||
" \n",
|
||||
" # Allow for variables to be re-used by the next tower\n",
|
||||
" tf.get_variable_scope().reuse_variables()\n",
|
||||
@ -920,9 +926,18 @@
|
||||
"\n",
|
||||
" # Retain tower's gradients\n",
|
||||
" tower_gradients.append(gradients)\n",
|
||||
" \n",
|
||||
" # Retain tower's total losses\n",
|
||||
" tower_total_losses.append(total_loss)\n",
|
||||
" \n",
|
||||
" # Retain tower's avg losses\n",
|
||||
" tower_avg_losses.append(avg_loss)\n",
|
||||
" \n",
|
||||
" # Retain tower's accuracies\n",
|
||||
" tower_accuracies.append(accuracy)\n",
|
||||
"\n",
|
||||
" # Return results to caller\n",
|
||||
" return tower_decodings, tower_labels, tower_gradients, avg_loss, accuracy"
|
||||
" return tower_decodings, tower_labels, tower_gradients, tower_total_losses, tower_avg_losses, tower_accuracies"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -1110,8 +1125,8 @@
|
||||
"source": [
|
||||
"def decode_batch(data_set):\n",
|
||||
" # Get gradients for each tower (Runs across all GPU's)\n",
|
||||
" tower_decodings, tower_labels, _, _, _ = get_tower_results(data_set)\n",
|
||||
" return tower_decodings, tower_labels\n",
|
||||
" tower_decodings, tower_labels, _, tower_total_losses, _, _ = get_tower_results(data_set)\n",
|
||||
" return tower_decodings, tower_labels, tower_total_losses\n",
|
||||
" "
|
||||
]
|
||||
},
|
||||
@ -1130,22 +1145,24 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def calculate_wer(session, tower_decodings, tower_labels):\n",
|
||||
"def calculate_wer(session, tower_decodings, tower_labels, tower_total_losses):\n",
|
||||
" originals = []\n",
|
||||
" results = []\n",
|
||||
" losses = []\n",
|
||||
" \n",
|
||||
" # Normalization\n",
|
||||
" tower_decodings = [j for i in tower_decodings for j in i]\n",
|
||||
" \n",
|
||||
" # Iterating over the towers\n",
|
||||
" for i in range(len(tower_decodings)):\n",
|
||||
" decoded, labels = session.run([tower_decodings[i], tower_labels[i]], feed_dict)\n",
|
||||
" decoded, labels, loss = session.run([tower_decodings[i], tower_labels[i], tower_total_losses[i]], feed_dict)\n",
|
||||
" originals.extend(sparse_tensor_value_to_texts(labels))\n",
|
||||
" results.extend(sparse_tensor_value_to_texts(decoded))\n",
|
||||
" losses.extend(loss)\n",
|
||||
" \n",
|
||||
" # Pairwise calculation of all rates\n",
|
||||
" rates, mean = wers(originals, results)\n",
|
||||
" return zip(originals, results, rates), mean"
|
||||
" return zip(originals, results, rates, losses), mean"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -1163,13 +1180,20 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def print_wer_report(session, caption, tower_decodings, tower_labels, show_example=True):\n",
|
||||
" items, mean = calculate_wer(session, tower_decodings, tower_labels)\n",
|
||||
"def print_wer_report(session, caption, tower_decodings, tower_labels, tower_total_losses, show_ranked=True):\n",
|
||||
" items, mean = calculate_wer(session, tower_decodings, tower_labels, tower_total_losses)\n",
|
||||
" print \"%s WER: %f09\" % (caption, mean)\n",
|
||||
" if len(items) > 0 and show_example:\n",
|
||||
" print \"Example (WER = %f09)\" % items[-1][2]\n",
|
||||
" print \" - source: \\\"%s\\\"\" % items[-1][0]\n",
|
||||
" print \" - result: \\\"%s\\\"\" % items[-1][1] \n",
|
||||
" if len(items) > 0 and show_ranked:\n",
|
||||
" items = [a for a in items if a[2] > 0]\n",
|
||||
" items.sort(key=lambda a: a[3])\n",
|
||||
" items = items[:10]\n",
|
||||
" items.sort(key=lambda a: a[2])\n",
|
||||
" for a in items:\n",
|
||||
" print\n",
|
||||
" print \"WER: %f09\" % a[2]\n",
|
||||
" print \" - source: \\\"%s\\\"\" % a[0]\n",
|
||||
" print \" - result: \\\"%s\\\"\" % a[1] \n",
|
||||
" print \" - loss: \\\"%s\\\"\" % a[3]\n",
|
||||
" return items, mean"
|
||||
]
|
||||
},
|
||||
@ -1203,11 +1227,19 @@
|
||||
" optimizer = create_optimizer()\n",
|
||||
"\n",
|
||||
" # Get gradients for each tower (Runs across all GPU's)\n",
|
||||
" tower_decodings, tower_labels, tower_gradients, tower_loss, accuracy = \\\n",
|
||||
" get_tower_results(data_sets.train, optimizer)\n",
|
||||
" tower_decodings, \\\n",
|
||||
" tower_labels, \\\n",
|
||||
" tower_gradients, \\\n",
|
||||
" tower_total_losses, \\\n",
|
||||
" tower_avg_losses, \\\n",
|
||||
" tower_accuracies \\\n",
|
||||
" = get_tower_results(data_sets.train, optimizer)\n",
|
||||
" \n",
|
||||
" # Validation step preparation\n",
|
||||
" validation_tower_decodings, validation_tower_labels = decode_batch(data_sets.dev)\n",
|
||||
" validation_tower_decodings, \\\n",
|
||||
" validation_tower_labels, \\\n",
|
||||
" validation_tower_total_losses \\\n",
|
||||
" = decode_batch(data_sets.dev)\n",
|
||||
"\n",
|
||||
" # Average tower gradients\n",
|
||||
" avg_tower_gradients = average_gradients(tower_gradients)\n",
|
||||
@ -1239,16 +1271,21 @@
|
||||
" \n",
|
||||
" # Validation step\n",
|
||||
" if epoch % validation_step == 0:\n",
|
||||
" _, last_validation_wer = print_wer_report(session, \"Validation\", validation_tower_decodings, validation_tower_labels)\n",
|
||||
" _, last_validation_wer = print_wer_report( \\\n",
|
||||
" session, \\\n",
|
||||
" \"Validation\", \\\n",
|
||||
" validation_tower_decodings, \\\n",
|
||||
" validation_tower_labels, \\\n",
|
||||
" validation_tower_total_losses)\n",
|
||||
" print\n",
|
||||
"\n",
|
||||
" # Loop over the batches\n",
|
||||
" for batch in range(int(ceil(float(total_batches)/len(available_devices)))):\n",
|
||||
" for batch in range(int(ceil(float(total_batches) / len(available_devices)))):\n",
|
||||
" # Compute the average loss for the last batch\n",
|
||||
" _, batch_avg_loss = session.run([apply_gradient_op, tower_loss], feed_dict_train)\n",
|
||||
" session.run(apply_gradient_op, feed_dict_train)\n",
|
||||
"\n",
|
||||
" # Add batch to total_accuracy\n",
|
||||
" total_accuracy += session.run(accuracy, feed_dict_train)\n",
|
||||
" total_accuracy += session.run(tower_accuracies[-1], feed_dict_train)\n",
|
||||
"\n",
|
||||
" # Log all variable states in current step\n",
|
||||
" step = epoch * total_batches + batch * len(available_devices)\n",
|
||||
@ -1259,7 +1296,12 @@
|
||||
" # Print progress message\n",
|
||||
" if epoch % display_step == 0:\n",
|
||||
" print \"Epoch:\", '%04d' % (epoch+1), \"avg_cer=\", \"{:.9f}\".format((total_accuracy / total_batches))\n",
|
||||
" _, last_train_wer = print_wer_report(session, \"Training\", tower_decodings, tower_labels)\n",
|
||||
" _, last_train_wer = print_wer_report( \\\n",
|
||||
" session, \\\n",
|
||||
" \"Training\", \\\n",
|
||||
" tower_decodings, \\\n",
|
||||
" tower_labels, \\\n",
|
||||
" tower_total_losses)\n",
|
||||
" print\n",
|
||||
"\n",
|
||||
" # Checkpoint the model\n",
|
||||
@ -1329,8 +1371,8 @@
|
||||
"# Define CPU as device on which the muti-gpu testing is orchestrated\n",
|
||||
"with tf.device('/cpu:0'):\n",
|
||||
" # Test network\n",
|
||||
" test_decodings, test_labels = decode_batch(data_sets.test)\n",
|
||||
" _, test_wer = print_wer_report(session, \"Test\", test_decodings, test_labels)"
|
||||
" test_decodings, test_labels, test_total_losses = decode_batch(data_sets.test)\n",
|
||||
" _, test_wer = print_wer_report(session, \"Test\", test_decodings, test_labels, test_total_losses)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -1386,7 +1428,7 @@
|
||||
" 'total_batches_validation': data_sets.dev.total_batches, \\\n",
|
||||
" 'total_batches_test': data_sets.test.total_batches, \\\n",
|
||||
" 'data_set': { \\\n",
|
||||
" 'name': ds_importer\n",
|
||||
" 'name': ds_importer \\\n",
|
||||
" }, \\\n",
|
||||
" }, \\\n",
|
||||
" 'results': { \\\n",
|
||||
@ -1434,7 +1476,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython2",
|
||||
"version": "2.7.11"
|
||||
"version": "2.7.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user