From cdc51ffd27f8f5a3151da161ae2b5dbb410d2803 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Wed, 16 Feb 2022 09:13:33 -0500 Subject: [PATCH] Add register method to AutoProcessor (#15669) * Add push_to_hub method to processors * Fix test * The other one too! * Add register method to AutoProcessor * Update src/transformers/models/auto/processing_auto.py Co-authored-by: Lysandre Debut Co-authored-by: Lysandre Debut --- .../models/auto/processing_auto.py | 17 +++++- tests/test_processor_auto.py | 53 ++++++++++++++++++- 2 files changed, 68 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 7b1365a3e3..68b846da96 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -60,7 +60,10 @@ def processor_class_from_name(class_name: str): module = importlib.import_module(f".{module_name}", "transformers.models") return getattr(module, class_name) - break + + for processor in PROCESSOR_MAPPING._extra_content.values(): + if getattr(processor, "__name__", None) == class_name: + return processor return None @@ -231,3 +234,15 @@ class AutoProcessor: f"its {FEATURE_EXTRACTOR_NAME}, or one of the following `model_type` keys in its {CONFIG_NAME}: " f"{', '.join(c for c in PROCESSOR_MAPPING_NAMES.keys())}" ) + + @staticmethod + def register(config_class, processor_class): + """ + Register a new processor for this class. + + Args: + config_class ([`PretrainedConfig`]): + The configuration corresponding to the model to register. + processor_class ([`FeatureExtractorMixin`]): The processor to register. + """ + PROCESSOR_MAPPING.register(config_class, processor_class) diff --git a/tests/test_processor_auto.py b/tests/test_processor_auto.py index 3f56c95fde..d4a543ee5c 100644 --- a/tests/test_processor_auto.py +++ b/tests/test_processor_auto.py @@ -23,7 +23,19 @@ from shutil import copyfile from huggingface_hub import Repository, delete_repo, login from requests.exceptions import HTTPError -from transformers import AutoProcessor, AutoTokenizer, Wav2Vec2Config, Wav2Vec2FeatureExtractor, Wav2Vec2Processor +from transformers import ( + CONFIG_MAPPING, + FEATURE_EXTRACTOR_MAPPING, + PROCESSOR_MAPPING, + TOKENIZER_MAPPING, + AutoConfig, + AutoFeatureExtractor, + AutoProcessor, + AutoTokenizer, + Wav2Vec2Config, + Wav2Vec2FeatureExtractor, + Wav2Vec2Processor, +) from transformers.file_utils import FEATURE_EXTRACTOR_NAME, is_tokenizers_available from transformers.testing_utils import PASS, USER, is_staging_test from transformers.tokenization_utils import TOKENIZER_CONFIG_FILE @@ -31,6 +43,7 @@ from transformers.tokenization_utils import TOKENIZER_CONFIG_FILE sys.path.append(str(Path(__file__).parent.parent / "utils")) +from test_module.custom_configuration import CustomConfig # noqa E402 from test_module.custom_feature_extraction import CustomFeatureExtractor # noqa E402 from test_module.custom_processing import CustomProcessor # noqa E402 from test_module.custom_tokenization import CustomTokenizer # noqa E402 @@ -45,6 +58,8 @@ SAMPLE_PROCESSOR_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__fil class AutoFeatureExtractorTest(unittest.TestCase): + vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "bla", "blou"] + def test_processor_from_model_shortcut(self): processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h") self.assertIsInstance(processor, Wav2Vec2Processor) @@ -154,6 +169,42 @@ class AutoFeatureExtractorTest(unittest.TestCase): else: self.assertEqual(tokenizer.__class__.__name__, "NewTokenizer") + def test_new_processor_registration(self): + try: + AutoConfig.register("custom", CustomConfig) + AutoFeatureExtractor.register(CustomConfig, CustomFeatureExtractor) + AutoTokenizer.register(CustomConfig, slow_tokenizer_class=CustomTokenizer) + AutoProcessor.register(CustomConfig, CustomProcessor) + # Trying to register something existing in the Transformers library will raise an error + with self.assertRaises(ValueError): + AutoProcessor.register(Wav2Vec2Config, Wav2Vec2Processor) + + # Now that the config is registered, it can be used as any other config with the auto-API + feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR) + + with tempfile.TemporaryDirectory() as tmp_dir: + vocab_file = os.path.join(tmp_dir, "vocab.txt") + with open(vocab_file, "w", encoding="utf-8") as vocab_writer: + vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens])) + tokenizer = CustomTokenizer(vocab_file) + + processor = CustomProcessor(feature_extractor, tokenizer) + + with tempfile.TemporaryDirectory() as tmp_dir: + processor.save_pretrained(tmp_dir) + new_processor = AutoProcessor.from_pretrained(tmp_dir) + self.assertIsInstance(new_processor, CustomProcessor) + + finally: + if "custom" in CONFIG_MAPPING._extra_content: + del CONFIG_MAPPING._extra_content["custom"] + if CustomConfig in FEATURE_EXTRACTOR_MAPPING._extra_content: + del FEATURE_EXTRACTOR_MAPPING._extra_content[CustomConfig] + if CustomConfig in TOKENIZER_MAPPING._extra_content: + del TOKENIZER_MAPPING._extra_content[CustomConfig] + if CustomConfig in PROCESSOR_MAPPING._extra_content: + del PROCESSOR_MAPPING._extra_content[CustomConfig] + @is_staging_test class ProcessorPushToHubTester(unittest.TestCase):