mirror of
https://github.com/mozilla/DeepSpeech.git
synced 2025-10-26 11:19:39 +00:00
Make sure the initializer passed to tf.scan doesn't break the API contract
We need to make sure the initializer shape matches the return value of the callable passed to tf.scan. This also adds an assertion on the shape of labels and the values in label_lengths that enforces a condition that is needed for ctc_label_dense_to_sparse to work.
This commit is contained in:
parent
dba8f219f7
commit
d989e8de09
@ -132,6 +132,11 @@ def gather_nd(params, indices, shape):
|
||||
# queue and convert to a sparse representation after dequeuing a batch.
|
||||
#
|
||||
def ctc_label_dense_to_sparse(labels, label_lengths, batch_size):
|
||||
# The second dimension of labels must be equal to the longest label length in the batch
|
||||
correct_shape_assert = tf.assert_equal(tf.shape(labels)[1], tf.reduce_max(label_lengths))
|
||||
with tf.control_dependencies([correct_shape_assert]):
|
||||
labels = tf.identity(labels)
|
||||
|
||||
label_shape = tf.shape(labels)
|
||||
num_batches_tns = tf.pack([label_shape[0]])
|
||||
max_num_labels_tns = tf.pack([label_shape[1]])
|
||||
@ -139,6 +144,7 @@ def ctc_label_dense_to_sparse(labels, label_lengths, batch_size):
|
||||
return tf.expand_dims(tf.range(label_shape[1]), 0) < current_input
|
||||
|
||||
init = tf.cast(tf.fill(max_num_labels_tns, 0), tf.bool)
|
||||
init = tf.expand_dims(init, 0)
|
||||
dense_mask = tf.scan(range_less_than, label_lengths, initializer=init, parallel_iterations=1)
|
||||
dense_mask = dense_mask[:, 0, :]
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user