Move pyctcdecode (#14686)
* Move pyctcdecode dep * Fix doc and last objects * Quality * Style * Ignore this black
This commit is contained in:
@@ -77,7 +77,7 @@ Wav2Vec2ProcessorWithLM
|
||||
Wav2Vec2 specific outputs
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.models.wav2vec2.processing_wav2vec2_with_lm.Wav2Vec2DecoderWithLMOutput
|
||||
.. autoclass:: transformers.models.wav2vec2_with_lm.processing_wav2vec2_with_lm.Wav2Vec2DecoderWithLMOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2BaseModelOutput
|
||||
|
||||
@@ -313,7 +313,7 @@ _import_structure = {
|
||||
"Wav2Vec2Processor",
|
||||
"Wav2Vec2Tokenizer",
|
||||
],
|
||||
"models.wav2vec2_with_lm": [],
|
||||
"models.wav2vec2_with_lm": ["Wav2Vec2ProcessorWithLM"],
|
||||
"models.xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig", "XLMTokenizer"],
|
||||
"models.xlm_prophetnet": ["XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMProphetNetConfig"],
|
||||
"models.xlm_roberta": ["XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMRobertaConfig"],
|
||||
@@ -475,15 +475,6 @@ else:
|
||||
name for name in dir(dummy_speech_objects) if not name.startswith("_")
|
||||
]
|
||||
|
||||
if is_pyctcdecode_available():
|
||||
_import_structure["models.wav2vec2_with_lm"].append("Wav2Vec2ProcessorWithLM")
|
||||
else:
|
||||
from .utils import dummy_pyctcdecode_objects
|
||||
|
||||
_import_structure["utils.dummy_pyctcdecode_objects"] = [
|
||||
name for name in dir(dummy_pyctcdecode_objects) if not name.startswith("_")
|
||||
]
|
||||
|
||||
if is_sentencepiece_available() and is_speech_available():
|
||||
_import_structure["models.speech_to_text"].append("Speech2TextProcessor")
|
||||
else:
|
||||
@@ -2329,6 +2320,7 @@ if TYPE_CHECKING:
|
||||
Wav2Vec2Processor,
|
||||
Wav2Vec2Tokenizer,
|
||||
)
|
||||
from .models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
|
||||
from .models.xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig, XLMTokenizer
|
||||
from .models.xlm_prophetnet import XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMProphetNetConfig
|
||||
from .models.xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig
|
||||
@@ -2472,11 +2464,6 @@ if TYPE_CHECKING:
|
||||
else:
|
||||
from .utils.dummy_speech_objects import *
|
||||
|
||||
if is_pyctcdecode_available():
|
||||
from .models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
|
||||
else:
|
||||
from .utils.dummy_pyctcdecode_objects import *
|
||||
|
||||
if is_speech_available() and is_sentencepiece_available():
|
||||
from .models.speech_to_text import Speech2TextProcessor
|
||||
else:
|
||||
|
||||
@@ -17,19 +17,18 @@
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...file_utils import _LazyModule, is_pyctcdecode_available
|
||||
from ...file_utils import _LazyModule
|
||||
|
||||
|
||||
_import_structure = {"processing_wav2vec2_with_lm": []}
|
||||
|
||||
|
||||
if is_pyctcdecode_available():
|
||||
_import_structure["processing_wav2vec2_with_lm"].append("Wav2Vec2ProcessorWithLM")
|
||||
# fmt: off
|
||||
_import_structure = {
|
||||
"processing_wav2vec2_with_lm": ["Wav2Vec2ProcessorWithLM"]
|
||||
}
|
||||
# fmt: on
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if is_pyctcdecode_available():
|
||||
from .processing_wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
|
||||
from .processing_wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
@@ -19,19 +19,10 @@ import os
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing import Pool
|
||||
from typing import Iterable, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Iterable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from pyctcdecode import BeamSearchDecoderCTC
|
||||
from pyctcdecode.alphabet import BLANK_TOKEN_PTN, UNK_TOKEN, UNK_TOKEN_PTN
|
||||
from pyctcdecode.constants import (
|
||||
DEFAULT_BEAM_WIDTH,
|
||||
DEFAULT_HOTWORD_WEIGHT,
|
||||
DEFAULT_MIN_TOKEN_LOGP,
|
||||
DEFAULT_PRUNE_LOGP,
|
||||
)
|
||||
|
||||
from ...feature_extraction_utils import FeatureExtractionMixin
|
||||
from ...file_utils import ModelOutput, requires_backends
|
||||
from ...tokenization_utils import PreTrainedTokenizer
|
||||
@@ -39,6 +30,10 @@ from ..wav2vec2.feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
|
||||
from ..wav2vec2.tokenization_wav2vec2 import Wav2Vec2CTCTokenizer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pyctcdecode import BeamSearchDecoderCTC
|
||||
|
||||
|
||||
@dataclass
|
||||
class Wav2Vec2DecoderWithLMOutput(ModelOutput):
|
||||
"""
|
||||
@@ -70,8 +65,10 @@ class Wav2Vec2ProcessorWithLM:
|
||||
self,
|
||||
feature_extractor: FeatureExtractionMixin,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
decoder: BeamSearchDecoderCTC,
|
||||
decoder: "BeamSearchDecoderCTC",
|
||||
):
|
||||
from pyctcdecode import BeamSearchDecoderCTC
|
||||
|
||||
if not isinstance(feature_extractor, Wav2Vec2FeatureExtractor):
|
||||
raise ValueError(
|
||||
f"`feature_extractor` has to be of type {Wav2Vec2FeatureExtractor.__class__}, but is {type(feature_extractor)}"
|
||||
@@ -153,6 +150,8 @@ class Wav2Vec2ProcessorWithLM:
|
||||
:class:`~transformers.PreTrainedTokenizer`
|
||||
"""
|
||||
requires_backends(cls, "pyctcdecode")
|
||||
from pyctcdecode import BeamSearchDecoderCTC
|
||||
|
||||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
@@ -183,7 +182,7 @@ class Wav2Vec2ProcessorWithLM:
|
||||
return cls(feature_extractor=feature_extractor, tokenizer=tokenizer, decoder=decoder)
|
||||
|
||||
@staticmethod
|
||||
def _set_language_model_attribute(decoder: BeamSearchDecoderCTC, attribute: str, value: float):
|
||||
def _set_language_model_attribute(decoder: "BeamSearchDecoderCTC", attribute: str, value: float):
|
||||
setattr(decoder.model_container[decoder._model_key], attribute, value)
|
||||
|
||||
@property
|
||||
@@ -192,6 +191,8 @@ class Wav2Vec2ProcessorWithLM:
|
||||
|
||||
@staticmethod
|
||||
def get_missing_alphabet_tokens(decoder, tokenizer):
|
||||
from pyctcdecode.alphabet import BLANK_TOKEN_PTN, UNK_TOKEN, UNK_TOKEN_PTN
|
||||
|
||||
# we need to make sure that all of the tokenizer's except the special tokens
|
||||
# are present in the decoder's alphabet. Retrieve missing alphabet token
|
||||
# from decoder
|
||||
@@ -270,6 +271,12 @@ class Wav2Vec2ProcessorWithLM:
|
||||
:class:`~transformers.models.wav2vec2.Wav2Vec2DecoderWithLMOutput` or :obj:`tuple`.
|
||||
|
||||
"""
|
||||
from pyctcdecode.constants import (
|
||||
DEFAULT_BEAM_WIDTH,
|
||||
DEFAULT_HOTWORD_WEIGHT,
|
||||
DEFAULT_MIN_TOKEN_LOGP,
|
||||
DEFAULT_PRUNE_LOGP,
|
||||
)
|
||||
|
||||
# set defaults
|
||||
beam_width = beam_width if beam_width is not None else DEFAULT_BEAM_WIDTH
|
||||
@@ -330,6 +337,12 @@ class Wav2Vec2ProcessorWithLM:
|
||||
:class:`~transformers.models.wav2vec2.Wav2Vec2DecoderWithLMOutput` or :obj:`tuple`.
|
||||
|
||||
"""
|
||||
from pyctcdecode.constants import (
|
||||
DEFAULT_BEAM_WIDTH,
|
||||
DEFAULT_HOTWORD_WEIGHT,
|
||||
DEFAULT_MIN_TOKEN_LOGP,
|
||||
DEFAULT_PRUNE_LOGP,
|
||||
)
|
||||
|
||||
# set defaults
|
||||
beam_width = beam_width if beam_width is not None else DEFAULT_BEAM_WIDTH
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
# This file is autogenerated by the command `make fix-copies`, do not edit.
|
||||
from ..file_utils import requires_backends
|
||||
|
||||
|
||||
class Wav2Vec2ProcessorWithLM:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["pyctcdecode"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["pyctcdecode"])
|
||||
Reference in New Issue
Block a user