mirror of
https://github.com/mozilla/DeepSpeech.git
synced 2025-10-26 11:19:39 +00:00
275 lines
8.7 KiB
C++
275 lines
8.7 KiB
C++
#include "ctc_beam_search_decoder.h"
|
|
|
|
#include <algorithm>
|
|
#include <cmath>
|
|
#include <iostream>
|
|
#include <limits>
|
|
#include <map>
|
|
#include <utility>
|
|
|
|
#include "decoder_utils.h"
|
|
#include "ThreadPool.h"
|
|
#include "fst/fstlib.h"
|
|
#include "path_trie.h"
|
|
|
|
|
|
int
|
|
DecoderState::init(const Alphabet& alphabet,
|
|
size_t beam_size,
|
|
double cutoff_prob,
|
|
size_t cutoff_top_n,
|
|
std::shared_ptr<Scorer> ext_scorer)
|
|
{
|
|
// assign special ids
|
|
abs_time_step_ = 0;
|
|
space_id_ = alphabet.GetSpaceLabel();
|
|
blank_id_ = alphabet.GetSize();
|
|
|
|
beam_size_ = beam_size;
|
|
cutoff_prob_ = cutoff_prob;
|
|
cutoff_top_n_ = cutoff_top_n;
|
|
ext_scorer_ = ext_scorer;
|
|
start_expanding_ = false;
|
|
|
|
// init prefixes' root
|
|
PathTrie *root = new PathTrie;
|
|
root->score = root->log_prob_b_prev = 0.0;
|
|
prefix_root_.reset(root);
|
|
prefixes_.push_back(root);
|
|
|
|
if (ext_scorer && (bool)(ext_scorer_->dictionary)) {
|
|
// no need for std::make_shared<>() since Copy() does 'new' behind the doors
|
|
auto dict_ptr = std::shared_ptr<PathTrie::FstType>(ext_scorer->dictionary->Copy(true));
|
|
root->set_dictionary(dict_ptr);
|
|
auto matcher = std::make_shared<fst::SortedMatcher<PathTrie::FstType>>(*dict_ptr, fst::MATCH_INPUT);
|
|
root->set_matcher(matcher);
|
|
}
|
|
|
|
return 0;
|
|
}
|
|
|
|
void
|
|
DecoderState::next(const double *probs,
|
|
int time_dim,
|
|
int class_dim)
|
|
{
|
|
// prefix search over time
|
|
for (size_t rel_time_step = 0; rel_time_step < time_dim; ++rel_time_step, ++abs_time_step_) {
|
|
auto *prob = &probs[rel_time_step*class_dim];
|
|
|
|
// At the start of the decoding process, we delay beam expansion so that
|
|
// timings on the first letters is not incorrect. As soon as we see a
|
|
// timestep with blank probability lower than 0.999, we start expanding
|
|
// beams.
|
|
if (prob[blank_id_] < 0.999) {
|
|
start_expanding_ = true;
|
|
}
|
|
|
|
// If not expanding yet, just continue to next timestep.
|
|
if (!start_expanding_) {
|
|
continue;
|
|
}
|
|
|
|
float min_cutoff = -NUM_FLT_INF;
|
|
bool full_beam = false;
|
|
if (ext_scorer_) {
|
|
size_t num_prefixes = std::min(prefixes_.size(), beam_size_);
|
|
std::partial_sort(prefixes_.begin(),
|
|
prefixes_.begin() + num_prefixes,
|
|
prefixes_.end(),
|
|
prefix_compare);
|
|
|
|
min_cutoff = prefixes_[num_prefixes - 1]->score +
|
|
std::log(prob[blank_id_]) - std::max(0.0, ext_scorer_->beta);
|
|
full_beam = (num_prefixes == beam_size_);
|
|
}
|
|
|
|
std::vector<std::pair<size_t, float>> log_prob_idx =
|
|
get_pruned_log_probs(prob, class_dim, cutoff_prob_, cutoff_top_n_);
|
|
// loop over class dim
|
|
for (size_t index = 0; index < log_prob_idx.size(); index++) {
|
|
auto c = log_prob_idx[index].first;
|
|
auto log_prob_c = log_prob_idx[index].second;
|
|
|
|
for (size_t i = 0; i < prefixes_.size() && i < beam_size_; ++i) {
|
|
auto prefix = prefixes_[i];
|
|
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
|
|
break;
|
|
}
|
|
|
|
// blank
|
|
if (c == blank_id_) {
|
|
prefix->log_prob_b_cur =
|
|
log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score);
|
|
continue;
|
|
}
|
|
|
|
// repeated character
|
|
if (c == prefix->character) {
|
|
prefix->log_prob_nb_cur = log_sum_exp(
|
|
prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev);
|
|
}
|
|
|
|
// get new prefix
|
|
auto prefix_new = prefix->get_path_trie(c, abs_time_step_, log_prob_c);
|
|
|
|
if (prefix_new != nullptr) {
|
|
float log_p = -NUM_FLT_INF;
|
|
|
|
if (c == prefix->character &&
|
|
prefix->log_prob_b_prev > -NUM_FLT_INF) {
|
|
log_p = log_prob_c + prefix->log_prob_b_prev;
|
|
} else if (c != prefix->character) {
|
|
log_p = log_prob_c + prefix->score;
|
|
}
|
|
|
|
if (ext_scorer_) {
|
|
// skip scoring the space in word based LMs
|
|
PathTrie* prefix_to_score;
|
|
if (ext_scorer_->is_utf8_mode()) {
|
|
prefix_to_score = prefix_new;
|
|
} else {
|
|
prefix_to_score = prefix;
|
|
}
|
|
|
|
// language model scoring
|
|
if (ext_scorer_->is_scoring_boundary(prefix_to_score, c)) {
|
|
float score = 0.0;
|
|
std::vector<std::string> ngram;
|
|
ngram = ext_scorer_->make_ngram(prefix_to_score);
|
|
bool bos = ngram.size() < ext_scorer_->get_max_order();
|
|
score = ext_scorer_->get_log_cond_prob(ngram, bos) * ext_scorer_->alpha;
|
|
log_p += score;
|
|
log_p += ext_scorer_->beta;
|
|
}
|
|
}
|
|
|
|
prefix_new->log_prob_nb_cur =
|
|
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
|
|
}
|
|
} // end of loop over prefix
|
|
} // end of loop over alphabet
|
|
|
|
// update log probs
|
|
prefixes_.clear();
|
|
prefix_root_->iterate_to_vec(prefixes_);
|
|
|
|
// only preserve top beam_size prefixes
|
|
if (prefixes_.size() > beam_size_) {
|
|
std::nth_element(prefixes_.begin(),
|
|
prefixes_.begin() + beam_size_,
|
|
prefixes_.end(),
|
|
prefix_compare);
|
|
for (size_t i = beam_size_; i < prefixes_.size(); ++i) {
|
|
prefixes_[i]->remove();
|
|
}
|
|
|
|
// Remove the elements from std::vector
|
|
prefixes_.resize(beam_size_);
|
|
}
|
|
} // end of loop over time
|
|
}
|
|
|
|
std::vector<Output>
|
|
DecoderState::decode(size_t num_results) const
|
|
{
|
|
std::vector<PathTrie*> prefixes_copy = prefixes_;
|
|
std::unordered_map<const PathTrie*, float> scores;
|
|
for (PathTrie* prefix : prefixes_copy) {
|
|
scores[prefix] = prefix->score;
|
|
}
|
|
|
|
// score the last word of each prefix that doesn't end with space
|
|
if (ext_scorer_) {
|
|
for (size_t i = 0; i < beam_size_ && i < prefixes_copy.size(); ++i) {
|
|
PathTrie* prefix = prefixes_copy[i];
|
|
PathTrie* prefix_boundary = ext_scorer_->is_utf8_mode() ? prefix : prefix->parent;
|
|
if (prefix_boundary && !ext_scorer_->is_scoring_boundary(prefix_boundary, prefix->character)) {
|
|
float score = 0.0;
|
|
std::vector<std::string> ngram = ext_scorer_->make_ngram(prefix);
|
|
bool bos = ngram.size() < ext_scorer_->get_max_order();
|
|
score = ext_scorer_->get_log_cond_prob(ngram, bos) * ext_scorer_->alpha;
|
|
score += ext_scorer_->beta;
|
|
scores[prefix] += score;
|
|
}
|
|
}
|
|
}
|
|
|
|
using namespace std::placeholders;
|
|
size_t num_returned = std::min(prefixes_copy.size(), num_results);
|
|
std::partial_sort(prefixes_copy.begin(),
|
|
prefixes_copy.begin() + num_returned,
|
|
prefixes_copy.end(),
|
|
std::bind(prefix_compare_external, _1, _2, scores));
|
|
|
|
std::vector<Output> outputs;
|
|
outputs.reserve(num_returned);
|
|
|
|
for (size_t i = 0; i < num_returned; ++i) {
|
|
Output output;
|
|
prefixes_copy[i]->get_path_vec(output.tokens, output.timesteps);
|
|
output.confidence = scores[prefixes_copy[i]];
|
|
outputs.push_back(output);
|
|
}
|
|
|
|
return outputs;
|
|
}
|
|
|
|
std::vector<Output> ctc_beam_search_decoder(
|
|
const double *probs,
|
|
int time_dim,
|
|
int class_dim,
|
|
const Alphabet &alphabet,
|
|
size_t beam_size,
|
|
double cutoff_prob,
|
|
size_t cutoff_top_n,
|
|
std::shared_ptr<Scorer> ext_scorer)
|
|
{
|
|
DecoderState state;
|
|
state.init(alphabet, beam_size, cutoff_prob, cutoff_top_n, ext_scorer);
|
|
state.next(probs, time_dim, class_dim);
|
|
return state.decode();
|
|
}
|
|
|
|
std::vector<std::vector<Output>>
|
|
ctc_beam_search_decoder_batch(
|
|
const double *probs,
|
|
int batch_size,
|
|
int time_dim,
|
|
int class_dim,
|
|
const int* seq_lengths,
|
|
int seq_lengths_size,
|
|
const Alphabet &alphabet,
|
|
size_t beam_size,
|
|
size_t num_processes,
|
|
double cutoff_prob,
|
|
size_t cutoff_top_n,
|
|
std::shared_ptr<Scorer> ext_scorer)
|
|
{
|
|
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
|
|
VALID_CHECK_EQ(batch_size, seq_lengths_size, "must have one sequence length per batch element");
|
|
// thread pool
|
|
ThreadPool pool(num_processes);
|
|
|
|
// enqueue the tasks of decoding
|
|
std::vector<std::future<std::vector<Output>>> res;
|
|
for (size_t i = 0; i < batch_size; ++i) {
|
|
res.emplace_back(pool.enqueue(ctc_beam_search_decoder,
|
|
&probs[i*time_dim*class_dim],
|
|
seq_lengths[i],
|
|
class_dim,
|
|
alphabet,
|
|
beam_size,
|
|
cutoff_prob,
|
|
cutoff_top_n,
|
|
ext_scorer));
|
|
}
|
|
|
|
// get decoding results
|
|
std::vector<std::vector<Output>> batch_results;
|
|
for (size_t i = 0; i < batch_size; ++i) {
|
|
batch_results.emplace_back(res[i].get());
|
|
}
|
|
return batch_results;
|
|
}
|