asr_eval.ctc

Utils for CTC models.

asr_eval.ctc.ctc_mapping(symbols, blank)[source]

Performs a CTC mapping: first removes duplicates, then removes blank tokens.

Example

>>> x = list('_________дджжой   иссто__ч_ни__ки_________   _иссто_ри__и')
>>> ctc_mapping(x, blank='_') == list('джой источники истории')
True
asr_eval.ctc.forced_alignment.forced_alignment(log_probs, true_tokens, blank_id=0)[source]

Performs a forced alignment.

Returns the path with the highest cumulative probability among all paths that match the specified transcription.

Parameters:
  • log_probs (ndarray[tuple[int, ...], dtype[floating[Any]]]) – log probabilities from CTC model.

  • true_tokens (list[int] | ndarray[tuple[int, ...], dtype[integer[Any]]]) – a sequence of tokens for the ground truth transcription.

  • blank_id (int) – an index for <blank> CTC token.

Return type:

tuple[list[int], list[float], list[tuple[int, int]]]

Returns:

A tuple. The first element is the token for each frame. The second element is a probability for each frame. The third element is a frame span (start_position, end_position) for each of true_tokens.

Note

This is going to stop working for torchaudio>=2.9.0, see https://github.com/pytorch/audio/issues/3902 . It is possible to use recursion_forced_alignment(), but there may be problems with recursion limit (recursion limit is 1000 for Python, equals 20 sec with 50 ticks/sec). To use custom implementation, set environmental variable FORCED_ALIGN_CUSTOM=1.

asr_eval.ctc.forced_alignment.forced_alignment_via_recursion(log_probs, true_tokens, blank_id=0)[source]

Performs forced alignment via a custom recursive algorithm.

Return type:

tuple[list[int], list[float]]

Parameters:
  • log_probs (ndarray[tuple[int, ...], dtype[floating[Any]]])

  • true_tokens (list[int] | ndarray[tuple[int, ...], dtype[integer[Any]]])

  • blank_id (int)

class asr_eval.ctc.lm.CTCDecoderWithLM(ctc_model, kenlm_path, unigrams=None, alpha=0.5, beta=1.0, beam_width=100, unk_score_offset=-10.0, lm_score_boundary=True, hotwords=None, hotword_weight=10.0, speedup_hotwords=True, beam_prune_logp=-10, token_min_logp=-5)[source]

Bases: TimedTranscriber

Performs joint decoding from CTC logits and KenLM language model.

Parameters:
  • ctc_model (CTC) – Any model with CTC interface.

  • kenlm_path (str | Path) – A KenLM model path, usually a .gz or .bin file.

  • unigrams (Collection[str] | None) – Passed into pyctcdecode.build_ctcdecoder.

  • alpha (float) – Weight for language model during shallow fusion.

  • beta (float) – Weight for length score adjustment during scoring.

  • beam_width (int) – Passed into pyctcdecode.build_ctcdecoder.

  • unk_score_offset (float) – Passed into pyctcdecode.build_ctcdecoder.

  • lm_score_boundary (bool) – Passed into pyctcdecode.build_ctcdecoder.

  • hotwords (list[str] | None) – A list of hotwords.

  • hotword_weight (float) – A score for hotwords.

  • speedup_hotwords (bool) – If True, will try to speed up decoding with hotwords and make them not case sensitive, see below.

  • beam_prune_logp (float) – Passed into decoder._decode_logits.

  • token_min_logp (float) – Passed into decoder._decode_logits.

Note

For Vosk models, it shows a warning: “No known unigrams provided, decoding results might be a lot worse.” - this is ok (according to Nikolay V. Shmyrev)

When using with CTCDecoderWithLM, it may show a warning “Found entries of length > 1 in alphabet. This is unusual unless style is BPE, but the alphabet was not recognized as BPE type. Is this correct?” - this is correct for wav2vec2, because .vocab() in Wav2vec2Wrapper may contain special tokens like “<s>”, but they usually are not predicted by the model (should have low logit scores).

If speedup_hotwords=True, will try to speed up decoding with hotwords. Otherwise uses a default pyctcdecode implementation that works as follows:

# create pattern to match full words
# sort by length to get longest possible match
# use lookahead and lookbehind to match on word boundary instead of '\b' to only match
# on space or bos/eos
match_ptn = re.compile(
    r"|".join(
        [
            r"(?<!\S)" + re.escape(s) + r"(?!\S)"
            for s in sorted(hotword_unigrams, key=len, reverse=True)
        ]
    )
)

score = self._weight * len(self._match_ptn.findall(text))

However, hotword_unigrams never contain space. To speedup, with speedup_hotwords=True we replace it wil the following:

hotwords: set[str]
score = self._weight * sum([word in self.hotwords for word in text.split()])

This should work equally. Note that when speedup_hotwords=True, hotwords are not case sensitive, otherwise they are.

decode(waveform)[source]

Accepts a waveform, returns beams, sorted from the best to the worst.

Return type:

list[OutputBeam]

Parameters:

waveform (ndarray[tuple[int, ...], dtype[floating[Any]]])

decode_from_log_probs(log_probs)[source]

Accepts log probs from a CTC model, returns beams, sorted from the best to the worst.

Return type:

list[OutputBeam]

Parameters:

log_probs (ndarray[tuple[int, ...], dtype[floating[Any]]])

timed_transcribe(waveform)[source]

Transcribes a float32 waveform, typically normalized from -1 to 1, into a list of texts with timings. Typically the texts are to be concatenated via space, so leading or trailing spaces in each chunk are not required.

Return type:

list[TimedText]

Parameters:

waveform (ndarray[tuple[int, ...], dtype[floating[Any]]])

class asr_eval.ctc.lm.OutputBeam(text, last_lm_state, text_frames, logit_score, lm_score)[source]

Outputs of BeamSearchDecoderCTC.decode_beams as dataclass.

Is needed for 0.5.0 version, but not for the last version from repo. In asr_eval we use the 0.5.0 version.

Parameters:
  • text (str)

  • last_lm_state (kenlm.State | list[kenlm.State])

  • text_frames (list[WordFrames])

  • logit_score (float)

  • lm_score (float)