Source code for asr_eval.models.vikhr_wrapper
from __future__ import annotations
from typing import TYPE_CHECKING, override
if TYPE_CHECKING:
from transformers.models.whisper.feature_extraction_whisper import (
WhisperFeatureExtractor
)
from transformers.generation.utils import GenerationMixin
from asr_eval.models.base.interfaces import Transcriber
from asr_eval.utils.types import FLOATS
__all__ = [
'VikhrBorealisWrapper',
]
[docs]
class VikhrBorealisWrapper(Transcriber):
"""A Vikhr Borealis wrapper.
Loading a model takes a long time, around 2 min.
Installation: see :doc:`/guide_installation` page.
"""
model: GenerationMixin
# tokenizer: Qwen2TokenizerFast
extractor: WhisperFeatureExtractor
def __init__(self):
from transformers import AutoModelForCausalLM, AutoFeatureExtractor
model_name = 'Vikhrmodels/Borealis'
self.model = AutoModelForCausalLM.from_pretrained( # type: ignore
model_name, trust_remote_code=True
)
# self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.extractor = AutoFeatureExtractor.from_pretrained(model_name) # type: ignore
self.model.eval() # type: ignore
self.model = self.model.to('cuda') # type: ignore
[docs]
@override
def transcribe(self, waveform: FLOATS) -> str:
import torch
proc = self.extractor(
waveform,
sampling_rate=16_000,
padding='max_length',
max_length=16_000 * 30,
return_attention_mask=True,
return_tensors='pt',
)
with torch.inference_mode():
return self.model.generate( # type: ignore
mel=proc.input_features.squeeze(0).to('cuda'), # type: ignore
att_mask=proc.attention_mask.squeeze(0).to('cuda'), # type: ignore
max_new_tokens=350,
do_sample=True,
top_p=0.9,
top_k=50,
temperature=0.2,
)