diff --git a/native_client/ctcdecode/ctc_beam_search_decoder.cpp b/native_client/ctcdecode/ctc_beam_search_decoder.cpp index b9797147..c33747f9 100644 --- a/native_client/ctcdecode/ctc_beam_search_decoder.cpp +++ b/native_client/ctcdecode/ctc_beam_search_decoder.cpp @@ -12,29 +12,29 @@ #include "fst/fstlib.h" #include "path_trie.h" -DecoderState* -decoder_init(const Alphabet &alphabet, - int class_dim, - Scorer* ext_scorer) -{ - // dimension check - VALID_CHECK_EQ(class_dim, alphabet.GetSize()+1, - "The shape of probs does not match with " - "the shape of the vocabulary"); +int +DecoderState::init(const Alphabet& alphabet, + size_t beam_size, + double cutoff_prob, + size_t cutoff_top_n, + Scorer *ext_scorer) +{ // assign special ids - DecoderState *state = new DecoderState; - state->time_step = 0; - state->space_id = alphabet.GetSpaceLabel(); - state->blank_id = alphabet.GetSize(); + abs_time_step_ = 0; + space_id_ = alphabet.GetSpaceLabel(); + blank_id_ = alphabet.GetSize(); + + beam_size_ = beam_size; + cutoff_prob_ = cutoff_prob; + cutoff_top_n_ = cutoff_top_n; + ext_scorer_ = ext_scorer; // init prefixes' root PathTrie *root = new PathTrie; root->score = root->log_prob_b_prev = 0.0; - - state->prefix_root = root; - - state->prefixes.push_back(root); + prefix_root_.reset(root); + prefixes_.push_back(root); if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { auto dict_ptr = ext_scorer->dictionary->Copy(true); @@ -43,51 +43,45 @@ decoder_init(const Alphabet &alphabet, root->set_matcher(matcher); } - return state; + return 0; } void -decoder_next(const double *probs, - const Alphabet &alphabet, - DecoderState *state, - int time_dim, - int class_dim, - double cutoff_prob, - size_t cutoff_top_n, - size_t beam_size, - Scorer *ext_scorer) +DecoderState::next(const double *probs, + int time_dim, + int class_dim) { // prefix search over time - for (size_t rel_time_step = 0; rel_time_step < time_dim; ++rel_time_step, ++state->time_step) { + for (size_t rel_time_step = 0; rel_time_step < time_dim; ++rel_time_step, ++abs_time_step_) { auto *prob = &probs[rel_time_step*class_dim]; float min_cutoff = -NUM_FLT_INF; bool full_beam = false; - if (ext_scorer != nullptr) { - size_t num_prefixes = std::min(state->prefixes.size(), beam_size); + if (ext_scorer_ != nullptr) { + size_t num_prefixes = std::min(prefixes_.size(), beam_size_); std::sort( - state->prefixes.begin(), state->prefixes.begin() + num_prefixes, prefix_compare); + prefixes_.begin(), prefixes_.begin() + num_prefixes, prefix_compare); - min_cutoff = state->prefixes[num_prefixes - 1]->score + - std::log(prob[state->blank_id]) - std::max(0.0, ext_scorer->beta); - full_beam = (num_prefixes == beam_size); + min_cutoff = prefixes_[num_prefixes - 1]->score + + std::log(prob[blank_id_]) - std::max(0.0, ext_scorer_->beta); + full_beam = (num_prefixes == beam_size_); } std::vector> log_prob_idx = - get_pruned_log_probs(prob, class_dim, cutoff_prob, cutoff_top_n); + get_pruned_log_probs(prob, class_dim, cutoff_prob_, cutoff_top_n_); // loop over class dim for (size_t index = 0; index < log_prob_idx.size(); index++) { auto c = log_prob_idx[index].first; auto log_prob_c = log_prob_idx[index].second; - for (size_t i = 0; i < state->prefixes.size() && i < beam_size; ++i) { - auto prefix = state->prefixes[i]; + for (size_t i = 0; i < prefixes_.size() && i < beam_size_; ++i) { + auto prefix = prefixes_[i]; if (full_beam && log_prob_c + prefix->score < min_cutoff) { break; } // blank - if (c == state->blank_id) { + if (c == blank_id_) { prefix->log_prob_b_cur = log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score); continue; @@ -100,7 +94,7 @@ decoder_next(const double *probs, } // get new prefix - auto prefix_new = prefix->get_path_trie(c, state->time_step, log_prob_c); + auto prefix_new = prefix->get_path_trie(c, abs_time_step_, log_prob_c); if (prefix_new != nullptr) { float log_p = -NUM_FLT_INF; @@ -113,11 +107,11 @@ decoder_next(const double *probs, } // language model scoring - if (ext_scorer != nullptr && - (c == state->space_id || ext_scorer->is_character_based())) { + if (ext_scorer_ != nullptr && + (c == space_id_ || ext_scorer_->is_character_based())) { PathTrie *prefix_to_score = nullptr; // skip scoring the space - if (ext_scorer->is_character_based()) { + if (ext_scorer_->is_character_based()) { prefix_to_score = prefix_new; } else { prefix_to_score = prefix; @@ -125,10 +119,10 @@ decoder_next(const double *probs, float score = 0.0; std::vector ngram; - ngram = ext_scorer->make_ngram(prefix_to_score); - score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha; + ngram = ext_scorer_->make_ngram(prefix_to_score); + score = ext_scorer_->get_log_cond_prob(ngram) * ext_scorer_->alpha; log_p += score; - log_p += ext_scorer->beta; + log_p += ext_scorer_->beta; } prefix_new->log_prob_nb_cur = @@ -138,53 +132,50 @@ decoder_next(const double *probs, } // end of loop over alphabet // update log probs - state->prefixes.clear(); - state->prefix_root->iterate_to_vec(state->prefixes); + prefixes_.clear(); + prefix_root_->iterate_to_vec(prefixes_); // only preserve top beam_size prefixes - if (state->prefixes.size() > beam_size) { - std::nth_element(state->prefixes.begin(), - state->prefixes.begin() + beam_size, - state->prefixes.end(), + if (prefixes_.size() > beam_size_) { + std::nth_element(prefixes_.begin(), + prefixes_.begin() + beam_size_, + prefixes_.end(), prefix_compare); - for (size_t i = beam_size; i < state->prefixes.size(); ++i) { - state->prefixes[i]->remove(); + for (size_t i = beam_size_; i < prefixes_.size(); ++i) { + prefixes_[i]->remove(); } // Remove the elements from std::vector - state->prefixes.resize(beam_size); + prefixes_.resize(beam_size_); } } // end of loop over time } std::vector -decoder_decode(DecoderState *state, - const Alphabet &alphabet, - size_t beam_size, - Scorer* ext_scorer) +DecoderState::decode() const { - std::vector prefixes_copy = state->prefixes; + std::vector prefixes_copy = prefixes_; std::unordered_map scores; for (PathTrie* prefix : prefixes_copy) { scores[prefix] = prefix->score; } // score the last word of each prefix that doesn't end with space - if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { - for (size_t i = 0; i < beam_size && i < prefixes_copy.size(); ++i) { + if (ext_scorer_ != nullptr && !ext_scorer_->is_character_based()) { + for (size_t i = 0; i < beam_size_ && i < prefixes_copy.size(); ++i) { auto prefix = prefixes_copy[i]; - if (!prefix->is_empty() && prefix->character != state->space_id) { + if (!prefix->is_empty() && prefix->character != space_id_) { float score = 0.0; - std::vector ngram = ext_scorer->make_ngram(prefix); - score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha; - score += ext_scorer->beta; + std::vector ngram = ext_scorer_->make_ngram(prefix); + score = ext_scorer_->get_log_cond_prob(ngram) * ext_scorer_->alpha; + score += ext_scorer_->beta; scores[prefix] += score; } } } using namespace std::placeholders; - size_t num_prefixes = std::min(prefixes_copy.size(), beam_size); + size_t num_prefixes = std::min(prefixes_copy.size(), beam_size_); std::sort(prefixes_copy.begin(), prefixes_copy.begin() + num_prefixes, std::bind(prefix_compare_external, _1, _2, scores)); //TODO: expose this as an API parameter @@ -194,16 +185,16 @@ decoder_decode(DecoderState *state, // return order of decoding result. To delete when decoder gets stable. for (size_t i = 0; i < top_paths && i < prefixes_copy.size(); ++i) { double approx_ctc = scores[prefixes_copy[i]]; - if (ext_scorer != nullptr) { + if (ext_scorer_ != nullptr) { std::vector output; std::vector timesteps; prefixes_copy[i]->get_path_vec(output, timesteps); auto prefix_length = output.size(); - auto words = ext_scorer->split_labels(output); + auto words = ext_scorer_->split_labels(output); // remove word insert - approx_ctc = approx_ctc - prefix_length * ext_scorer->beta; + approx_ctc = approx_ctc - prefix_length * ext_scorer_->beta; // remove language model weight: - approx_ctc -= (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha; + approx_ctc -= (ext_scorer_->get_sent_log_prob(words)) * ext_scorer_->alpha; } prefixes_copy[i]->approx_ctc = approx_ctc; } @@ -221,13 +212,10 @@ std::vector ctc_beam_search_decoder( size_t cutoff_top_n, Scorer *ext_scorer) { - DecoderState *state = decoder_init(alphabet, class_dim, ext_scorer); - decoder_next(probs, alphabet, state, time_dim, class_dim, cutoff_prob, cutoff_top_n, beam_size, ext_scorer); - std::vector out = decoder_decode(state, alphabet, beam_size, ext_scorer); - - delete state; - - return out; + DecoderState state; + state.init(alphabet, beam_size, cutoff_prob, cutoff_top_n, ext_scorer); + state.next(probs, time_dim, class_dim); + return state.decode(); } std::vector> diff --git a/native_client/ctcdecode/ctc_beam_search_decoder.h b/native_client/ctcdecode/ctc_beam_search_decoder.h index 81f1b613..4d6b7ea5 100644 --- a/native_client/ctcdecode/ctc_beam_search_decoder.h +++ b/native_client/ctcdecode/ctc_beam_search_decoder.h @@ -7,66 +7,67 @@ #include "scorer.h" #include "output.h" #include "alphabet.h" -#include "decoderstate.h" -/* Initialize CTC beam search decoder +class DecoderState { + int abs_time_step_; + int space_id_; + int blank_id_; + size_t beam_size_; + double cutoff_prob_; + size_t cutoff_top_n_; - * Parameters: - * alphabet: The alphabet. - * class_dim: Alphabet length (plus 1 for space character). - * ext_scorer: External scorer to evaluate a prefix, which consists of - * n-gram language model scoring and word insertion term. - * Default null, decoding the input sample without scorer. - * Return: - * A struct containing prefixes and state variables. -*/ -DecoderState* decoder_init(const Alphabet &alphabet, - int class_dim, - Scorer *ext_scorer); + Scorer* ext_scorer_; // weak + std::vector prefixes_; + std::unique_ptr prefix_root_; -/* Send data to the decoder +public: + DecoderState() = default; + ~DecoderState() = default; - * Parameters: - * probs: 2-D vector where each element is a vector of probabilities - * over alphabet of one time step. - * alphabet: The alphabet. - * state: The state structure previously obtained from decoder_init(). - * time_dim: Number of timesteps. - * class_dim: Alphabet length (plus 1 for space character). - * cutoff_prob: Cutoff probability for pruning. - * cutoff_top_n: Cutoff number for pruning. - * beam_size: The width of beam search. - * ext_scorer: External scorer to evaluate a prefix, which consists of - * n-gram language model scoring and word insertion term. - * Default null, decoding the input sample without scorer. -*/ -void decoder_next(const double *probs, - const Alphabet &alphabet, - DecoderState *state, - int time_dim, - int class_dim, - double cutoff_prob, - size_t cutoff_top_n, - size_t beam_size, - Scorer *ext_scorer); + // Disallow copying + DecoderState(const DecoderState&) = delete; + DecoderState& operator=(DecoderState&) = delete; -/* Get transcription for the data you sent via decoder_next() + /* Initialize CTC beam search decoder + * + * Parameters: + * alphabet: The alphabet. + * beam_size: The width of beam search. + * cutoff_prob: Cutoff probability for pruning. + * cutoff_top_n: Cutoff number for pruning. + * ext_scorer: External scorer to evaluate a prefix, which consists of + * n-gram language model scoring and word insertion term. + * Default null, decoding the input sample without scorer. + * Return: + * Zero on success, non-zero on failure. + */ + int init(const Alphabet& alphabet, + size_t beam_size, + double cutoff_prob, + size_t cutoff_top_n, + Scorer *ext_scorer); + + /* Send data to the decoder + * + * Parameters: + * probs: 2-D vector where each element is a vector of probabilities + * over alphabet of one time step. + * time_dim: Number of timesteps. + * class_dim: Number of classes (alphabet length + 1 for space character). + */ + void next(const double *probs, + int time_dim, + int class_dim); + + /* Get transcription from current decoder state + * + * Return: + * A vector where each element is a pair of score and decoding result, + * in descending order. + */ + std::vector decode() const; +}; - * Parameters: - * state: The state structure previously obtained from decoder_init(). - * alphabet: The alphabet. - * beam_size: The width of beam search. - * ext_scorer: External scorer to evaluate a prefix, which consists of - * n-gram language model scoring and word insertion term. - * Default null, decoding the input sample without scorer. - * Return: - * A vector where each element is a pair of score and decoding result, - * in descending order. -*/ -std::vector decoder_decode(DecoderState *state, - const Alphabet &alphabet, - size_t beam_size, - Scorer* ext_scorer); /* CTC Beam Search Decoder * Parameters: diff --git a/native_client/ctcdecode/decoderstate.h b/native_client/ctcdecode/decoderstate.h deleted file mode 100644 index 1a92e80e..00000000 --- a/native_client/ctcdecode/decoderstate.h +++ /dev/null @@ -1,23 +0,0 @@ -#ifndef DECODERSTATE_H_ -#define DECODERSTATE_H_ - -#include - -/* Struct for the state of the decoder, containing the prefixes and initial root prefix plus state variables. */ - -struct DecoderState { - int time_step; - int space_id; - int blank_id; - std::vector prefixes; - PathTrie *prefix_root; - - ~DecoderState() { - if (prefix_root != nullptr) { - delete prefix_root; - } - prefix_root = nullptr; - } -}; - -#endif // DECODERSTATE_H_ diff --git a/native_client/deepspeech.cc b/native_client/deepspeech.cc index 1578ed6d..8ddd2c72 100644 --- a/native_client/deepspeech.cc +++ b/native_client/deepspeech.cc @@ -71,7 +71,7 @@ struct StreamingState { vector previous_state_h_; ModelState* model_; - std::unique_ptr decoder_state_; + DecoderState decoder_state_; StreamingState(); ~StreamingState(); @@ -133,21 +133,21 @@ StreamingState::feedAudioContent(const short* buffer, char* StreamingState::intermediateDecode() { - return model_->decode(decoder_state_.get()); + return model_->decode(decoder_state_); } char* StreamingState::finishStream() { finalizeStream(); - return model_->decode(decoder_state_.get()); + return model_->decode(decoder_state_); } Metadata* StreamingState::finishStreamWithMetadata() { finalizeStream(); - return model_->decode_metadata(decoder_state_.get()); + return model_->decode_metadata(decoder_state_); } void @@ -244,23 +244,15 @@ StreamingState::processBatch(const vector& buf, unsigned int n_steps) previous_state_c_, previous_state_h_); - const int cutoff_top_n = 40; - const double cutoff_prob = 1.0; const size_t num_classes = model_->alphabet_->GetSize() + 1; // +1 for blank const int n_frames = logits.size() / (ModelState::BATCH_SIZE * num_classes); // Convert logits to double vector inputs(logits.begin(), logits.end()); - decoder_next(inputs.data(), - *model_->alphabet_, - decoder_state_.get(), - n_frames, - num_classes, - cutoff_prob, - cutoff_top_n, - model_->beam_width_, - model_->scorer_); + decoder_state_.next(inputs.data(), + n_frames, + num_classes); } int @@ -340,8 +332,6 @@ DS_SetupStream(ModelState* aCtx, return DS_ERR_FAIL_CREATE_STREAM; } - const size_t num_classes = aCtx->alphabet_->GetSize() + 1; // +1 for blank - ctx->audio_buffer_.reserve(aCtx->audio_win_len_); ctx->mfcc_buffer_.reserve(aCtx->mfcc_feats_per_timestep_); ctx->mfcc_buffer_.resize(aCtx->n_features_*aCtx->n_context_, 0.f); @@ -350,7 +340,14 @@ DS_SetupStream(ModelState* aCtx, ctx->previous_state_h_.resize(aCtx->state_size_, 0.f); ctx->model_ = aCtx; - ctx->decoder_state_.reset(decoder_init(*aCtx->alphabet_, num_classes, aCtx->scorer_)); + const int cutoff_top_n = 40; + const double cutoff_prob = 1.0; + + ctx->decoder_state_.init(*aCtx->alphabet_, + aCtx->beam_width_, + cutoff_prob, + cutoff_top_n, + aCtx->scorer_); *retval = ctx.release(); return DS_ERR_OK; diff --git a/native_client/modelstate.cc b/native_client/modelstate.cc index 7bb7f073..48c34eee 100644 --- a/native_client/modelstate.cc +++ b/native_client/modelstate.cc @@ -41,24 +41,17 @@ ModelState::init(const char* model_path, return DS_ERR_OK; } -vector -ModelState::decode_raw(DecoderState* state) -{ - vector out = decoder_decode(state, *alphabet_, beam_width_, scorer_); - return out; -} - char* -ModelState::decode(DecoderState* state) +ModelState::decode(const DecoderState& state) { - vector out = decode_raw(state); + vector out = state.decode(); return strdup(alphabet_->LabelsToString(out[0].tokens).c_str()); } Metadata* -ModelState::decode_metadata(DecoderState* state) +ModelState::decode_metadata(const DecoderState& state) { - vector out = decode_raw(state); + vector out = state.decode(); std::unique_ptr metadata(new Metadata()); metadata->num_items = out[0].tokens.size(); diff --git a/native_client/modelstate.h b/native_client/modelstate.h index 71799421..cb7c7d34 100644 --- a/native_client/modelstate.h +++ b/native_client/modelstate.h @@ -8,7 +8,8 @@ #include "ctcdecode/scorer.h" #include "ctcdecode/output.h" -#include "ctcdecode/decoderstate.h" + +class DecoderState; struct ModelState { //TODO: infer batch size from model/use dynamic batch size @@ -59,16 +60,6 @@ struct ModelState { std::vector& state_c_output, std::vector& state_h_output) = 0; - /** - * @brief Perform decoding of the logits, using basic CTC decoder or - * CTC decoder with KenLM enabled - * - * @param state Decoder state to use when decoding. - * - * @return Vector of Output structs directly from the CTC decoder for additional processing. - */ - virtual std::vector decode_raw(DecoderState* state); - /** * @brief Perform decoding of the logits, using basic CTC decoder or * CTC decoder with KenLM enabled @@ -77,7 +68,7 @@ struct ModelState { * * @return String representing the decoded text. */ - virtual char* decode(DecoderState* state); + virtual char* decode(const DecoderState& state); /** * @brief Return character-level metadata including letter timings. @@ -87,7 +78,7 @@ struct ModelState { * @return Metadata struct containing MetadataItem structs for each character. * The user is responsible for freeing Metadata by calling DS_FreeMetadata(). */ - virtual Metadata* decode_metadata(DecoderState* state); + virtual Metadata* decode_metadata(const DecoderState& state); }; #endif // MODELSTATE_H