mirror of
https://github.com/mozilla/DeepSpeech.git
synced 2025-10-26 11:19:39 +00:00
Simplify decoder impl by making it object oriented, avoid pointers where possible
This commit is contained in:
parent
f442b69aeb
commit
4d882a8aec
@ -12,29 +12,29 @@
|
||||
#include "fst/fstlib.h"
|
||||
#include "path_trie.h"
|
||||
|
||||
DecoderState*
|
||||
decoder_init(const Alphabet &alphabet,
|
||||
int class_dim,
|
||||
Scorer* ext_scorer)
|
||||
{
|
||||
// dimension check
|
||||
VALID_CHECK_EQ(class_dim, alphabet.GetSize()+1,
|
||||
"The shape of probs does not match with "
|
||||
"the shape of the vocabulary");
|
||||
|
||||
int
|
||||
DecoderState::init(const Alphabet& alphabet,
|
||||
size_t beam_size,
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n,
|
||||
Scorer *ext_scorer)
|
||||
{
|
||||
// assign special ids
|
||||
DecoderState *state = new DecoderState;
|
||||
state->time_step = 0;
|
||||
state->space_id = alphabet.GetSpaceLabel();
|
||||
state->blank_id = alphabet.GetSize();
|
||||
abs_time_step_ = 0;
|
||||
space_id_ = alphabet.GetSpaceLabel();
|
||||
blank_id_ = alphabet.GetSize();
|
||||
|
||||
beam_size_ = beam_size;
|
||||
cutoff_prob_ = cutoff_prob;
|
||||
cutoff_top_n_ = cutoff_top_n;
|
||||
ext_scorer_ = ext_scorer;
|
||||
|
||||
// init prefixes' root
|
||||
PathTrie *root = new PathTrie;
|
||||
root->score = root->log_prob_b_prev = 0.0;
|
||||
|
||||
state->prefix_root = root;
|
||||
|
||||
state->prefixes.push_back(root);
|
||||
prefix_root_.reset(root);
|
||||
prefixes_.push_back(root);
|
||||
|
||||
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
|
||||
auto dict_ptr = ext_scorer->dictionary->Copy(true);
|
||||
@ -43,51 +43,45 @@ decoder_init(const Alphabet &alphabet,
|
||||
root->set_matcher(matcher);
|
||||
}
|
||||
|
||||
return state;
|
||||
return 0;
|
||||
}
|
||||
|
||||
void
|
||||
decoder_next(const double *probs,
|
||||
const Alphabet &alphabet,
|
||||
DecoderState *state,
|
||||
int time_dim,
|
||||
int class_dim,
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n,
|
||||
size_t beam_size,
|
||||
Scorer *ext_scorer)
|
||||
DecoderState::next(const double *probs,
|
||||
int time_dim,
|
||||
int class_dim)
|
||||
{
|
||||
// prefix search over time
|
||||
for (size_t rel_time_step = 0; rel_time_step < time_dim; ++rel_time_step, ++state->time_step) {
|
||||
for (size_t rel_time_step = 0; rel_time_step < time_dim; ++rel_time_step, ++abs_time_step_) {
|
||||
auto *prob = &probs[rel_time_step*class_dim];
|
||||
|
||||
float min_cutoff = -NUM_FLT_INF;
|
||||
bool full_beam = false;
|
||||
if (ext_scorer != nullptr) {
|
||||
size_t num_prefixes = std::min(state->prefixes.size(), beam_size);
|
||||
if (ext_scorer_ != nullptr) {
|
||||
size_t num_prefixes = std::min(prefixes_.size(), beam_size_);
|
||||
std::sort(
|
||||
state->prefixes.begin(), state->prefixes.begin() + num_prefixes, prefix_compare);
|
||||
prefixes_.begin(), prefixes_.begin() + num_prefixes, prefix_compare);
|
||||
|
||||
min_cutoff = state->prefixes[num_prefixes - 1]->score +
|
||||
std::log(prob[state->blank_id]) - std::max(0.0, ext_scorer->beta);
|
||||
full_beam = (num_prefixes == beam_size);
|
||||
min_cutoff = prefixes_[num_prefixes - 1]->score +
|
||||
std::log(prob[blank_id_]) - std::max(0.0, ext_scorer_->beta);
|
||||
full_beam = (num_prefixes == beam_size_);
|
||||
}
|
||||
|
||||
std::vector<std::pair<size_t, float>> log_prob_idx =
|
||||
get_pruned_log_probs(prob, class_dim, cutoff_prob, cutoff_top_n);
|
||||
get_pruned_log_probs(prob, class_dim, cutoff_prob_, cutoff_top_n_);
|
||||
// loop over class dim
|
||||
for (size_t index = 0; index < log_prob_idx.size(); index++) {
|
||||
auto c = log_prob_idx[index].first;
|
||||
auto log_prob_c = log_prob_idx[index].second;
|
||||
|
||||
for (size_t i = 0; i < state->prefixes.size() && i < beam_size; ++i) {
|
||||
auto prefix = state->prefixes[i];
|
||||
for (size_t i = 0; i < prefixes_.size() && i < beam_size_; ++i) {
|
||||
auto prefix = prefixes_[i];
|
||||
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
|
||||
break;
|
||||
}
|
||||
|
||||
// blank
|
||||
if (c == state->blank_id) {
|
||||
if (c == blank_id_) {
|
||||
prefix->log_prob_b_cur =
|
||||
log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score);
|
||||
continue;
|
||||
@ -100,7 +94,7 @@ decoder_next(const double *probs,
|
||||
}
|
||||
|
||||
// get new prefix
|
||||
auto prefix_new = prefix->get_path_trie(c, state->time_step, log_prob_c);
|
||||
auto prefix_new = prefix->get_path_trie(c, abs_time_step_, log_prob_c);
|
||||
|
||||
if (prefix_new != nullptr) {
|
||||
float log_p = -NUM_FLT_INF;
|
||||
@ -113,11 +107,11 @@ decoder_next(const double *probs,
|
||||
}
|
||||
|
||||
// language model scoring
|
||||
if (ext_scorer != nullptr &&
|
||||
(c == state->space_id || ext_scorer->is_character_based())) {
|
||||
if (ext_scorer_ != nullptr &&
|
||||
(c == space_id_ || ext_scorer_->is_character_based())) {
|
||||
PathTrie *prefix_to_score = nullptr;
|
||||
// skip scoring the space
|
||||
if (ext_scorer->is_character_based()) {
|
||||
if (ext_scorer_->is_character_based()) {
|
||||
prefix_to_score = prefix_new;
|
||||
} else {
|
||||
prefix_to_score = prefix;
|
||||
@ -125,10 +119,10 @@ decoder_next(const double *probs,
|
||||
|
||||
float score = 0.0;
|
||||
std::vector<std::string> ngram;
|
||||
ngram = ext_scorer->make_ngram(prefix_to_score);
|
||||
score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
|
||||
ngram = ext_scorer_->make_ngram(prefix_to_score);
|
||||
score = ext_scorer_->get_log_cond_prob(ngram) * ext_scorer_->alpha;
|
||||
log_p += score;
|
||||
log_p += ext_scorer->beta;
|
||||
log_p += ext_scorer_->beta;
|
||||
}
|
||||
|
||||
prefix_new->log_prob_nb_cur =
|
||||
@ -138,53 +132,50 @@ decoder_next(const double *probs,
|
||||
} // end of loop over alphabet
|
||||
|
||||
// update log probs
|
||||
state->prefixes.clear();
|
||||
state->prefix_root->iterate_to_vec(state->prefixes);
|
||||
prefixes_.clear();
|
||||
prefix_root_->iterate_to_vec(prefixes_);
|
||||
|
||||
// only preserve top beam_size prefixes
|
||||
if (state->prefixes.size() > beam_size) {
|
||||
std::nth_element(state->prefixes.begin(),
|
||||
state->prefixes.begin() + beam_size,
|
||||
state->prefixes.end(),
|
||||
if (prefixes_.size() > beam_size_) {
|
||||
std::nth_element(prefixes_.begin(),
|
||||
prefixes_.begin() + beam_size_,
|
||||
prefixes_.end(),
|
||||
prefix_compare);
|
||||
for (size_t i = beam_size; i < state->prefixes.size(); ++i) {
|
||||
state->prefixes[i]->remove();
|
||||
for (size_t i = beam_size_; i < prefixes_.size(); ++i) {
|
||||
prefixes_[i]->remove();
|
||||
}
|
||||
|
||||
// Remove the elements from std::vector
|
||||
state->prefixes.resize(beam_size);
|
||||
prefixes_.resize(beam_size_);
|
||||
}
|
||||
} // end of loop over time
|
||||
}
|
||||
|
||||
std::vector<Output>
|
||||
decoder_decode(DecoderState *state,
|
||||
const Alphabet &alphabet,
|
||||
size_t beam_size,
|
||||
Scorer* ext_scorer)
|
||||
DecoderState::decode() const
|
||||
{
|
||||
std::vector<PathTrie*> prefixes_copy = state->prefixes;
|
||||
std::vector<PathTrie*> prefixes_copy = prefixes_;
|
||||
std::unordered_map<const PathTrie*, float> scores;
|
||||
for (PathTrie* prefix : prefixes_copy) {
|
||||
scores[prefix] = prefix->score;
|
||||
}
|
||||
|
||||
// score the last word of each prefix that doesn't end with space
|
||||
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
|
||||
for (size_t i = 0; i < beam_size && i < prefixes_copy.size(); ++i) {
|
||||
if (ext_scorer_ != nullptr && !ext_scorer_->is_character_based()) {
|
||||
for (size_t i = 0; i < beam_size_ && i < prefixes_copy.size(); ++i) {
|
||||
auto prefix = prefixes_copy[i];
|
||||
if (!prefix->is_empty() && prefix->character != state->space_id) {
|
||||
if (!prefix->is_empty() && prefix->character != space_id_) {
|
||||
float score = 0.0;
|
||||
std::vector<std::string> ngram = ext_scorer->make_ngram(prefix);
|
||||
score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
|
||||
score += ext_scorer->beta;
|
||||
std::vector<std::string> ngram = ext_scorer_->make_ngram(prefix);
|
||||
score = ext_scorer_->get_log_cond_prob(ngram) * ext_scorer_->alpha;
|
||||
score += ext_scorer_->beta;
|
||||
scores[prefix] += score;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
using namespace std::placeholders;
|
||||
size_t num_prefixes = std::min(prefixes_copy.size(), beam_size);
|
||||
size_t num_prefixes = std::min(prefixes_copy.size(), beam_size_);
|
||||
std::sort(prefixes_copy.begin(), prefixes_copy.begin() + num_prefixes, std::bind(prefix_compare_external, _1, _2, scores));
|
||||
|
||||
//TODO: expose this as an API parameter
|
||||
@ -194,16 +185,16 @@ decoder_decode(DecoderState *state,
|
||||
// return order of decoding result. To delete when decoder gets stable.
|
||||
for (size_t i = 0; i < top_paths && i < prefixes_copy.size(); ++i) {
|
||||
double approx_ctc = scores[prefixes_copy[i]];
|
||||
if (ext_scorer != nullptr) {
|
||||
if (ext_scorer_ != nullptr) {
|
||||
std::vector<int> output;
|
||||
std::vector<int> timesteps;
|
||||
prefixes_copy[i]->get_path_vec(output, timesteps);
|
||||
auto prefix_length = output.size();
|
||||
auto words = ext_scorer->split_labels(output);
|
||||
auto words = ext_scorer_->split_labels(output);
|
||||
// remove word insert
|
||||
approx_ctc = approx_ctc - prefix_length * ext_scorer->beta;
|
||||
approx_ctc = approx_ctc - prefix_length * ext_scorer_->beta;
|
||||
// remove language model weight:
|
||||
approx_ctc -= (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha;
|
||||
approx_ctc -= (ext_scorer_->get_sent_log_prob(words)) * ext_scorer_->alpha;
|
||||
}
|
||||
prefixes_copy[i]->approx_ctc = approx_ctc;
|
||||
}
|
||||
@ -221,13 +212,10 @@ std::vector<Output> ctc_beam_search_decoder(
|
||||
size_t cutoff_top_n,
|
||||
Scorer *ext_scorer)
|
||||
{
|
||||
DecoderState *state = decoder_init(alphabet, class_dim, ext_scorer);
|
||||
decoder_next(probs, alphabet, state, time_dim, class_dim, cutoff_prob, cutoff_top_n, beam_size, ext_scorer);
|
||||
std::vector<Output> out = decoder_decode(state, alphabet, beam_size, ext_scorer);
|
||||
|
||||
delete state;
|
||||
|
||||
return out;
|
||||
DecoderState state;
|
||||
state.init(alphabet, beam_size, cutoff_prob, cutoff_top_n, ext_scorer);
|
||||
state.next(probs, time_dim, class_dim);
|
||||
return state.decode();
|
||||
}
|
||||
|
||||
std::vector<std::vector<Output>>
|
||||
|
||||
@ -7,66 +7,67 @@
|
||||
#include "scorer.h"
|
||||
#include "output.h"
|
||||
#include "alphabet.h"
|
||||
#include "decoderstate.h"
|
||||
|
||||
/* Initialize CTC beam search decoder
|
||||
class DecoderState {
|
||||
int abs_time_step_;
|
||||
int space_id_;
|
||||
int blank_id_;
|
||||
size_t beam_size_;
|
||||
double cutoff_prob_;
|
||||
size_t cutoff_top_n_;
|
||||
|
||||
* Parameters:
|
||||
* alphabet: The alphabet.
|
||||
* class_dim: Alphabet length (plus 1 for space character).
|
||||
* ext_scorer: External scorer to evaluate a prefix, which consists of
|
||||
* n-gram language model scoring and word insertion term.
|
||||
* Default null, decoding the input sample without scorer.
|
||||
* Return:
|
||||
* A struct containing prefixes and state variables.
|
||||
*/
|
||||
DecoderState* decoder_init(const Alphabet &alphabet,
|
||||
int class_dim,
|
||||
Scorer *ext_scorer);
|
||||
Scorer* ext_scorer_; // weak
|
||||
std::vector<PathTrie*> prefixes_;
|
||||
std::unique_ptr<PathTrie> prefix_root_;
|
||||
|
||||
/* Send data to the decoder
|
||||
public:
|
||||
DecoderState() = default;
|
||||
~DecoderState() = default;
|
||||
|
||||
* Parameters:
|
||||
* probs: 2-D vector where each element is a vector of probabilities
|
||||
* over alphabet of one time step.
|
||||
* alphabet: The alphabet.
|
||||
* state: The state structure previously obtained from decoder_init().
|
||||
* time_dim: Number of timesteps.
|
||||
* class_dim: Alphabet length (plus 1 for space character).
|
||||
* cutoff_prob: Cutoff probability for pruning.
|
||||
* cutoff_top_n: Cutoff number for pruning.
|
||||
* beam_size: The width of beam search.
|
||||
* ext_scorer: External scorer to evaluate a prefix, which consists of
|
||||
* n-gram language model scoring and word insertion term.
|
||||
* Default null, decoding the input sample without scorer.
|
||||
*/
|
||||
void decoder_next(const double *probs,
|
||||
const Alphabet &alphabet,
|
||||
DecoderState *state,
|
||||
int time_dim,
|
||||
int class_dim,
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n,
|
||||
size_t beam_size,
|
||||
Scorer *ext_scorer);
|
||||
// Disallow copying
|
||||
DecoderState(const DecoderState&) = delete;
|
||||
DecoderState& operator=(DecoderState&) = delete;
|
||||
|
||||
/* Get transcription for the data you sent via decoder_next()
|
||||
/* Initialize CTC beam search decoder
|
||||
*
|
||||
* Parameters:
|
||||
* alphabet: The alphabet.
|
||||
* beam_size: The width of beam search.
|
||||
* cutoff_prob: Cutoff probability for pruning.
|
||||
* cutoff_top_n: Cutoff number for pruning.
|
||||
* ext_scorer: External scorer to evaluate a prefix, which consists of
|
||||
* n-gram language model scoring and word insertion term.
|
||||
* Default null, decoding the input sample without scorer.
|
||||
* Return:
|
||||
* Zero on success, non-zero on failure.
|
||||
*/
|
||||
int init(const Alphabet& alphabet,
|
||||
size_t beam_size,
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n,
|
||||
Scorer *ext_scorer);
|
||||
|
||||
/* Send data to the decoder
|
||||
*
|
||||
* Parameters:
|
||||
* probs: 2-D vector where each element is a vector of probabilities
|
||||
* over alphabet of one time step.
|
||||
* time_dim: Number of timesteps.
|
||||
* class_dim: Number of classes (alphabet length + 1 for space character).
|
||||
*/
|
||||
void next(const double *probs,
|
||||
int time_dim,
|
||||
int class_dim);
|
||||
|
||||
/* Get transcription from current decoder state
|
||||
*
|
||||
* Return:
|
||||
* A vector where each element is a pair of score and decoding result,
|
||||
* in descending order.
|
||||
*/
|
||||
std::vector<Output> decode() const;
|
||||
};
|
||||
|
||||
* Parameters:
|
||||
* state: The state structure previously obtained from decoder_init().
|
||||
* alphabet: The alphabet.
|
||||
* beam_size: The width of beam search.
|
||||
* ext_scorer: External scorer to evaluate a prefix, which consists of
|
||||
* n-gram language model scoring and word insertion term.
|
||||
* Default null, decoding the input sample without scorer.
|
||||
* Return:
|
||||
* A vector where each element is a pair of score and decoding result,
|
||||
* in descending order.
|
||||
*/
|
||||
std::vector<Output> decoder_decode(DecoderState *state,
|
||||
const Alphabet &alphabet,
|
||||
size_t beam_size,
|
||||
Scorer* ext_scorer);
|
||||
|
||||
/* CTC Beam Search Decoder
|
||||
* Parameters:
|
||||
|
||||
@ -1,23 +0,0 @@
|
||||
#ifndef DECODERSTATE_H_
|
||||
#define DECODERSTATE_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
/* Struct for the state of the decoder, containing the prefixes and initial root prefix plus state variables. */
|
||||
|
||||
struct DecoderState {
|
||||
int time_step;
|
||||
int space_id;
|
||||
int blank_id;
|
||||
std::vector<PathTrie*> prefixes;
|
||||
PathTrie *prefix_root;
|
||||
|
||||
~DecoderState() {
|
||||
if (prefix_root != nullptr) {
|
||||
delete prefix_root;
|
||||
}
|
||||
prefix_root = nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
#endif // DECODERSTATE_H_
|
||||
@ -71,7 +71,7 @@ struct StreamingState {
|
||||
vector<float> previous_state_h_;
|
||||
|
||||
ModelState* model_;
|
||||
std::unique_ptr<DecoderState> decoder_state_;
|
||||
DecoderState decoder_state_;
|
||||
|
||||
StreamingState();
|
||||
~StreamingState();
|
||||
@ -133,21 +133,21 @@ StreamingState::feedAudioContent(const short* buffer,
|
||||
char*
|
||||
StreamingState::intermediateDecode()
|
||||
{
|
||||
return model_->decode(decoder_state_.get());
|
||||
return model_->decode(decoder_state_);
|
||||
}
|
||||
|
||||
char*
|
||||
StreamingState::finishStream()
|
||||
{
|
||||
finalizeStream();
|
||||
return model_->decode(decoder_state_.get());
|
||||
return model_->decode(decoder_state_);
|
||||
}
|
||||
|
||||
Metadata*
|
||||
StreamingState::finishStreamWithMetadata()
|
||||
{
|
||||
finalizeStream();
|
||||
return model_->decode_metadata(decoder_state_.get());
|
||||
return model_->decode_metadata(decoder_state_);
|
||||
}
|
||||
|
||||
void
|
||||
@ -244,23 +244,15 @@ StreamingState::processBatch(const vector<float>& buf, unsigned int n_steps)
|
||||
previous_state_c_,
|
||||
previous_state_h_);
|
||||
|
||||
const int cutoff_top_n = 40;
|
||||
const double cutoff_prob = 1.0;
|
||||
const size_t num_classes = model_->alphabet_->GetSize() + 1; // +1 for blank
|
||||
const int n_frames = logits.size() / (ModelState::BATCH_SIZE * num_classes);
|
||||
|
||||
// Convert logits to double
|
||||
vector<double> inputs(logits.begin(), logits.end());
|
||||
|
||||
decoder_next(inputs.data(),
|
||||
*model_->alphabet_,
|
||||
decoder_state_.get(),
|
||||
n_frames,
|
||||
num_classes,
|
||||
cutoff_prob,
|
||||
cutoff_top_n,
|
||||
model_->beam_width_,
|
||||
model_->scorer_);
|
||||
decoder_state_.next(inputs.data(),
|
||||
n_frames,
|
||||
num_classes);
|
||||
}
|
||||
|
||||
int
|
||||
@ -340,8 +332,6 @@ DS_SetupStream(ModelState* aCtx,
|
||||
return DS_ERR_FAIL_CREATE_STREAM;
|
||||
}
|
||||
|
||||
const size_t num_classes = aCtx->alphabet_->GetSize() + 1; // +1 for blank
|
||||
|
||||
ctx->audio_buffer_.reserve(aCtx->audio_win_len_);
|
||||
ctx->mfcc_buffer_.reserve(aCtx->mfcc_feats_per_timestep_);
|
||||
ctx->mfcc_buffer_.resize(aCtx->n_features_*aCtx->n_context_, 0.f);
|
||||
@ -350,7 +340,14 @@ DS_SetupStream(ModelState* aCtx,
|
||||
ctx->previous_state_h_.resize(aCtx->state_size_, 0.f);
|
||||
ctx->model_ = aCtx;
|
||||
|
||||
ctx->decoder_state_.reset(decoder_init(*aCtx->alphabet_, num_classes, aCtx->scorer_));
|
||||
const int cutoff_top_n = 40;
|
||||
const double cutoff_prob = 1.0;
|
||||
|
||||
ctx->decoder_state_.init(*aCtx->alphabet_,
|
||||
aCtx->beam_width_,
|
||||
cutoff_prob,
|
||||
cutoff_top_n,
|
||||
aCtx->scorer_);
|
||||
|
||||
*retval = ctx.release();
|
||||
return DS_ERR_OK;
|
||||
|
||||
@ -41,24 +41,17 @@ ModelState::init(const char* model_path,
|
||||
return DS_ERR_OK;
|
||||
}
|
||||
|
||||
vector<Output>
|
||||
ModelState::decode_raw(DecoderState* state)
|
||||
{
|
||||
vector<Output> out = decoder_decode(state, *alphabet_, beam_width_, scorer_);
|
||||
return out;
|
||||
}
|
||||
|
||||
char*
|
||||
ModelState::decode(DecoderState* state)
|
||||
ModelState::decode(const DecoderState& state)
|
||||
{
|
||||
vector<Output> out = decode_raw(state);
|
||||
vector<Output> out = state.decode();
|
||||
return strdup(alphabet_->LabelsToString(out[0].tokens).c_str());
|
||||
}
|
||||
|
||||
Metadata*
|
||||
ModelState::decode_metadata(DecoderState* state)
|
||||
ModelState::decode_metadata(const DecoderState& state)
|
||||
{
|
||||
vector<Output> out = decode_raw(state);
|
||||
vector<Output> out = state.decode();
|
||||
|
||||
std::unique_ptr<Metadata> metadata(new Metadata());
|
||||
metadata->num_items = out[0].tokens.size();
|
||||
|
||||
@ -8,7 +8,8 @@
|
||||
|
||||
#include "ctcdecode/scorer.h"
|
||||
#include "ctcdecode/output.h"
|
||||
#include "ctcdecode/decoderstate.h"
|
||||
|
||||
class DecoderState;
|
||||
|
||||
struct ModelState {
|
||||
//TODO: infer batch size from model/use dynamic batch size
|
||||
@ -59,16 +60,6 @@ struct ModelState {
|
||||
std::vector<float>& state_c_output,
|
||||
std::vector<float>& state_h_output) = 0;
|
||||
|
||||
/**
|
||||
* @brief Perform decoding of the logits, using basic CTC decoder or
|
||||
* CTC decoder with KenLM enabled
|
||||
*
|
||||
* @param state Decoder state to use when decoding.
|
||||
*
|
||||
* @return Vector of Output structs directly from the CTC decoder for additional processing.
|
||||
*/
|
||||
virtual std::vector<Output> decode_raw(DecoderState* state);
|
||||
|
||||
/**
|
||||
* @brief Perform decoding of the logits, using basic CTC decoder or
|
||||
* CTC decoder with KenLM enabled
|
||||
@ -77,7 +68,7 @@ struct ModelState {
|
||||
*
|
||||
* @return String representing the decoded text.
|
||||
*/
|
||||
virtual char* decode(DecoderState* state);
|
||||
virtual char* decode(const DecoderState& state);
|
||||
|
||||
/**
|
||||
* @brief Return character-level metadata including letter timings.
|
||||
@ -87,7 +78,7 @@ struct ModelState {
|
||||
* @return Metadata struct containing MetadataItem structs for each character.
|
||||
* The user is responsible for freeing Metadata by calling DS_FreeMetadata().
|
||||
*/
|
||||
virtual Metadata* decode_metadata(DecoderState* state);
|
||||
virtual Metadata* decode_metadata(const DecoderState& state);
|
||||
};
|
||||
|
||||
#endif // MODELSTATE_H
|
||||
|
||||
Loading…
Reference in New Issue
Block a user