Source code for asr_eval.correction.comparator_wordfreq
from typing import cast, override
import string
from asr_eval.align.solvers.dynprog import solve_optimal_alignment
from asr_eval.align.parsing import DEFAULT_PARSER
from asr_eval.align.transcription import Token
from asr_eval.linguistics.linguistics import word_freq
from asr_eval.models.base.interfaces import Transcriber
from asr_eval.utils.types import FLOATS
[docs]
class RuWordFreqComparator(Transcriber):
"""A composite transcriber that retrieves prediction for two models
(where the first is generally better) and replaces words predicted
by the first model by words predicted by the second model in cases
where the second model's word is more frequent in language.
Works for Russian language currently.
"""
def __init__(
self,
base_model: Transcriber, # TODO add special code for TimedTranscriber
additional_model: Transcriber,
):
self.base_model = base_model
self.additional_model = additional_model
[docs]
@override
def transcribe(self, waveform: FLOATS) -> str:
base_text = DEFAULT_PARSER.parse_single_variant_transcription(
self.base_model.transcribe(waveform)
)
additional_text = DEFAULT_PARSER.parse_single_variant_transcription(
self.additional_model.transcribe(waveform)
)
matches, _ = solve_optimal_alignment(base_text, additional_text)
skip_chars = set(string.ascii_lowercase + string.digits)
def shoud_skip(word: str) -> bool:
return len(skip_chars.intersection(set(word.lower()))) > 0
replacements: list[tuple[Token, str]] = []
for match in matches.matches:
if match.status == 'replacement':
true_word = cast(str, cast(Token, match.true).value)
pred_word = cast(str, cast(Token, match.pred).value)
if shoud_skip(true_word) or shoud_skip(pred_word):
continue
if word_freq(pred_word, 'ru') > word_freq(true_word, 'ru'):
replacements.append((cast(Token, match.true), pred_word))
replaced_text = base_text.text
for token, replacement in replacements[::-1]:
replaced_text = (
replaced_text[:token.start_pos]
+ replacement
+ replaced_text[token.end_pos:]
)
return replaced_text