Source code for asr_eval.models.gigaam_wrapper

from __future__ import annotations
from abc import ABC, abstractmethod
import shutil
from typing import TYPE_CHECKING, Literal, cast, override
import warnings

if TYPE_CHECKING:
    from gigaam.model import GigaAMASR, GigaAM, GigaAMEmo
    import torch

from asr_eval.models.base.interfaces import CTC, Transcriber
from asr_eval.utils.types import FLOATS


__all__ = [
    'GigaAMShortformBase',
    'GigaAMShortformRNNT',
    'GigaAMShortformCTC',
]


SAMPLING_RATE = 16000
FREQ = 25  # GigaAM2 encoder outputs per second


[docs] class GigaAMShortformBase(Transcriber, ABC): ''' An abstract class for GigaAM model, either CTC or RNNT. Implementations: - :class:`~asr_eval.models.gigaam_wrapper.GigaAMShortformCTC` - :class:`~asr_eval.models.gigaam_wrapper.GigaAMShortformRNNT` ''' @abstractmethod def _get_model(self) -> GigaAMASR: """Instantiate a model.""" def __init__(self): import gigaam from gigaam.model import SAMPLE_RATE, LONGFORM_THRESHOLD from gigaam.decoding import CTCGreedyDecoding if '_MODEL_HASHES' not in dir(gigaam): raise RuntimeError( 'You seem to have installed old gigaam version from PyPI' ', update from https://github.com/salute-developers/GigaAM' ) self.SAMPLE_RATE = SAMPLE_RATE self.LONGFORM_THRESHOLD = LONGFORM_THRESHOLD self.CTCGreedyDecoding = CTCGreedyDecoding
[docs] @override def transcribe(self, waveform: FLOATS) -> str: import torch with torch.inference_mode(): # a forward + decoding pipeline suitable for both CTC and RNNT model = self._get_model() inputs = ( torch.tensor(waveform) .to(model._device) # pyright:ignore[reportPrivateUsage] .to(model._dtype) # pyright:ignore[reportPrivateUsage] .unsqueeze(0) ) length = torch.full([1], inputs.shape[-1], device=model._device) # pyright:ignore[reportPrivateUsage] encoded, encoded_len = model.forward(inputs, length) return model.decoding.decode(model.head, encoded, encoded_len)[0]
def _gigaam_load_model( model_name: str, fp16_encoder: bool = True, use_flash: bool | None = False, device: str | torch.device | None = None, download_root: str | None = None, ) -> GigaAM | GigaAMEmo | GigaAMASR: """Same as gigaam.load_model, but will remove half-downloaded checkpoints. """ import gigaam try: return gigaam.load_model( model_name=model_name, fp16_encoder=fp16_encoder, use_flash=use_flash, device=device, download_root=download_root, ) except RuntimeError as e: print(f'GigaAM failed to load: {e}') directory = ( download_root if download_root is not None else gigaam._CACHE_DIR # pyright: ignore[reportPrivateUsage] ) print(f'Removing GigaAM cache dir {directory} and trying again') shutil.rmtree(directory) return gigaam.load_model( model_name=model_name, fp16_encoder=fp16_encoder, use_flash=use_flash, device=device, download_root=download_root, )
[docs] class GigaAMShortformRNNT(GigaAMShortformBase): """A GigaAM RNNT model. Supports different versions (see :code:`version` parameter): "v2", "v3", "v3_e2e". Installation: see :doc:`/guide_installation` page. """ def __init__( self, version: Literal['v2', 'v3', 'v3_e2e'], device: str | torch.device = 'cuda', fp16: bool = False, ): from gigaam.model import GigaAMASR # type: ignore super().__init__() self._model = cast(GigaAMASR, _gigaam_load_model( f'{version}_rnnt', fp16_encoder=fp16, device=device )) self._model.eval() @override def _get_model(self) -> GigaAMASR: return self._model
[docs] class GigaAMShortformCTC(GigaAMShortformBase, CTC): """A GigaAM CTC model. Supports different versions (see :code:`version` parameter): "v2", "v3", "v3_e2e". Installation: see :doc:`/guide_installation` page. """ def __init__( self, version: Literal['v2', 'v3', 'v3_e2e'], device: str | torch.device = 'cuda', fp16: bool = False, ): from gigaam.model import GigaAMASR # type: ignore from gigaam.decoding import Tokenizer from sentencepiece import SentencePieceProcessor super().__init__() self._model = cast(GigaAMASR, _gigaam_load_model( f'{version}_ctc', fp16_encoder=fp16, device=device )) self._model.eval() tokenizer = cast(Tokenizer, self._model.decoding.tokenizer) if tokenizer.charwise: # v2_ctc, v3_ctc have a vocab of characters vocab = self._model.decoding.tokenizer.vocab.copy() vocab.insert(self.blank_id, '') else: # v3_e2e_ctc has a vocab as a sentencepiece tokenizer sp = cast( SentencePieceProcessor, self._model.decoding.tokenizer.model ) vocab_size = cast(int, sp.GetPieceSize()) # check that the tokenizer does not contain byte fragments # so we can equivalently decode from vocab as `list[str]` assert [ '\ufffd' not in sp.DecodeIds([id]) # type: ignore for id in range(sp.GetPieceSize()) # type: ignore ] vocab = [ # replacing special char that means a space cast(str, sp.IdToPiece(id)).replace('▁', ' ') # type: ignore for id in range(vocab_size) # if not sp.IsControl(id) and not sp.IsUnknown(id) ] self._vocab = tuple(vocab) @override def _get_model(self) -> GigaAMASR: return self._model
[docs] @override def transcribe(self, waveform: FLOATS) -> str: # we have two base classes: GigaAMShortformBase and CTC # usually we want to use GigaAMShortformBase base class to call # .transcribe() return super(GigaAMShortformBase, self).transcribe(waveform)
@property @override def blank_id(self) -> int: return self._model.decoding.blank_id @property @override def tick_size(self) -> float: return 1 / FREQ @property @override def vocab(self) -> tuple[str, ...]: return self._vocab
[docs] @override def ctc_log_probs(self, waveforms: list[FLOATS]) -> list[FLOATS]: import torch from torch.nn.utils.rnn import pad_sequence with torch.inference_mode(): # Sampling rate should be equal to # gigaam.preprocess.SAMPLE_RATE == 16_000. assert isinstance(self._model.decoding, self.CTCGreedyDecoding) for waveform in waveforms: if len(waveform) / self.SAMPLE_RATE > self.LONGFORM_THRESHOLD: warnings.warn( 'too long audio,' ' GigaAMASR.transcribe() would throw an error', RuntimeWarning ) waveform_tensors = [ torch.tensor(w, dtype=self._model._dtype) # pyright: ignore[reportPrivateUsage] .to(self._model._device) # pyright: ignore[reportPrivateUsage] for w in waveforms ] lengths = ( torch.tensor([len(w) for w in waveforms]) .to(self._model._device) # pyright: ignore[reportPrivateUsage] ) waveform_tensors_padded = pad_sequence( waveform_tensors, batch_first=True, padding_value=0, ) encoded, encoded_len = self._model.forward( waveform_tensors_padded, lengths ) log_probs = cast( torch.Tensor, self._model.head(encoder_output=encoded) ) # exp(log_probs) sums to 1 return [ _log_probs[:length].cpu().numpy() # type: ignore for _log_probs, length in zip(log_probs, encoded_len) ]