Move pyctcdecode (#14686)
* Move pyctcdecode dep * Fix doc and last objects * Quality * Style * Ignore this black
This commit is contained in:
@@ -77,7 +77,7 @@ Wav2Vec2ProcessorWithLM
|
|||||||
Wav2Vec2 specific outputs
|
Wav2Vec2 specific outputs
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
.. autoclass:: transformers.models.wav2vec2.processing_wav2vec2_with_lm.Wav2Vec2DecoderWithLMOutput
|
.. autoclass:: transformers.models.wav2vec2_with_lm.processing_wav2vec2_with_lm.Wav2Vec2DecoderWithLMOutput
|
||||||
:members:
|
:members:
|
||||||
|
|
||||||
.. autoclass:: transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2BaseModelOutput
|
.. autoclass:: transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2BaseModelOutput
|
||||||
|
|||||||
@@ -313,7 +313,7 @@ _import_structure = {
|
|||||||
"Wav2Vec2Processor",
|
"Wav2Vec2Processor",
|
||||||
"Wav2Vec2Tokenizer",
|
"Wav2Vec2Tokenizer",
|
||||||
],
|
],
|
||||||
"models.wav2vec2_with_lm": [],
|
"models.wav2vec2_with_lm": ["Wav2Vec2ProcessorWithLM"],
|
||||||
"models.xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig", "XLMTokenizer"],
|
"models.xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig", "XLMTokenizer"],
|
||||||
"models.xlm_prophetnet": ["XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMProphetNetConfig"],
|
"models.xlm_prophetnet": ["XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMProphetNetConfig"],
|
||||||
"models.xlm_roberta": ["XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMRobertaConfig"],
|
"models.xlm_roberta": ["XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMRobertaConfig"],
|
||||||
@@ -475,15 +475,6 @@ else:
|
|||||||
name for name in dir(dummy_speech_objects) if not name.startswith("_")
|
name for name in dir(dummy_speech_objects) if not name.startswith("_")
|
||||||
]
|
]
|
||||||
|
|
||||||
if is_pyctcdecode_available():
|
|
||||||
_import_structure["models.wav2vec2_with_lm"].append("Wav2Vec2ProcessorWithLM")
|
|
||||||
else:
|
|
||||||
from .utils import dummy_pyctcdecode_objects
|
|
||||||
|
|
||||||
_import_structure["utils.dummy_pyctcdecode_objects"] = [
|
|
||||||
name for name in dir(dummy_pyctcdecode_objects) if not name.startswith("_")
|
|
||||||
]
|
|
||||||
|
|
||||||
if is_sentencepiece_available() and is_speech_available():
|
if is_sentencepiece_available() and is_speech_available():
|
||||||
_import_structure["models.speech_to_text"].append("Speech2TextProcessor")
|
_import_structure["models.speech_to_text"].append("Speech2TextProcessor")
|
||||||
else:
|
else:
|
||||||
@@ -2329,6 +2320,7 @@ if TYPE_CHECKING:
|
|||||||
Wav2Vec2Processor,
|
Wav2Vec2Processor,
|
||||||
Wav2Vec2Tokenizer,
|
Wav2Vec2Tokenizer,
|
||||||
)
|
)
|
||||||
|
from .models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
|
||||||
from .models.xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig, XLMTokenizer
|
from .models.xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig, XLMTokenizer
|
||||||
from .models.xlm_prophetnet import XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMProphetNetConfig
|
from .models.xlm_prophetnet import XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMProphetNetConfig
|
||||||
from .models.xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig
|
from .models.xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig
|
||||||
@@ -2472,11 +2464,6 @@ if TYPE_CHECKING:
|
|||||||
else:
|
else:
|
||||||
from .utils.dummy_speech_objects import *
|
from .utils.dummy_speech_objects import *
|
||||||
|
|
||||||
if is_pyctcdecode_available():
|
|
||||||
from .models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
|
|
||||||
else:
|
|
||||||
from .utils.dummy_pyctcdecode_objects import *
|
|
||||||
|
|
||||||
if is_speech_available() and is_sentencepiece_available():
|
if is_speech_available() and is_sentencepiece_available():
|
||||||
from .models.speech_to_text import Speech2TextProcessor
|
from .models.speech_to_text import Speech2TextProcessor
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -17,19 +17,18 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from ...file_utils import _LazyModule, is_pyctcdecode_available
|
from ...file_utils import _LazyModule
|
||||||
|
|
||||||
|
|
||||||
_import_structure = {"processing_wav2vec2_with_lm": []}
|
# fmt: off
|
||||||
|
_import_structure = {
|
||||||
|
"processing_wav2vec2_with_lm": ["Wav2Vec2ProcessorWithLM"]
|
||||||
if is_pyctcdecode_available():
|
}
|
||||||
_import_structure["processing_wav2vec2_with_lm"].append("Wav2Vec2ProcessorWithLM")
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
if is_pyctcdecode_available():
|
from .processing_wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
|
||||||
from .processing_wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
|
|
||||||
else:
|
else:
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|||||||
@@ -19,19 +19,10 @@ import os
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from multiprocessing import Pool
|
from multiprocessing import Pool
|
||||||
from typing import Iterable, List, Optional, Union
|
from typing import TYPE_CHECKING, Iterable, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from pyctcdecode import BeamSearchDecoderCTC
|
|
||||||
from pyctcdecode.alphabet import BLANK_TOKEN_PTN, UNK_TOKEN, UNK_TOKEN_PTN
|
|
||||||
from pyctcdecode.constants import (
|
|
||||||
DEFAULT_BEAM_WIDTH,
|
|
||||||
DEFAULT_HOTWORD_WEIGHT,
|
|
||||||
DEFAULT_MIN_TOKEN_LOGP,
|
|
||||||
DEFAULT_PRUNE_LOGP,
|
|
||||||
)
|
|
||||||
|
|
||||||
from ...feature_extraction_utils import FeatureExtractionMixin
|
from ...feature_extraction_utils import FeatureExtractionMixin
|
||||||
from ...file_utils import ModelOutput, requires_backends
|
from ...file_utils import ModelOutput, requires_backends
|
||||||
from ...tokenization_utils import PreTrainedTokenizer
|
from ...tokenization_utils import PreTrainedTokenizer
|
||||||
@@ -39,6 +30,10 @@ from ..wav2vec2.feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
|
|||||||
from ..wav2vec2.tokenization_wav2vec2 import Wav2Vec2CTCTokenizer
|
from ..wav2vec2.tokenization_wav2vec2 import Wav2Vec2CTCTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from pyctcdecode import BeamSearchDecoderCTC
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Wav2Vec2DecoderWithLMOutput(ModelOutput):
|
class Wav2Vec2DecoderWithLMOutput(ModelOutput):
|
||||||
"""
|
"""
|
||||||
@@ -70,8 +65,10 @@ class Wav2Vec2ProcessorWithLM:
|
|||||||
self,
|
self,
|
||||||
feature_extractor: FeatureExtractionMixin,
|
feature_extractor: FeatureExtractionMixin,
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
decoder: BeamSearchDecoderCTC,
|
decoder: "BeamSearchDecoderCTC",
|
||||||
):
|
):
|
||||||
|
from pyctcdecode import BeamSearchDecoderCTC
|
||||||
|
|
||||||
if not isinstance(feature_extractor, Wav2Vec2FeatureExtractor):
|
if not isinstance(feature_extractor, Wav2Vec2FeatureExtractor):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`feature_extractor` has to be of type {Wav2Vec2FeatureExtractor.__class__}, but is {type(feature_extractor)}"
|
f"`feature_extractor` has to be of type {Wav2Vec2FeatureExtractor.__class__}, but is {type(feature_extractor)}"
|
||||||
@@ -153,6 +150,8 @@ class Wav2Vec2ProcessorWithLM:
|
|||||||
:class:`~transformers.PreTrainedTokenizer`
|
:class:`~transformers.PreTrainedTokenizer`
|
||||||
"""
|
"""
|
||||||
requires_backends(cls, "pyctcdecode")
|
requires_backends(cls, "pyctcdecode")
|
||||||
|
from pyctcdecode import BeamSearchDecoderCTC
|
||||||
|
|
||||||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
|
|
||||||
@@ -183,7 +182,7 @@ class Wav2Vec2ProcessorWithLM:
|
|||||||
return cls(feature_extractor=feature_extractor, tokenizer=tokenizer, decoder=decoder)
|
return cls(feature_extractor=feature_extractor, tokenizer=tokenizer, decoder=decoder)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _set_language_model_attribute(decoder: BeamSearchDecoderCTC, attribute: str, value: float):
|
def _set_language_model_attribute(decoder: "BeamSearchDecoderCTC", attribute: str, value: float):
|
||||||
setattr(decoder.model_container[decoder._model_key], attribute, value)
|
setattr(decoder.model_container[decoder._model_key], attribute, value)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -192,6 +191,8 @@ class Wav2Vec2ProcessorWithLM:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_missing_alphabet_tokens(decoder, tokenizer):
|
def get_missing_alphabet_tokens(decoder, tokenizer):
|
||||||
|
from pyctcdecode.alphabet import BLANK_TOKEN_PTN, UNK_TOKEN, UNK_TOKEN_PTN
|
||||||
|
|
||||||
# we need to make sure that all of the tokenizer's except the special tokens
|
# we need to make sure that all of the tokenizer's except the special tokens
|
||||||
# are present in the decoder's alphabet. Retrieve missing alphabet token
|
# are present in the decoder's alphabet. Retrieve missing alphabet token
|
||||||
# from decoder
|
# from decoder
|
||||||
@@ -270,6 +271,12 @@ class Wav2Vec2ProcessorWithLM:
|
|||||||
:class:`~transformers.models.wav2vec2.Wav2Vec2DecoderWithLMOutput` or :obj:`tuple`.
|
:class:`~transformers.models.wav2vec2.Wav2Vec2DecoderWithLMOutput` or :obj:`tuple`.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
from pyctcdecode.constants import (
|
||||||
|
DEFAULT_BEAM_WIDTH,
|
||||||
|
DEFAULT_HOTWORD_WEIGHT,
|
||||||
|
DEFAULT_MIN_TOKEN_LOGP,
|
||||||
|
DEFAULT_PRUNE_LOGP,
|
||||||
|
)
|
||||||
|
|
||||||
# set defaults
|
# set defaults
|
||||||
beam_width = beam_width if beam_width is not None else DEFAULT_BEAM_WIDTH
|
beam_width = beam_width if beam_width is not None else DEFAULT_BEAM_WIDTH
|
||||||
@@ -330,6 +337,12 @@ class Wav2Vec2ProcessorWithLM:
|
|||||||
:class:`~transformers.models.wav2vec2.Wav2Vec2DecoderWithLMOutput` or :obj:`tuple`.
|
:class:`~transformers.models.wav2vec2.Wav2Vec2DecoderWithLMOutput` or :obj:`tuple`.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
from pyctcdecode.constants import (
|
||||||
|
DEFAULT_BEAM_WIDTH,
|
||||||
|
DEFAULT_HOTWORD_WEIGHT,
|
||||||
|
DEFAULT_MIN_TOKEN_LOGP,
|
||||||
|
DEFAULT_PRUNE_LOGP,
|
||||||
|
)
|
||||||
|
|
||||||
# set defaults
|
# set defaults
|
||||||
beam_width = beam_width if beam_width is not None else DEFAULT_BEAM_WIDTH
|
beam_width = beam_width if beam_width is not None else DEFAULT_BEAM_WIDTH
|
||||||
|
|||||||
@@ -1,11 +0,0 @@
|
|||||||
# This file is autogenerated by the command `make fix-copies`, do not edit.
|
|
||||||
from ..file_utils import requires_backends
|
|
||||||
|
|
||||||
|
|
||||||
class Wav2Vec2ProcessorWithLM:
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_backends(self, ["pyctcdecode"])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, *args, **kwargs):
|
|
||||||
requires_backends(cls, ["pyctcdecode"])
|
|
||||||
Reference in New Issue
Block a user