From 13186d7152c3795878265d4699423847d6299866 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Wed, 8 Dec 2021 15:41:58 -0500 Subject: [PATCH] Move pyctcdecode (#14686) * Move pyctcdecode dep * Fix doc and last objects * Quality * Style * Ignore this black --- docs/source/model_doc/wav2vec2.rst | 2 +- src/transformers/__init__.py | 17 +-------- .../models/wav2vec2_with_lm/__init__.py | 15 ++++---- .../processing_wav2vec2_with_lm.py | 37 +++++++++++++------ .../utils/dummy_pyctcdecode_objects.py | 11 ------ 5 files changed, 35 insertions(+), 47 deletions(-) delete mode 100644 src/transformers/utils/dummy_pyctcdecode_objects.py diff --git a/docs/source/model_doc/wav2vec2.rst b/docs/source/model_doc/wav2vec2.rst index 8c3d5481ac..3ac721b2f9 100644 --- a/docs/source/model_doc/wav2vec2.rst +++ b/docs/source/model_doc/wav2vec2.rst @@ -77,7 +77,7 @@ Wav2Vec2ProcessorWithLM Wav2Vec2 specific outputs ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: transformers.models.wav2vec2.processing_wav2vec2_with_lm.Wav2Vec2DecoderWithLMOutput +.. autoclass:: transformers.models.wav2vec2_with_lm.processing_wav2vec2_with_lm.Wav2Vec2DecoderWithLMOutput :members: .. autoclass:: transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2BaseModelOutput diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index b974552460..e0198b1f14 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -313,7 +313,7 @@ _import_structure = { "Wav2Vec2Processor", "Wav2Vec2Tokenizer", ], - "models.wav2vec2_with_lm": [], + "models.wav2vec2_with_lm": ["Wav2Vec2ProcessorWithLM"], "models.xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig", "XLMTokenizer"], "models.xlm_prophetnet": ["XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMProphetNetConfig"], "models.xlm_roberta": ["XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMRobertaConfig"], @@ -475,15 +475,6 @@ else: name for name in dir(dummy_speech_objects) if not name.startswith("_") ] -if is_pyctcdecode_available(): - _import_structure["models.wav2vec2_with_lm"].append("Wav2Vec2ProcessorWithLM") -else: - from .utils import dummy_pyctcdecode_objects - - _import_structure["utils.dummy_pyctcdecode_objects"] = [ - name for name in dir(dummy_pyctcdecode_objects) if not name.startswith("_") - ] - if is_sentencepiece_available() and is_speech_available(): _import_structure["models.speech_to_text"].append("Speech2TextProcessor") else: @@ -2329,6 +2320,7 @@ if TYPE_CHECKING: Wav2Vec2Processor, Wav2Vec2Tokenizer, ) + from .models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM from .models.xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig, XLMTokenizer from .models.xlm_prophetnet import XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMProphetNetConfig from .models.xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig @@ -2472,11 +2464,6 @@ if TYPE_CHECKING: else: from .utils.dummy_speech_objects import * - if is_pyctcdecode_available(): - from .models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM - else: - from .utils.dummy_pyctcdecode_objects import * - if is_speech_available() and is_sentencepiece_available(): from .models.speech_to_text import Speech2TextProcessor else: diff --git a/src/transformers/models/wav2vec2_with_lm/__init__.py b/src/transformers/models/wav2vec2_with_lm/__init__.py index 4b03c83252..cca731f0f7 100644 --- a/src/transformers/models/wav2vec2_with_lm/__init__.py +++ b/src/transformers/models/wav2vec2_with_lm/__init__.py @@ -17,19 +17,18 @@ # limitations under the License. from typing import TYPE_CHECKING -from ...file_utils import _LazyModule, is_pyctcdecode_available +from ...file_utils import _LazyModule -_import_structure = {"processing_wav2vec2_with_lm": []} - - -if is_pyctcdecode_available(): - _import_structure["processing_wav2vec2_with_lm"].append("Wav2Vec2ProcessorWithLM") +# fmt: off +_import_structure = { + "processing_wav2vec2_with_lm": ["Wav2Vec2ProcessorWithLM"] +} +# fmt: on if TYPE_CHECKING: - if is_pyctcdecode_available(): - from .processing_wav2vec2_with_lm import Wav2Vec2ProcessorWithLM + from .processing_wav2vec2_with_lm import Wav2Vec2ProcessorWithLM else: import sys diff --git a/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py b/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py index 750a49d473..291898274d 100644 --- a/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py +++ b/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py @@ -19,19 +19,10 @@ import os from contextlib import contextmanager from dataclasses import dataclass from multiprocessing import Pool -from typing import Iterable, List, Optional, Union +from typing import TYPE_CHECKING, Iterable, List, Optional, Union import numpy as np -from pyctcdecode import BeamSearchDecoderCTC -from pyctcdecode.alphabet import BLANK_TOKEN_PTN, UNK_TOKEN, UNK_TOKEN_PTN -from pyctcdecode.constants import ( - DEFAULT_BEAM_WIDTH, - DEFAULT_HOTWORD_WEIGHT, - DEFAULT_MIN_TOKEN_LOGP, - DEFAULT_PRUNE_LOGP, -) - from ...feature_extraction_utils import FeatureExtractionMixin from ...file_utils import ModelOutput, requires_backends from ...tokenization_utils import PreTrainedTokenizer @@ -39,6 +30,10 @@ from ..wav2vec2.feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor from ..wav2vec2.tokenization_wav2vec2 import Wav2Vec2CTCTokenizer +if TYPE_CHECKING: + from pyctcdecode import BeamSearchDecoderCTC + + @dataclass class Wav2Vec2DecoderWithLMOutput(ModelOutput): """ @@ -70,8 +65,10 @@ class Wav2Vec2ProcessorWithLM: self, feature_extractor: FeatureExtractionMixin, tokenizer: PreTrainedTokenizer, - decoder: BeamSearchDecoderCTC, + decoder: "BeamSearchDecoderCTC", ): + from pyctcdecode import BeamSearchDecoderCTC + if not isinstance(feature_extractor, Wav2Vec2FeatureExtractor): raise ValueError( f"`feature_extractor` has to be of type {Wav2Vec2FeatureExtractor.__class__}, but is {type(feature_extractor)}" @@ -153,6 +150,8 @@ class Wav2Vec2ProcessorWithLM: :class:`~transformers.PreTrainedTokenizer` """ requires_backends(cls, "pyctcdecode") + from pyctcdecode import BeamSearchDecoderCTC + feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs) tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs) @@ -183,7 +182,7 @@ class Wav2Vec2ProcessorWithLM: return cls(feature_extractor=feature_extractor, tokenizer=tokenizer, decoder=decoder) @staticmethod - def _set_language_model_attribute(decoder: BeamSearchDecoderCTC, attribute: str, value: float): + def _set_language_model_attribute(decoder: "BeamSearchDecoderCTC", attribute: str, value: float): setattr(decoder.model_container[decoder._model_key], attribute, value) @property @@ -192,6 +191,8 @@ class Wav2Vec2ProcessorWithLM: @staticmethod def get_missing_alphabet_tokens(decoder, tokenizer): + from pyctcdecode.alphabet import BLANK_TOKEN_PTN, UNK_TOKEN, UNK_TOKEN_PTN + # we need to make sure that all of the tokenizer's except the special tokens # are present in the decoder's alphabet. Retrieve missing alphabet token # from decoder @@ -270,6 +271,12 @@ class Wav2Vec2ProcessorWithLM: :class:`~transformers.models.wav2vec2.Wav2Vec2DecoderWithLMOutput` or :obj:`tuple`. """ + from pyctcdecode.constants import ( + DEFAULT_BEAM_WIDTH, + DEFAULT_HOTWORD_WEIGHT, + DEFAULT_MIN_TOKEN_LOGP, + DEFAULT_PRUNE_LOGP, + ) # set defaults beam_width = beam_width if beam_width is not None else DEFAULT_BEAM_WIDTH @@ -330,6 +337,12 @@ class Wav2Vec2ProcessorWithLM: :class:`~transformers.models.wav2vec2.Wav2Vec2DecoderWithLMOutput` or :obj:`tuple`. """ + from pyctcdecode.constants import ( + DEFAULT_BEAM_WIDTH, + DEFAULT_HOTWORD_WEIGHT, + DEFAULT_MIN_TOKEN_LOGP, + DEFAULT_PRUNE_LOGP, + ) # set defaults beam_width = beam_width if beam_width is not None else DEFAULT_BEAM_WIDTH diff --git a/src/transformers/utils/dummy_pyctcdecode_objects.py b/src/transformers/utils/dummy_pyctcdecode_objects.py deleted file mode 100644 index fee38b3dac..0000000000 --- a/src/transformers/utils/dummy_pyctcdecode_objects.py +++ /dev/null @@ -1,11 +0,0 @@ -# This file is autogenerated by the command `make fix-copies`, do not edit. -from ..file_utils import requires_backends - - -class Wav2Vec2ProcessorWithLM: - def __init__(self, *args, **kwargs): - requires_backends(self, ["pyctcdecode"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["pyctcdecode"])