#!/usr/bin/env python # -*- coding: utf-8 -*- from __future__ import absolute_import, division, print_function import itertools import json from multiprocessing import cpu_count import absl.app import numpy as np import progressbar import tensorflow as tf import tensorflow.compat.v1 as tfv1 from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer from six.moves import zip from util.config import Config, initialize_globals from util.evaluate_tools import calculate_report from util.feeding import create_dataset from util.flags import create_flags, FLAGS from util.logging import log_error, log_progress, create_progressbar def sparse_tensor_value_to_texts(value, alphabet): r""" Given a :class:`tf.SparseTensor` ``value``, return an array of Python strings representing its values, converting tokens to strings using ``alphabet``. """ return sparse_tuple_to_texts((value.indices, value.values, value.dense_shape), alphabet) def sparse_tuple_to_texts(sp_tuple, alphabet): indices = sp_tuple[0] values = sp_tuple[1] results = [''] * sp_tuple[2][0] for i, index in enumerate(indices): results[index[0]] += alphabet.string_from_label(values[i]) # List of strings return results def evaluate(test_csvs, create_model, try_loading): if FLAGS.lm_binary_path: scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.lm_binary_path, FLAGS.lm_trie_path, Config.alphabet) else: scorer = None test_csvs = FLAGS.test_files.split(',') test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False) for csv in test_csvs] iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(test_sets[0]), tfv1.data.get_output_shapes(test_sets[0]), output_classes=tfv1.data.get_output_classes(test_sets[0])) test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets] batch_wav_filename, (batch_x, batch_x_len), batch_y = iterator.get_next() # One rate per layer no_dropout = [None] * 6 logits, _ = create_model(batch_x=batch_x, batch_size=FLAGS.test_batch_size, seq_length=batch_x_len, dropout=no_dropout) # Transpose to batch major and apply softmax for decoder transposed = tf.nn.softmax(tf.transpose(a=logits, perm=[1, 0, 2])) loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_x_len) tfv1.train.get_or_create_global_step() # Get number of accessible CPU cores for this process try: num_processes = cpu_count() except NotImplementedError: num_processes = 1 # Create a saver using variables from the above newly created graph saver = tfv1.train.Saver() with tfv1.Session(config=Config.session_config) as session: # Restore variables from training checkpoint loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation') if not loaded: loaded = try_loading(session, saver, 'checkpoint', 'most recent') if not loaded: log_error('Checkpoint directory ({}) does not contain a valid checkpoint state.'.format(FLAGS.checkpoint_dir)) exit(1) def run_test(init_op, dataset): wav_filenames = [] losses = [] predictions = [] ground_truths = [] bar = create_progressbar(prefix='Test epoch | ', widgets=['Steps: ', progressbar.Counter(), ' | ', progressbar.Timer()]).start() log_progress('Test epoch...') step_count = 0 # Initialize iterator to the appropriate dataset session.run(init_op) # First pass, compute losses and transposed logits for decoding while True: try: batch_wav_filenames, batch_logits, batch_loss, batch_lengths, batch_transcripts = \ session.run([batch_wav_filename, transposed, loss, batch_x_len, batch_y]) except tf.errors.OutOfRangeError: break decoded = ctc_beam_search_decoder_batch(batch_logits, batch_lengths, Config.alphabet, FLAGS.beam_width, num_processes=num_processes, scorer=scorer) predictions.extend(d[0][1] for d in decoded) ground_truths.extend(sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet)) wav_filenames.extend(wav_filename.decode('UTF-8') for wav_filename in batch_wav_filenames) losses.extend(batch_loss) step_count += 1 bar.update(step_count) bar.finish() wer, cer, samples = calculate_report(wav_filenames, ground_truths, predictions, losses) mean_loss = np.mean(losses) # Take only the first report_count items report_samples = itertools.islice(samples, FLAGS.report_count) print('Test on %s - WER: %f, CER: %f, loss: %f' % (dataset, wer, cer, mean_loss)) print('-' * 80) for sample in report_samples: print('WER: %f, CER: %f, loss: %f' % (sample.wer, sample.cer, sample.loss)) print(' - wav: file://%s' % sample.wav_filename) print(' - src: "%s"' % sample.src) print(' - res: "%s"' % sample.res) print('-' * 80) return samples samples = [] for csv, init_op in zip(test_csvs, test_init_ops): print('Testing model on {}'.format(csv)) samples.extend(run_test(init_op, dataset=csv)) return samples def main(_): initialize_globals() if not FLAGS.test_files: log_error('You need to specify what files to use for evaluation via ' 'the --test_files flag.') exit(1) from DeepSpeech import create_model, try_loading # pylint: disable=cyclic-import samples = evaluate(FLAGS.test_files.split(','), create_model, try_loading) if FLAGS.test_output_file: # Save decoded tuples as JSON, converting NumPy floats to Python floats json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float) if __name__ == '__main__': create_flags() absl.app.run(main)