DeepSpeech/native_client/ctcdecode/path_trie.cpp
Reuben Morais e3bf5d3cc6 Only update time step of leaf prefixes
The intention of this check is to improve the accuracy of the timings by recording the time step where the character saw its highest probability rather than the first time step where it was seen. The problem happens when updating the time step of a prefix that already has children. In that case, if any of the children have a time step that is earlier than `new_timestep`, it'll break the linearity of the timings. My fix is to simply check that the prefix we're updating is a leaf.

For example, say during decoding we have the following beams (format is `(char | time)`, tree node id below, nodes with same id are the same object):

```
1. (-1 | 0 ) -> ('s' | 10) -> ('h' | 13) -> ('e' | 14)
        A                B                  C                D

2. (-1 | 0 ) -> ('s' | 10) -> ('h' | 14)
        A                B                  E
```

And the prefix list is [B, C, D, E]. Currently, if we process character 'h' in time step 15 with a probability higher than both C and E, we update both nodes to have time step 15, which breaks linearity in beam 1. With my fix, we only update node E, which is a leaf. In my tests this does fix the problem, but since we don't have any known good quality data to verify against, it's hard to know if it has other side effects.
2019-08-20 12:03:59 +02:00

177 lines
5.2 KiB
C++

#include "path_trie.h"
#include <algorithm>
#include <limits>
#include <memory>
#include <utility>
#include <vector>
#include "decoder_utils.h"
PathTrie::PathTrie() {
log_prob_b_prev = -NUM_FLT_INF;
log_prob_nb_prev = -NUM_FLT_INF;
log_prob_b_cur = -NUM_FLT_INF;
log_prob_nb_cur = -NUM_FLT_INF;
log_prob_c = -NUM_FLT_INF;
score = -NUM_FLT_INF;
ROOT_ = -1;
character = ROOT_;
timestep = 0;
exists_ = true;
parent = nullptr;
dictionary_ = nullptr;
dictionary_state_ = 0;
has_dictionary_ = false;
matcher_ = nullptr;
}
PathTrie::~PathTrie() {
for (auto child : children_) {
delete child.second;
}
}
PathTrie* PathTrie::get_path_trie(int new_char, int new_timestep, float cur_log_prob_c, bool reset) {
auto child = children_.begin();
for (child = children_.begin(); child != children_.end(); ++child) {
if (child->first == new_char) {
// If existing child matches this new_char but had a lower probability,
// and it's a leaf, update its timestep to new_timestep.
// The leak check makes sure we don't update the child to have a later
// timestep than a grandchild.
if (child->second->log_prob_c < cur_log_prob_c &&
child->second->children_.size() == 0) {
child->second->log_prob_c = cur_log_prob_c;
child->second->timestep = new_timestep;
}
break;
}
}
if (child != children_.end()) {
if (!child->second->exists_) {
child->second->exists_ = true;
child->second->log_prob_b_prev = -NUM_FLT_INF;
child->second->log_prob_nb_prev = -NUM_FLT_INF;
child->second->log_prob_b_cur = -NUM_FLT_INF;
child->second->log_prob_nb_cur = -NUM_FLT_INF;
}
return child->second;
} else {
if (has_dictionary_) {
matcher_->SetState(dictionary_state_);
bool found = matcher_->Find(new_char + 1);
if (!found) {
// Adding this character causes word outside dictionary
auto FSTZERO = fst::TropicalWeight::Zero();
auto final_weight = dictionary_->Final(dictionary_state_);
bool is_final = (final_weight != FSTZERO);
if (is_final && reset) {
dictionary_state_ = dictionary_->Start();
}
return nullptr;
} else {
PathTrie* new_path = new PathTrie;
new_path->character = new_char;
new_path->timestep = new_timestep;
new_path->parent = this;
new_path->dictionary_ = dictionary_;
new_path->has_dictionary_ = true;
new_path->matcher_ = matcher_;
new_path->log_prob_c = cur_log_prob_c;
// set spell checker state
// check to see if next state is final
auto FSTZERO = fst::TropicalWeight::Zero();
auto final_weight = dictionary_->Final(matcher_->Value().nextstate);
bool is_final = (final_weight != FSTZERO);
if (is_final && reset) {
// restart spell checker at the start state
new_path->dictionary_state_ = dictionary_->Start();
} else {
// go to next state
new_path->dictionary_state_ = matcher_->Value().nextstate;
}
children_.push_back(std::make_pair(new_char, new_path));
return new_path;
}
} else {
PathTrie* new_path = new PathTrie;
new_path->character = new_char;
new_path->timestep = new_timestep;
new_path->parent = this;
new_path->log_prob_c = cur_log_prob_c;
children_.push_back(std::make_pair(new_char, new_path));
return new_path;
}
}
}
PathTrie* PathTrie::get_path_vec(std::vector<int>& output, std::vector<int>& timesteps) {
return get_path_vec(output, timesteps, ROOT_);
}
PathTrie* PathTrie::get_path_vec(std::vector<int>& output,
std::vector<int>& timesteps,
int stop,
size_t max_steps) {
if (character == stop || character == ROOT_ || output.size() == max_steps) {
std::reverse(output.begin(), output.end());
std::reverse(timesteps.begin(), timesteps.end());
return this;
} else {
output.push_back(character);
timesteps.push_back(timestep);
return parent->get_path_vec(output, timesteps, stop, max_steps);
}
}
void PathTrie::iterate_to_vec(std::vector<PathTrie*>& output) {
if (exists_) {
log_prob_b_prev = log_prob_b_cur;
log_prob_nb_prev = log_prob_nb_cur;
log_prob_b_cur = -NUM_FLT_INF;
log_prob_nb_cur = -NUM_FLT_INF;
score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev);
output.push_back(this);
}
for (auto child : children_) {
child.second->iterate_to_vec(output);
}
}
void PathTrie::remove() {
exists_ = false;
if (children_.size() == 0) {
for (auto child = parent->children_.begin(); child != parent->children_.end(); ++child) {
if (child->first == character) {
parent->children_.erase(child);
break;
}
}
if (parent->children_.size() == 0 && !parent->exists_) {
parent->remove();
}
delete this;
}
}
void PathTrie::set_dictionary(PathTrie::FstType* dictionary) {
dictionary_ = dictionary;
dictionary_state_ = dictionary->Start();
has_dictionary_ = true;
}
void PathTrie::set_matcher(std::shared_ptr<fst::SortedMatcher<FstType>> matcher) {
matcher_ = matcher;
}