Source code for asr_eval.align.parsing

from __future__ import annotations
from collections import Counter
from collections.abc import Iterable
from dataclasses import dataclass
from functools import lru_cache
from itertools import chain
import re
import string
from typing import Callable

from asr_eval.align.transcription import (
    Wildcard,
    Transcription,
    SingleVariantTranscription,
    Token,
    MultiVariantBlock,
)


__all__ = [
    "Parser",
    "DEFAULT_PARSER",
    "PUNCT",
]


PUNCT = re.escape(r""".,!?:;…-‑–—'"‘“”«»()[]{}""")
r"""
A default set of punctuation characters to exclude from words
:code:`.,!?:;…-‑–—'"‘“”«»()[]{}`. Note that this does not affect parsing
control characters "{", "|", "}" in multivariant syntax. To override,
create a :class:`~asr_eval.align.parsing.Parser` with custom
:attr:`~asr_eval.align.parsing.Parser.tokenizing` field.

:meta hide-value:
"""


[docs] @dataclass class Parser: r"""Parses into words and (optionally) normalizes prediction or annotation. Performs the following: 1. Preprocesses the whole text if :attr:`~asr_eval.align.parsing.Parser.preprocessing` is set. This stage is suitable for various normalization methods, if they are used, such as numerals-to-digits normalizers or filler words removers. 2. If :meth:`~asr_eval.align.parsing.Parser.parse_transcription`, is called, processes multivariant syntax. 3. Splits all the text blocks into words with a regexp stored in the :attr:`~asr_eval.align.parsing.Parser.tokenizing` attribute. 4. Postprocesses each word if :attr:`~asr_eval.align.parsing.Parser.postprocessing` is set. This stage is suitable for lowercase conversion. A :const:`~asr_eval.align.parsing.DEFAULT_PARSER` is an instance of the Parser with default parameters. Example: >>> from asr_eval.align.parsing import DEFAULT_PARSER # same as Parser() >>> text = 'Hi there {fouth|4|t-th} {eh} <*>' >>> parsed = DEFAULT_PARSER.parse_transcription(text) >>> print(parsed.blocks) # doctest: +NORMALIZE_WHITESPACE (Token(hi), Token(there), MultiVariantBlock([Token(fouth)], [Token(4)], [Token(t), Token(th)]), MultiVariantBlock([Token(eh)], []), Token(Wildcard())) >>> from dataclasses import asdict >>> asdict(parsed.blocks[0]) # doctest: +NORMALIZE_WHITESPACE {'value': 'hi', 'uid': 'id0', 'start_pos': 0, 'end_pos': 2, 'start_time': nan, 'end_time': nan, 'attrs': {}, 'flags': set()} >>> print(parsed.colorize()) # doctest: +SKIP .. raw:: html <style>.y {background-color: #e0e841;} .g {background-color: #41e8a8;} .b {background-color: #41c7e8 ;}</style> <span style="white-space='pre'; font-family: 'Consolas', 'Ubuntu Mono', 'Monaco', monospace"> <span class="y">Hi</span> <span class="g">there</span> {<span class="b">fourth</span>|<span class="y">4</span>|<span class="g">4</span>-<span class="b">th</span>} {<span class="y">eh</span>} <span class="g"><*></span></span> Note: 1. Why not just :code:`nltk.word_tokenize`? In *asr_eval* words keep references to their positions in the original text, which :code:`word_tokenize` does not support. 2. By making a Parser with :code:`tokenizing=r'\\w|\\s|[^\\w\\s{PUNCT}]'` you can parse strings into characters, excluding punctuation. In this case, :class:`~asr_eval.align.alignment.Alignment` will calculate CER (character error rate) instead of WER. 3. You can create named parsers in :mod:`asr_eval.bench.parsers`. 4. When labeling a dataset, the annotator should be aware of the tokenization scheme. For example, if :code:`3/4$` is tokenized as a single word, then :code:`3/4$` and :code:`3 / 4 $` (with spaces) are different options, and both should be included in a multivariant block. See :doc:`/guide_alignment_wer` for details. Words (and, in general, tokens) can have attributes (:attr:`~asr_eval.align.transcription.Token.attrs`) and flags (:attr:`~asr_eval.align.transcription.Token.flags`). They are written in brackets, separated by comma from each other, and by "!" from the text. For example, the following syntax adds flag "abc" and attribute "w" with value "10" to the words bar, baz, qux: Example: >>> from asr_eval.align.parsing import DEFAULT_PARSER >>> text = '"Foo [abc,w=10! bar {baz} qux ]"' >>> transcription = DEFAULT_PARSER.parse_transcription(text) >>> for block in transcription.blocks: # doctest: +SKIP ... print(block) Token(foo) Token(bar, attrs:w=10, flags:abc) MultiVariantBlock([Token(baz, attrs:w=10, flags:abc)], []) Token(qux, attrs:w=10, flags:abc) """ tokenizing: str = rf'\w+|[^\w\s{PUNCT}]+' r"""A regexp to extract word, by default :code:`\\w+|[^\\w\\s{PUNCT}]+`, where :const:`~asr_eval.align.parsing.PUNCT` are punctuation characters. :meta hide-value: """ preprocessing: Callable[[str], str] = lambda text: text """A text preprocessing method set as :code:`Callable[[str], str]`, by default does nothing. Is suitable for text-to-text operations such as normalizers or filler word removers. Note that after parsing the :attr:`~asr_eval.align.transcription.Transcription.text` field in :class:`~asr_eval.align.transcription.Transcription` will contain the preprocessed version, and the original version will be gone. Example: >>> from asr_eval.align.parsing import Parser >>> import re >>> def filler_remover(text: str) -> str: ... for word in 'eh', 'oh', 'umm': ... text = re.sub(word, '', text, flags=re.IGNORECASE) ... return text >>> parser = Parser(preprocessing=filler_remover) >>> parsed = parser.parse_transcription('Umm eh of course') >>> print(parsed.text, parsed.blocks) of course [Token(of), Token(course)] See more examples in :mod:`asr_eval.bench.parsers`. :meta hide-value: """ postprocessing: Callable[[str], str] = ( lambda text: text.lower().replace('ё', 'е') ) """ A word postprocessing method set as :code:`Callable[[str], str]`, by default performs lowercase and diacritic conversion: .. code-block:: python postprocessing=lambda text: text.lower().replace('ё', 'е') Will only affect the :attr:`~asr_eval.align.transcription.Token.value` field in :class:`~asr_eval.align.transcription.Token`. This is useful to match lowercase words, while tracking their positions in the original :attr:`~asr_eval.align.transcription.Transcription.text` with capitalization and punctuation. :meta hide-value: """
[docs] def parse_single_variant_transcription( self, text: str ) -> SingleVariantTranscription: """Parses a text without multivariant blocks. In general, one needs this method for typing purposes only, because :meth:`~asr_eval.align.parsing.Parser.parse_transcription` supports both multivariant and single-variant transcriptions. """ text = self.preprocessing(text) text, attrs_per_token = _extract_and_trim_attrs(text) tokens = self._split_text_into_tokens(text) result = SingleVariantTranscription(text, tuple(tokens)) _assign_attrs_inplace(result, attrs_per_token) for i, t in enumerate(result.list_all_tokens()): t.uid = 'id' + str(i) return result
_MULTIVARIANT_PATTERN = re.compile( r'({[^{}]*?})' # multi variant '|' r'(?<=})([^{}]+?)(?={)' # single variant ) # We could also parse multivariant strings with pyparsing: # from pyparsing import CharsNotIn, OneOrMore, Suppress as S, Group, # Empty, ZeroOrMore # WORDS = CharsNotIn('{|}\n')('words') # OPTION = Group(WORDS | Empty())('option') # MULTI = Group(S('{') + OPTION + OneOrMore(S('|') + OPTION) \ # + S('}'))('multi') # MV_STRING = ZeroOrMore(MULTI | WORDS) # results = MV_STRING.parse_string('{a|b} ! {1|2 3|} x y {3|4}', # parse_all=True) # print(results.as_list()) # however, this is not obvious for ones who are not familiar with # pyparsing, and also gives uninformative parsing errors
[docs] def parse_transcription(self, text: str) -> Transcription: """Parses a text possibly containing multivariant blocks. See example in the class docstring. """ text = self.preprocessing(text) text, attrs_per_token = _extract_and_trim_attrs(text) blocks: list[Token | MultiVariantBlock] = [] for match in re.finditer(self._MULTIVARIANT_PATTERN, '}' + text + '{'): text_part = match.group() start = match.start() - 1 # account for '}' (see in re.finditer) end = match.end() - 1 # account for '}' (see in re.finditer) if text_part.startswith('{'): if start > 0: assert (c := text[start - 1]) in string.whitespace, ( f'put a space before a multivariant block, got "{c}"' ) if end < len(text): assert (c := text[end]) in string.whitespace, ( f'put a space after a multivariant block, got "{c}"' ) # options_raw: (option text, start pos) options_raw: list[tuple[str, int]] = [] for option_match in re.finditer( r'([^\|]*)\|', text_part[1:-1] + '|' ): option_text = option_match.group(1) start_pos = start + option_match.start() + 1 if option_text.strip().startswith('~'): # lexically wrong but acceptable form option_text = option_text.strip()[1:] if (match2 := re.match( r'([^\<]+)\<(\w+)\>', option_text.strip() )) is not None: # forms like Facebook<е> # TODO: ambiguous: need to add empty option to # {Facebook<е>} ?? # TODO: handle this in single-variant blocks base, suffix = match2.groups() options_raw.append((f'{base}', start_pos)) options_raw.append((f'{base}-{suffix}', start_pos)) else: options_raw.append((option_text, start_pos)) options: list[list[Token]] = [] for option_text, start_pos in options_raw: option_tokens = self._split_text_into_tokens(option_text) _shift_tokens_inplace(option_tokens, start_pos) options.append(option_tokens) if len(options) == 1: assert len(options[0]), 'empty multivariant block' options.append([]) blocks.append(MultiVariantBlock( options=options, start_pos=start, end_pos=end, )) else: new_tokens = self._split_text_into_tokens(text_part) _shift_tokens_inplace(new_tokens, shift=start) blocks += new_tokens result = Transcription(text, tuple(blocks)) _assign_attrs_inplace(result, attrs_per_token) for i, t in enumerate(result.list_all_tokens()): t.uid = 'id' + str(i) i = 0 for block in result.blocks: if isinstance(block, MultiVariantBlock): block.uid = 'mvid' + str(i) i += 1 return result
def _split_text_into_tokens(self, text: str) -> list[Token]: """Finds words in the text and return them as a list of Token. """ tokens = list(_regexp_split_text_into_tokens( text, {'word': self.tokenizing} )) for token in tokens: if not isinstance(token.value, Wildcard): token.value = self.postprocessing(token.value) return tokens
DEFAULT_PARSER = Parser() """An instance of :class:`~asr_eval.align.parsing.Parser` with default parameters. :meta hide-value: """ def _shift_tokens_inplace(tokens: list[Token], shift: int = 0): for t in tokens: t.start_pos += shift t.end_pos += shift @lru_cache(maxsize=100) def _compile_regexp(patterns: tuple[tuple[str, str]]) -> re.Pattern[str]: pattern = '|'.join( f'(?P<{name}>{subpattern})' for name, subpattern in patterns ) return re.compile(pattern, re.MULTILINE|re.DOTALL|re.UNICODE) def _regexp_split_text_into_tokens( text: str, patterns: dict[str, str] ) -> Iterable[Token]: """Searches sequentially for any of the given patterns. For each match returns a :class:`~asr_eval.align.transcription.Token`. """ # this is overcompilcated, TODO simplify? pattern = _compile_regexp(tuple(patterns.items())) for match in re.finditer(pattern, text): found_groups = [ (name, substr) for name, substr in match.groupdict().items() if substr is not None ] assert len(found_groups) == 1 name, word = found_groups[0] assert name in patterns yield Token( value=Wildcard() if word == Wildcard._SYMBOL else word, # pyright: ignore[reportPrivateUsage] start_pos=match.start(), end_pos=match.end(), ) @dataclass(slots=True) class _MarkedAttr: value: str start: int end: int _ATTRS_PATTERN = re.compile( r'(\[)' # starting [ r'([^!\]]*)' # marks r'(!)' # ! symbol r'([^\]]*)' # transcription r'(\])' # ending ] ) def _extract_and_trim_attrs(text: str) -> tuple[str, list[list[str]]]: # we start with this: X = "one two [w=10!Kafka AWS] three" # and convert to this: Y = "one two Kafka AWS three" # while extracting this: [("w=10", 8, 17)] # where (8, 17) is the position of "Kafka AWS" in Y # span of value, start symbol, end symbol attrs: list[_MarkedAttr] = [] # iterating while modifying the string # TODO speedup this logic for long texts while (match := re.search(_ATTRS_PATTERN, text)): # example: [digits,a=b!три четыре] # here: "digits,a=b" - flags for "три четыре" # extract groups (_ATTRS_PATTERN has 5 groups) _, attr, _, _text, _ = match.groups() assert attr, f'Attrs should not be empty: {match.group()}' # extract group positions s1, _s2, s3, _s4, s5 = [match.span(i) for i in range(1, 6)] if len(attrs): assert s1[0] >= attrs[-1].end, 'Overlapping attrs' # spans to cut from text ("[digits,a=b!" and "]" in our example) start1, end1 = s1[0], s3[1] start2, end2 = s5 # cut spans from text text = text[:start1] + text[end1:start2] + text[end2:] # save the attrs attr = _MarkedAttr(attr, start1, start1 + start2 - end1) attrs.append(attr) assert text[attr.start:attr.end] == _text attrs_per_token: list[list[str]] = [[] for _ in range(len(text))] for attr in attrs: for pos in range(attr.start, attr.end): for value in attr.value.split(','): attrs_per_token[pos].append(value) return text, attrs_per_token def _assign_attrs_inplace( transcription: Transcription, attrs_per_token: list[list[str]] ): for token in transcription.list_all_tokens(): attrs_for_token = Counter(chain(*[ attrs_per_token[pos] for pos in range(token.start_pos, token.end_pos) ])) for attr, count in attrs_for_token.most_common(): assert count == token.end_pos - token.start_pos, ( f'Attr {attr} half-covered a word {token.to_text()}' ) if '=' in attr: k, v = attr.split('=', 1) token.attrs[k] = v else: token.flags.add(attr)