#include "decoder_utils.h" #include #include #include std::vector> get_pruned_log_probs( const double *prob_step, size_t class_dim, double cutoff_prob, size_t cutoff_top_n) { std::vector> prob_idx; for (size_t i = 0; i < class_dim; ++i) { prob_idx.push_back(std::pair(i, prob_step[i])); } // pruning of vacobulary size_t cutoff_len = class_dim; if (cutoff_prob < 1.0 || cutoff_top_n < cutoff_len) { std::sort( prob_idx.begin(), prob_idx.end(), pair_comp_second_rev); if (cutoff_prob < 1.0) { double cum_prob = 0.0; cutoff_len = 0; for (size_t i = 0; i < prob_idx.size(); ++i) { cum_prob += prob_idx[i].second; cutoff_len += 1; if (cum_prob >= cutoff_prob || cutoff_len >= cutoff_top_n) break; } } prob_idx = std::vector>( prob_idx.begin(), prob_idx.begin() + cutoff_len); } std::vector> log_prob_idx; for (size_t i = 0; i < cutoff_len; ++i) { log_prob_idx.push_back(std::pair( prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN))); } return log_prob_idx; } size_t get_utf8_str_len(const std::string &str) { size_t str_len = 0; for (char c : str) { str_len += ((c & 0xc0) != 0x80); } return str_len; } std::vector split_into_codepoints(const std::string &str) { std::vector result; std::string out_str; for (char c : str) { if (byte_is_codepoint_boundary(c)) { if (!out_str.empty()) { result.push_back(out_str); out_str.clear(); } } out_str.append(1, c); } result.push_back(out_str); return result; } std::vector split_into_bytes(const std::string &str) { std::vector result; for (char c : str) { std::string ch(1, c); result.push_back(ch); } return result; } std::vector split_str(const std::string &s, const std::string &delim) { std::vector result; std::size_t start = 0, delim_len = delim.size(); while (true) { std::size_t end = s.find(delim, start); if (end == std::string::npos) { if (start < s.size()) { result.push_back(s.substr(start)); } break; } if (end > start) { result.push_back(s.substr(start, end - start)); } start = end + delim_len; } return result; } bool prefix_compare(const PathTrie *x, const PathTrie *y) { if (x->score == y->score) { if (x->character == y->character) { return false; } else { return (x->character < y->character); } } else { return x->score > y->score; } } bool prefix_compare_external(const PathTrie *x, const PathTrie *y, const std::unordered_map& scores) { if (scores.at(x) == scores.at(y)) { if (x->character == y->character) { return false; } else { return (x->character < y->character); } } else { return scores.at(x) > scores.at(y); } } void add_word_to_fst(const std::vector &word, fst::StdVectorFst *dictionary) { if (dictionary->NumStates() == 0) { fst::StdVectorFst::StateId start = dictionary->AddState(); assert(start == 0); dictionary->SetStart(start); } fst::StdVectorFst::StateId src = dictionary->Start(); fst::StdVectorFst::StateId dst; for (auto c : word) { dst = dictionary->AddState(); dictionary->AddArc(src, fst::StdArc(c, c, 0, dst)); src = dst; } dictionary->SetFinal(dst, fst::StdArc::Weight::One()); } bool add_word_to_dictionary( const std::string &word, const std::unordered_map &char_map, bool utf8, int SPACE_ID, fst::StdVectorFst *dictionary) { auto characters = utf8 ? split_into_bytes(word) : split_into_codepoints(word); std::vector int_word; for (auto &c : characters) { auto int_c = char_map.find(c); if (int_c != char_map.end()) { int_word.push_back(int_c->second); } else { return false; // return without adding } } if (!utf8) { int_word.push_back(SPACE_ID); } add_word_to_fst(int_word, dictionary); return true; // return with successful adding }