Source code for asr_eval.models.flamingo_wrapper
import sys
from typing import Literal, override
from asr_eval.utils.audio_ops import waveform_as_file
from asr_eval.utils.types import FLOATS
from asr_eval.models.base.interfaces import Transcriber
__all__ = [
'FlamingoWrapper',
]
[docs]
class FlamingoWrapper(Transcriber):
'''
A Flamingo transcriber. Not working anymore, TODO fix
Installation: see :doc:`/guide_installation` page.
Authors: Dmitry Ezhov & Oleg Sedukhin
'''
def __init__(self, lang: Literal['en', 'ru'] = 'ru'):
from huggingface_hub import snapshot_download # type: ignore
code_path = snapshot_download(
repo_id='nvidia/audio-flamingo-3', repo_type='space'
)
sys.path.append(code_path)
import llava # type: ignore
self.llava_module = llava
sys.path.pop(-1)
model_path = snapshot_download(repo_id='nvidia/audio-flamingo-3')
self.model = llava.load(model_path, model_base=None).cuda() # type: ignore
self.lang: Literal['en', 'ru'] = lang
[docs]
@override
def transcribe(self, waveform: FLOATS) -> str:
# processor calls self.feature_extractor(audio, ...), it trims audio to 30 seconds
assert len(waveform) <= 16_000 * 30, (
'Audio should be <= 30 seconds length'
)
with waveform_as_file(waveform) as audio_path:
# requires soundfile lib
# keep the file until generation is done
sound = self.llava_module.Sound(str(audio_path)) # type: ignore
match self.lang:
case 'en':
prompt = 'Transcribe the audio.'
case 'ru':
prompt = (
'Транскрибируй аудио на русском языке.'
' Текст должен быть на русском языке.'
)
return self.model.generate_content( # type: ignore
[sound, prompt],
generation_config=self.model.default_generation_config, # type: ignore
)