Skip to content

Commit 4a1069a

Browse files
taku910hiroyuki-komatsu
authored andcommitted
Removed legacy ShouldRevertTypingCorrection.
Migrating to new TypingCorrectionReranker component. PiperOrigin-RevId: 690046621
1 parent 550b151 commit 4a1069a

5 files changed

+24
-268
lines changed

src/engine/supplemental_model_interface.h

-10
Original file line numberDiff line numberDiff line change
@@ -86,16 +86,6 @@ class SupplementalModelInterface {
8686
const ConversionRequest &request, const Segments &segments,
8787
std::vector<absl::Nonnull<const prediction::Result *>> *results) const {}
8888

89-
// Returns true if the final typing correct result is not confident.
90-
// TODO(taku): Remove this function after finishing the migration of
91-
// the more general SuppressTypingCorrection method.
92-
virtual bool ShouldRevertTypingCorrection(
93-
const ConversionRequest &request, const Segments &segments,
94-
absl::Span<const prediction::Result> literal_results,
95-
absl::Span<const prediction::Result> typing_corrected_results) const {
96-
return false;
97-
}
98-
9989
// Performs general post correction on `segments`.
10090
virtual void PostCorrect(const ConversionRequest &request,
10191
absl::Nonnull<Segments *> segments) const {}

src/engine/supplemental_model_mock.h

-5
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,6 @@ class MockSupplementalModel : public SupplementalModelInterface {
6868
(const ConversionRequest &request, const Segments &segments,
6969
std::vector<absl::Nonnull<const prediction::Result *>> *results),
7070
(const, override));
71-
MOCK_METHOD(bool, ShouldRevertTypingCorrection,
72-
(const ConversionRequest &request, const Segments &segments,
73-
absl::Span<const prediction::Result> literal_results,
74-
absl::Span<const prediction::Result> typing_corrected_results),
75-
(const, override));
7671
MOCK_METHOD(void, PostCorrect,
7772
(const ConversionRequest &, absl::Nonnull<Segments *> segments),
7873
(const, override));

src/prediction/dictionary_predictor.cc

+11-112
Original file line numberDiff line numberDiff line change
@@ -308,14 +308,11 @@ bool DictionaryPredictor::PredictForRequest(const ConversionRequest &request,
308308
RewriteResultsForPrediction(request, *segments, &results);
309309

310310
// Explicitly populate the typing corrected results.
311-
const TypingCorrectionMixingParams typing_correction_mixing_params =
312-
MaybePopulateTypingCorrectedResults(request, *segments, &results);
311+
MaybePopulateTypingCorrectedResults(request, *segments, &results);
313312

314313
MaybeRescoreResults(request, *segments, absl::MakeSpan(results));
315314

316-
return AddPredictionToCandidates(request, segments,
317-
typing_correction_mixing_params,
318-
absl::MakeSpan(results));
315+
return AddPredictionToCandidates(request, segments, absl::MakeSpan(results));
319316
}
320317

321318
void DictionaryPredictor::RewriteResultsForPrediction(
@@ -343,42 +340,30 @@ void DictionaryPredictor::RewriteResultsForPrediction(
343340
}
344341
}
345342

