Make sure the initializer passed to tf.scan doesn't break the API contract

We need to make sure the initializer shape matches the return value
of the callable passed to tf.scan.

This also adds an assertion on the shape of labels and the values
in label_lengths that enforces a condition that is needed for
ctc_label_dense_to_sparse to work.
This commit is contained in:
Reuben Morais 2016-11-08 12:19:42 -02:00
parent dba8f219f7
commit d989e8de09

View File

@ -132,6 +132,11 @@ def gather_nd(params, indices, shape):
# queue and convert to a sparse representation after dequeuing a batch.
#
def ctc_label_dense_to_sparse(labels, label_lengths, batch_size):
# The second dimension of labels must be equal to the longest label length in the batch
correct_shape_assert = tf.assert_equal(tf.shape(labels)[1], tf.reduce_max(label_lengths))
with tf.control_dependencies([correct_shape_assert]):
labels = tf.identity(labels)
label_shape = tf.shape(labels)
num_batches_tns = tf.pack([label_shape[0]])
max_num_labels_tns = tf.pack([label_shape[1]])
@ -139,6 +144,7 @@ def ctc_label_dense_to_sparse(labels, label_lengths, batch_size):
return tf.expand_dims(tf.range(label_shape[1]), 0) < current_input
init = tf.cast(tf.fill(max_num_labels_tns, 0), tf.bool)
init = tf.expand_dims(init, 0)
dense_mask = tf.scan(range_less_than, label_lengths, initializer=init, parallel_iterations=1)
dense_mask = dense_mask[:, 0, :]