Auto processor fix (#14623)
* Add AutoProcessor class Init and tests Add doc Fix init Update src/transformers/models/auto/processing_auto.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Reverts to tokenizer or feature extractor when available Adapt test * Revert "Adapt test" This reverts commit bbdde5fab02465f24b54b227390073082cb32093. * Revert "Reverts to tokenizer or feature extractor when available" This reverts commit 77659ff5d21b6cc0baf6f443017e35e056a525bb. * Don't revert everything Lysandre! Co-authored-by: Sylvain Gugger <sylvain.gugger@gmail.com>
This commit is contained in:
@@ -28,8 +28,6 @@ from .configuration_auto import (
|
|||||||
model_type_to_module_name,
|
model_type_to_module_name,
|
||||||
replace_list_option_in_docstrings,
|
replace_list_option_in_docstrings,
|
||||||
)
|
)
|
||||||
from .feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING_NAMES, AutoFeatureExtractor
|
|
||||||
from .tokenization_auto import TOKENIZER_MAPPING_NAMES, AutoTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
PROCESSOR_MAPPING_NAMES = OrderedDict(
|
PROCESSOR_MAPPING_NAMES = OrderedDict(
|
||||||
@@ -85,9 +83,6 @@ class AutoProcessor:
|
|||||||
|
|
||||||
List options
|
List options
|
||||||
|
|
||||||
For other types of models, this class will return the appropriate tokenizer (if available) or feature
|
|
||||||
extractor.
|
|
||||||
|
|
||||||
Params:
|
Params:
|
||||||
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
|
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
|
||||||
This can be either:
|
This can be either:
|
||||||
@@ -167,24 +162,11 @@ class AutoProcessor:
|
|||||||
return processor_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
return processor_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
|
|
||||||
model_type = config_class_to_model_type(type(config).__name__)
|
model_type = config_class_to_model_type(type(config).__name__)
|
||||||
if model_type is not None and model_type in PROCESSOR_MAPPING_NAMES:
|
if model_type is not None:
|
||||||
return PROCESSOR_MAPPING[type(config)].from_pretrained(pretrained_model_name_or_path, **kwargs)
|
return PROCESSOR_MAPPING[type(config)].from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
|
|
||||||
# At this stage there doesn't seem to be a `Processor` class available for this model, so let's try a tokenizer
|
|
||||||
if model_type in TOKENIZER_MAPPING_NAMES:
|
|
||||||
return AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
|
||||||
|
|
||||||
# At this stage there doesn't seem to be a `Processor` class available for this model, so let's try a tokenizer
|
|
||||||
if model_type in FEATURE_EXTRACTOR_MAPPING_NAMES:
|
|
||||||
return AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
|
||||||
|
|
||||||
all_model_types = set(
|
|
||||||
PROCESSOR_MAPPING_NAMES.keys() + TOKENIZER_MAPPING_NAMES.keys() + FEATURE_EXTRACTOR_MAPPING_NAMES.keys()
|
|
||||||
)
|
|
||||||
all_model_types = list(all_model_types)
|
|
||||||
all_model_types.sort()
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unrecognized processor in {pretrained_model_name_or_path}. Should have a `processor_type` key in "
|
f"Unrecognized processor in {pretrained_model_name_or_path}. Should have a `processor_type` key in "
|
||||||
f"its {FEATURE_EXTRACTOR_NAME}, or one of the following `model_type` keys in its {CONFIG_NAME}: "
|
f"its {FEATURE_EXTRACTOR_NAME}, or one of the following `model_type` keys in its {CONFIG_NAME}: "
|
||||||
f"{', '.join(all_model_types)}"
|
f"{', '.join(c for c in PROCESSOR_MAPPING_NAMES.keys())}"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -17,8 +17,7 @@ import os
|
|||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import AutoProcessor, BeitFeatureExtractor, BertTokenizerFast, Wav2Vec2Config, Wav2Vec2Processor
|
from transformers import AutoProcessor, Wav2Vec2Config, Wav2Vec2Processor
|
||||||
from transformers.testing_utils import require_torch
|
|
||||||
|
|
||||||
|
|
||||||
SAMPLE_PROCESSOR_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
|
SAMPLE_PROCESSOR_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
|
||||||
@@ -45,12 +44,3 @@ class AutoFeatureExtractorTest(unittest.TestCase):
|
|||||||
processor = AutoProcessor.from_pretrained(tmpdirname)
|
processor = AutoProcessor.from_pretrained(tmpdirname)
|
||||||
|
|
||||||
self.assertIsInstance(processor, Wav2Vec2Processor)
|
self.assertIsInstance(processor, Wav2Vec2Processor)
|
||||||
|
|
||||||
def test_auto_processor_reverts_to_tokenizer(self):
|
|
||||||
processor = AutoProcessor.from_pretrained("bert-base-cased")
|
|
||||||
self.assertIsInstance(processor, BertTokenizerFast)
|
|
||||||
|
|
||||||
@require_torch
|
|
||||||
def test_auto_processor_reverts_to_feature_extractor(self):
|
|
||||||
processor = AutoProcessor.from_pretrained("microsoft/beit-base-patch16-224")
|
|
||||||
self.assertIsInstance(processor, BeitFeatureExtractor)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user