346-
TypingCorrectionMixingParams
347-
DictionaryPredictor::MaybePopulateTypingCorrectedResults(
343+
void DictionaryPredictor::MaybePopulateTypingCorrectedResults(
348344
const ConversionRequest &request, const Segments &segments,
349345
std::vector<Result> *results) const {
350-
if (!IsTypingCorrectionEnabled(request)) {
351-
return {};
352-
}
353-
354-
if (results->empty()) {
355-
return {};
346+
if (!IsTypingCorrectionEnabled(request) || results->empty()) {
347+
return;
356348
}
357349

358350
const size_t key_len = Util::CharsLen(segments.conversion_segment(0).key());
359351
constexpr int kMinTypingCorrectionKeyLen = 3;
360352
if (key_len < kMinTypingCorrectionKeyLen) {
361-
return {};
353+
return;
362354
}
363355

364356
std::vector<Result> typing_corrected_results =
365357
aggregator_->AggregateTypingCorrectedResults(request, segments);
366358
RewriteResultsForPrediction(request, segments, &typing_corrected_results);
367359

368-
const TypingCorrectionMixingParams typing_correction_mixing_params =
369-
GetTypingCorrectionMixingParams(request, segments, *results,
370-
typing_corrected_results);
371-
372360
for (auto &result : typing_corrected_results) {
373361
results->emplace_back(std::move(result));
374362
}
375-
376-
return typing_correction_mixing_params;
377363
}
378364

379365
bool DictionaryPredictor::AddPredictionToCandidates(
380366
const ConversionRequest &request, Segments *segments,
381-
const TypingCorrectionMixingParams &typing_correction_mixing_params,
382367
absl::Span<Result> results) const {
383368
DCHECK(segments);
384369

@@ -462,14 +447,8 @@ bool DictionaryPredictor::AddPredictionToCandidates(
462447
final_results_ptrs.emplace_back(&result);
463448
}
464449

465-
const auto &params = request.request().decoder_experiment_params();
466-
if (params.typing_correction_result_reranker_mode() > 0) {
467-
MaybeRerankAggressiveTypingCorrection(request, *segments,
468-
&final_results_ptrs);
469-
} else {
470-
MaybeSuppressAggressiveTypingCorrection(
471-
request, typing_correction_mixing_params, &final_results_ptrs);
472-
}
450+
MaybeRerankAggressiveTypingCorrection(request, *segments,
451+
&final_results_ptrs);
473452

474453
// Fill segments from final_results_ptrs.
475454
for (const Result *result : final_results_ptrs) {
@@ -491,72 +470,15 @@ bool DictionaryPredictor::AddPredictionToCandidates(
491470
void DictionaryPredictor::MaybeRerankAggressiveTypingCorrection(
492471
const ConversionRequest &request, const Segments &segments,
493472
std::vector<absl::Nonnull<const Result *>> *results) const {
494-
const auto &params = request.request().decoder_experiment_params();
495-
if (params.typing_correction_result_reranker_mode() == 0) return;
496-
473+
if (!IsTypingCorrectionEnabled(request) || results->empty()) {
474+
return;
475+
}
497476
const engine::SupplementalModelInterface *supplemental_model =
498477
modules_.GetSupplementalModel();
499478
if (supplemental_model == nullptr) return;
500-
501479
supplemental_model->RerankTypingCorrection(request, segments, results);
502480
}
503481

504-
// static
505-
void DictionaryPredictor::MaybeSuppressAggressiveTypingCorrection(
506-
const ConversionRequest &request,
507-
const TypingCorrectionMixingParams &typing_correction_mixing_params,
508-
std::vector<absl::Nonnull<const Result *>> *results) {
509-
if (results->empty()) return;
510-
511-
// Top is already literal.
512-
const auto &top_result = results->front();
513-
514-
auto is_typing_correction = [&](const Result &result) {
515-
return (
516-
result.types & PredictionType::TYPING_CORRECTION ||
517-
(result.candidate_attributes & Segment::Candidate::TYPING_CORRECTION));
518-
};
519-
520-
if (!is_typing_correction(*top_result)) {
521-
return;
522-
}
523-
524-
const bool force_literal_on_top =
525-
typing_correction_mixing_params.literal_on_top;
526-
const bool literal_at_least_second =
527-
typing_correction_mixing_params.literal_at_least_second;
528-
529-
if (!force_literal_on_top && !literal_at_least_second) {
530-
return;
531-
}
532-
533-
auto promote_result = [&results](int old_idx, int new_idx) {
534-
const Result *result = (*results)[old_idx];
535-
for (int i = old_idx; i >= new_idx + 1; --i)
536-
(*results)[i] = (*results)[i - 1];
537-
(*results)[new_idx] = result;
538-
};
539-
540-
const int max_size = std::min<int>(10, results->size());
541-
for (int i = 1; i < max_size; ++i) {
542-
const Result *result = (*results)[i];
543-
// Finds the first non-typing-corrected candidate.
544-
if (is_typing_correction(*result)) {
545-
continue;
546-
}
547-
// Replace the literal with top when the cost is close enough or
548-
// force_literal_on_top is true.
549-
if (force_literal_on_top) {
550-
promote_result(i, 0);
551-
} else if (literal_at_least_second && i >= 2) {
552-
// Moves the literal to the second position even when
553-
// literal-on-top condition doesn't match.
554-
promote_result(i, 1);
555-
}
556-
break;
557-
}
558-
}
559-
560482
// static
561483
void DictionaryPredictor::MaybeApplyPostCorrection(
562484
const ConversionRequest &request, const engine::Modules &modules,
@@ -1423,29 +1345,6 @@ std::shared_ptr<Result> DictionaryPredictor::MaybeGetPreviousTopResult(
14231345
return nullptr;
14241346
}
14251347

1426-
// Computes the typing correction mixing params.
1427-
// from the `literal_result` and `typing_corrected_results`
1428-
TypingCorrectionMixingParams
1429-
DictionaryPredictor::GetTypingCorrectionMixingParams(
1430-
const ConversionRequest &request, const Segments &segments,
1431-
absl::Span<const Result> literal_results,
1432-
absl::Span<const Result> typing_corrected_results) const {
1433-
TypingCorrectionMixingParams typing_correction_mixing_params;
1434-
1435-
const engine::SupplementalModelInterface *supplemental_model =
1436-
modules_.GetSupplementalModel();
1437-
1438-
if (supplemental_model) {
1439-
typing_correction_mixing_params.literal_on_top =
1440-
supplemental_model->ShouldRevertTypingCorrection(
1441-
request, segments, literal_results, typing_corrected_results);
1442-
}
1443-
1444-
typing_correction_mixing_params.literal_at_least_second = true;
1445-
1446-
return typing_correction_mixing_params;
1447-
}
1448-
14491348
} // namespace mozc::prediction
14501349

14511350
#undef MOZC_WORD_LOG_MESSAGE

src/prediction/dictionary_predictor.h

+6-33
Original file line numberDiff line numberDiff line change
@@ -68,20 +68,6 @@ struct KeyValueView {
6868

6969
} // namespace dictionary_predictor_internal
7070

71-
// Parameters to mix the literal and typing corrected results.
72-
// These parameters define the position of literal and typing corrected
73-
// results, and determined dynamically using various quality signals.
74-
struct TypingCorrectionMixingParams {
75-
// Moves the literal candidate to the top position even when
76-
// the typing corrected result is placed at top.
77-
// Set this flag when the typing correction is less confident.
78-
bool literal_on_top = false;
79-
80-
// Moves the literal candidate to the at least second position.
81-
// When the literal candidate is already at the top, do nothing.
82-
bool literal_at_least_second = false;
83-
};
84-
8571
// Dictionary-based predictor
8672
class DictionaryPredictor : public PredictorInterface {
8773
public:
@@ -156,24 +142,16 @@ class DictionaryPredictor : public PredictorInterface {
156142
aggregator,
157143
const ImmutableConverterInterface *immutable_converter);
158144

159-
bool AddPredictionToCandidates(
160-
const ConversionRequest &request, Segments *segments,
161-
const TypingCorrectionMixingParams &typing_correction_mixing_params,
162-
absl::Span<Result> results) const;
145+
bool AddPredictionToCandidates(const ConversionRequest &request,
146+
Segments *segments,
147+
absl::Span<Result> results) const;
163148

164149
void FillCandidate(
165150
const ConversionRequest &request, const Result &result,
166151
dictionary_predictor_internal::KeyValueView key_value,
167152
const absl::flat_hash_map<std::string, int32_t> &merged_types,
168153
Segment::Candidate *candidate) const;
169154

170-
// Computes the typing correction mixing params.
171-
// from the `base_result` and `typing_corrected_results`.
172-
TypingCorrectionMixingParams GetTypingCorrectionMixingParams(
173-
const ConversionRequest &request, const Segments &segments,
174-
absl::Span<const Result> literal_results,
175-
absl::Span<const Result> typing_corrected_results) const;
176-
177155
// Returns the position of misspelled character position.
178156
//
179157
// Example:
@@ -287,19 +265,14 @@ class DictionaryPredictor : public PredictorInterface {
287265
absl::flat_hash_map<PrefixPenaltyKey, int> *cache) const;
288266

289267
// Populates typing corrected results to `results`.
290-
TypingCorrectionMixingParams MaybePopulateTypingCorrectedResults(
291-
const ConversionRequest &request, const Segments &segments,
292-
std::vector<Result> *results) const;
268+
void MaybePopulateTypingCorrectedResults(const ConversionRequest &request,
269+
const Segments &segments,
270+
std::vector<Result> *results) const;
293271

294272
void MaybeRerankAggressiveTypingCorrection(
295273
const ConversionRequest &request, const Segments &segments,
296274
std::vector<absl::Nonnull<const Result *>> *results) const;
297275

298-
static void MaybeSuppressAggressiveTypingCorrection(
299-
const ConversionRequest &request,
300-
const TypingCorrectionMixingParams &typing_correction_mixing_params,
301-
std::vector<absl::Nonnull<const Result *>> *results);
302-
303276
static void MaybeApplyPostCorrection(const ConversionRequest &request,
304277
const engine::Modules &modules,
305278
Segments *segments);

0 commit comments

Comments
 (0)