From a378a54a5738bebebcfc91e1e005975af668daf2 Mon Sep 17 00:00:00 2001 From: Ritik Nandwal <48522685+nandwalritik@users.noreply.github.com> Date: Mon, 26 Aug 2024 20:37:01 +0530 Subject: [PATCH] Add changes for uroman package to handle non-Roman characters (#32404) * Add changes for uroman package to handle non-Roman characters * Update docs for uroman changes * Modifying error message to warning, for backward compatibility * Update instruction for user to install uroman * Update docs for uroman python version dependency and backward compatibility * Update warning message for python version compatibility with uroman * Refine docs --- docs/source/en/model_doc/vits.md | 25 +++++++++++++++++-- .../models/vits/tokenization_vits.py | 19 +++++++++----- src/transformers/utils/__init__.py | 1 + src/transformers/utils/import_utils.py | 11 ++++++++ 4 files changed, 48 insertions(+), 8 deletions(-) diff --git a/docs/source/en/model_doc/vits.md b/docs/source/en/model_doc/vits.md index 73001d82ed..7a29586fc0 100644 --- a/docs/source/en/model_doc/vits.md +++ b/docs/source/en/model_doc/vits.md @@ -93,12 +93,33 @@ from transformers import VitsTokenizer tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-eng") print(tokenizer.is_uroman) ``` +If the is_uroman attribute is `True`, the tokenizer will automatically apply the `uroman` package to your text inputs, but you need to install uroman if not already installed using: +``` +pip install --upgrade uroman +``` +Note: Python version required to use `uroman` as python package should be >= `3.10`. +You can use the tokenizer as usual without any additional preprocessing steps: +```python +import torch +from transformers import VitsTokenizer, VitsModel, set_seed +import os +import subprocess -If required, you should apply the uroman package to your text inputs **prior** to passing them to the `VitsTokenizer`, -since currently the tokenizer does not support performing the pre-processing itself. +tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-kor") +model = VitsModel.from_pretrained("facebook/mms-tts-kor") +text = "이봐 무슨 일이야" +inputs = tokenizer(text=text, return_tensors="pt") +set_seed(555) # make deterministic +with torch.no_grad(): + outputs = model(inputs["input_ids"]) + +waveform = outputs.waveform[0] +``` +If you don't want to upgrade to python >= `3.10`, then you can use the `uroman` perl package to pre-process the text inputs to the Roman alphabet. To do this, first clone the uroman repository to your local machine and set the bash variable `UROMAN` to the local path: + ```bash git clone https://github.com/isi-nlp/uroman.git cd uroman diff --git a/src/transformers/models/vits/tokenization_vits.py b/src/transformers/models/vits/tokenization_vits.py index 4c02857483..b4d8af7403 100644 --- a/src/transformers/models/vits/tokenization_vits.py +++ b/src/transformers/models/vits/tokenization_vits.py @@ -20,12 +20,14 @@ import re from typing import Any, Dict, List, Optional, Tuple, Union from ...tokenization_utils import PreTrainedTokenizer -from ...utils import is_phonemizer_available, logging +from ...utils import is_phonemizer_available, is_uroman_available, logging if is_phonemizer_available(): import phonemizer +if is_uroman_available(): + import uroman as ur logger = logging.get_logger(__name__) @@ -172,11 +174,16 @@ class VitsTokenizer(PreTrainedTokenizer): filtered_text = self._preprocess_char(text) if has_non_roman_characters(filtered_text) and self.is_uroman: - logger.warning( - "Text to the tokenizer contains non-Roman characters. Ensure the `uroman` Romanizer is " - "applied to the text prior to passing it to the tokenizer. See " - "`https://github.com/isi-nlp/uroman` for details." - ) + if not is_uroman_available(): + logger.warning( + "Text to the tokenizer contains non-Roman characters. To apply the `uroman` pre-processing " + "step automatically, ensure the `uroman` Romanizer is installed with: `pip install uroman` " + "Note `uroman` requires python version >= 3.10" + "Otherwise, apply the Romanizer manually as per the instructions: https://github.com/isi-nlp/uroman" + ) + else: + uroman = ur.Uroman() + filtered_text = uroman.romanize_string(filtered_text) if self.phonemize: if not is_phonemizer_available(): diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 56f594da15..b1a1bb56cb 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -218,6 +218,7 @@ from .import_utils import ( is_torchdynamo_compiling, is_torchvision_available, is_training_run_on_sagemaker, + is_uroman_available, is_vision_available, requires_backends, torch_only_method, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 0c16cac0f0..c4bb1a64eb 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -142,6 +142,7 @@ _quanto_available = _is_package_available("quanto") _pandas_available = _is_package_available("pandas") _peft_available = _is_package_available("peft") _phonemizer_available = _is_package_available("phonemizer") +_uroman_available = _is_package_available("uroman") _psutil_available = _is_package_available("psutil") _py3nvml_available = _is_package_available("py3nvml") _pyctcdecode_available = _is_package_available("pyctcdecode") @@ -1107,6 +1108,10 @@ def is_phonemizer_available(): return _phonemizer_available +def is_uroman_available(): + return _uroman_available + + def torch_only_method(fn): def wrapper(*args, **kwargs): if not _torch_available: @@ -1383,6 +1388,11 @@ PHONEMIZER_IMPORT_ERROR = """ {0} requires the phonemizer library but it was not found in your environment. You can install it with pip: `pip install phonemizer`. Please note that you may need to restart your runtime after installation. """ +# docstyle-ignore +UROMAN_IMPORT_ERROR = """ +{0} requires the uroman library but it was not found in your environment. You can install it with pip: +`pip install uroman`. Please note that you may need to restart your runtime after installation. +""" # docstyle-ignore @@ -1523,6 +1533,7 @@ BACKENDS_MAPPING = OrderedDict( ("g2p_en", (is_g2p_en_available, G2P_EN_IMPORT_ERROR)), ("pandas", (is_pandas_available, PANDAS_IMPORT_ERROR)), ("phonemizer", (is_phonemizer_available, PHONEMIZER_IMPORT_ERROR)), + ("uroman", (is_uroman_available, UROMAN_IMPORT_ERROR)), ("pretty_midi", (is_pretty_midi_available, PRETTY_MIDI_IMPORT_ERROR)), ("levenshtein", (is_levenshtein_available, LEVENSHTEIN_IMPORT_ERROR)), ("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)),