From e9688875bf8da160d35593dda60e37182cf9fe98 Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Mon, 6 Dec 2021 18:49:50 +0100 Subject: [PATCH] Auto processor fix (#14623) * Add AutoProcessor class Init and tests Add doc Fix init Update src/transformers/models/auto/processing_auto.py Co-authored-by: Lysandre Debut Reverts to tokenizer or feature extractor when available Adapt test * Revert "Adapt test" This reverts commit bbdde5fab02465f24b54b227390073082cb32093. * Revert "Reverts to tokenizer or feature extractor when available" This reverts commit 77659ff5d21b6cc0baf6f443017e35e056a525bb. * Don't revert everything Lysandre! Co-authored-by: Sylvain Gugger --- .../models/auto/processing_auto.py | 22 ++----------------- tests/test_processor_auto.py | 12 +--------- 2 files changed, 3 insertions(+), 31 deletions(-) diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 349a885593..df68052966 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -28,8 +28,6 @@ from .configuration_auto import ( model_type_to_module_name, replace_list_option_in_docstrings, ) -from .feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING_NAMES, AutoFeatureExtractor -from .tokenization_auto import TOKENIZER_MAPPING_NAMES, AutoTokenizer PROCESSOR_MAPPING_NAMES = OrderedDict( @@ -85,9 +83,6 @@ class AutoProcessor: List options - For other types of models, this class will return the appropriate tokenizer (if available) or feature - extractor. - Params: pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): This can be either: @@ -167,24 +162,11 @@ class AutoProcessor: return processor_class.from_pretrained(pretrained_model_name_or_path, **kwargs) model_type = config_class_to_model_type(type(config).__name__) - if model_type is not None and model_type in PROCESSOR_MAPPING_NAMES: + if model_type is not None: return PROCESSOR_MAPPING[type(config)].from_pretrained(pretrained_model_name_or_path, **kwargs) - # At this stage there doesn't seem to be a `Processor` class available for this model, so let's try a tokenizer - if model_type in TOKENIZER_MAPPING_NAMES: - return AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs) - - # At this stage there doesn't seem to be a `Processor` class available for this model, so let's try a tokenizer - if model_type in FEATURE_EXTRACTOR_MAPPING_NAMES: - return AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs) - - all_model_types = set( - PROCESSOR_MAPPING_NAMES.keys() + TOKENIZER_MAPPING_NAMES.keys() + FEATURE_EXTRACTOR_MAPPING_NAMES.keys() - ) - all_model_types = list(all_model_types) - all_model_types.sort() raise ValueError( f"Unrecognized processor in {pretrained_model_name_or_path}. Should have a `processor_type` key in " f"its {FEATURE_EXTRACTOR_NAME}, or one of the following `model_type` keys in its {CONFIG_NAME}: " - f"{', '.join(all_model_types)}" + f"{', '.join(c for c in PROCESSOR_MAPPING_NAMES.keys())}" ) diff --git a/tests/test_processor_auto.py b/tests/test_processor_auto.py index f587b75da4..3afc6db241 100644 --- a/tests/test_processor_auto.py +++ b/tests/test_processor_auto.py @@ -17,8 +17,7 @@ import os import tempfile import unittest -from transformers import AutoProcessor, BeitFeatureExtractor, BertTokenizerFast, Wav2Vec2Config, Wav2Vec2Processor -from transformers.testing_utils import require_torch +from transformers import AutoProcessor, Wav2Vec2Config, Wav2Vec2Processor SAMPLE_PROCESSOR_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures") @@ -45,12 +44,3 @@ class AutoFeatureExtractorTest(unittest.TestCase): processor = AutoProcessor.from_pretrained(tmpdirname) self.assertIsInstance(processor, Wav2Vec2Processor) - - def test_auto_processor_reverts_to_tokenizer(self): - processor = AutoProcessor.from_pretrained("bert-base-cased") - self.assertIsInstance(processor, BertTokenizerFast) - - @require_torch - def test_auto_processor_reverts_to_feature_extractor(self): - processor = AutoProcessor.from_pretrained("microsoft/beit-base-patch16-224") - self.assertIsInstance(processor, BeitFeatureExtractor)