From ecfa7eb2606338641715ac2becbbd7de4dd8031b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 18 Aug 2021 16:18:13 +0200 Subject: [PATCH] [AutoFeatureExtractor] Fix loading of local folders if config.json exists (#13166) * up * up --- .../models/auto/configuration_auto.py | 3 ++- .../models/auto/feature_extraction_auto.py | 16 +++++++++---- tests/test_feature_extraction_auto.py | 23 +++++++++++++++++-- 3 files changed, 34 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index df7400e917..fa42da76f5 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -20,6 +20,7 @@ from collections import OrderedDict from typing import List, Union from ...configuration_utils import PretrainedConfig +from ...file_utils import CONFIG_NAME CONFIG_MAPPING_NAMES = OrderedDict( @@ -520,6 +521,6 @@ class AutoConfig: raise ValueError( f"Unrecognized model in {pretrained_model_name_or_path}. " - "Should have a `model_type` key in its config.json, or contain one of the following strings " + f"Should have a `model_type` key in its {CONFIG_NAME}, or contain one of the following strings " f"in its name: {', '.join(CONFIG_MAPPING.keys())}" ) diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index 39ba15f5ac..54c03a3dcb 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -20,7 +20,7 @@ from collections import OrderedDict # Build the list of all feature extractors from ...configuration_utils import PretrainedConfig from ...feature_extraction_utils import FeatureExtractionMixin -from ...file_utils import FEATURE_EXTRACTOR_NAME +from ...file_utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME from .auto_factory import _LazyAutoMapping from .configuration_auto import ( CONFIG_MAPPING_NAMES, @@ -142,7 +142,12 @@ class AutoFeatureExtractor: os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME) ) - if not is_feature_extraction_file and not is_directory: + has_local_config = ( + os.path.exists(os.path.join(pretrained_model_name_or_path, CONFIG_NAME)) if is_directory else False + ) + + # load config, if it can be loaded + if not is_feature_extraction_file and (has_local_config or not is_directory): if not isinstance(config, PretrainedConfig): config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) @@ -150,6 +155,7 @@ class AutoFeatureExtractor: config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs) model_type = config_class_to_model_type(type(config).__name__) + if model_type is not None: return FEATURE_EXTRACTOR_MAPPING[type(config)].from_dict(config_dict, **kwargs) elif "feature_extractor_type" in config_dict: @@ -157,7 +163,7 @@ class AutoFeatureExtractor: return feature_extractor_class.from_dict(config_dict, **kwargs) raise ValueError( - f"Unrecognized model in {pretrained_model_name_or_path}. Should have a `feature_extractor_type` key in " - f"its {FEATURE_EXTRACTOR_NAME}, or contain one of the following strings " - f"in its name: {', '.join(FEATURE_EXTRACTOR_MAPPING.keys())}" + f"Unrecognized feature extractor in {pretrained_model_name_or_path}. Should have a `feature_extractor_type` key in " + f"its {FEATURE_EXTRACTOR_NAME}, or one of the following `model_type` keys in its {CONFIG_NAME}: " + f"{', '.join(c for c in FEATURE_EXTRACTOR_MAPPING_NAMES.keys())}" ) diff --git a/tests/test_feature_extraction_auto.py b/tests/test_feature_extraction_auto.py index 7502e84224..5b219e0d51 100644 --- a/tests/test_feature_extraction_auto.py +++ b/tests/test_feature_extraction_auto.py @@ -14,15 +14,17 @@ # limitations under the License. import os +import tempfile import unittest -from transformers import AutoFeatureExtractor, Wav2Vec2FeatureExtractor +from transformers import AutoFeatureExtractor, Wav2Vec2Config, Wav2Vec2FeatureExtractor SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures") SAMPLE_FEATURE_EXTRACTION_CONFIG = os.path.join( os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy_feature_extractor_config.json" ) +SAMPLE_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy-config.json") class AutoFeatureExtractorTest(unittest.TestCase): @@ -30,10 +32,27 @@ class AutoFeatureExtractorTest(unittest.TestCase): config = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h") self.assertIsInstance(config, Wav2Vec2FeatureExtractor) - def test_feature_extractor_from_local_directory(self): + def test_feature_extractor_from_local_directory_from_key(self): config = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR) self.assertIsInstance(config, Wav2Vec2FeatureExtractor) + def test_feature_extractor_from_local_directory_from_config(self): + with tempfile.TemporaryDirectory() as tmpdirname: + model_config = Wav2Vec2Config() + + # remove feature_extractor_type to make sure config.json alone is enough to load feature processor locally + config_dict = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR).to_dict() + config_dict.pop("feature_extractor_type") + config = Wav2Vec2FeatureExtractor(config_dict) + + # save in new folder + model_config.save_pretrained(tmpdirname) + config.save_pretrained(tmpdirname) + + config = AutoFeatureExtractor.from_pretrained(tmpdirname) + + self.assertIsInstance(config, Wav2Vec2FeatureExtractor) + def test_feature_extractor_from_local_file(self): config = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG) self.assertIsInstance(config, Wav2Vec2FeatureExtractor)