[AutoFeatureExtractor] Fix loading of local folders if config.json exists (#13166)
* up * up
This commit is contained in:
committed by
GitHub
parent
439a43b6b4
commit
ecfa7eb260
@@ -20,6 +20,7 @@ from collections import OrderedDict
|
||||
from typing import List, Union
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...file_utils import CONFIG_NAME
|
||||
|
||||
|
||||
CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
@@ -520,6 +521,6 @@ class AutoConfig:
|
||||
|
||||
raise ValueError(
|
||||
f"Unrecognized model in {pretrained_model_name_or_path}. "
|
||||
"Should have a `model_type` key in its config.json, or contain one of the following strings "
|
||||
f"Should have a `model_type` key in its {CONFIG_NAME}, or contain one of the following strings "
|
||||
f"in its name: {', '.join(CONFIG_MAPPING.keys())}"
|
||||
)
|
||||
|
||||
@@ -20,7 +20,7 @@ from collections import OrderedDict
|
||||
# Build the list of all feature extractors
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...feature_extraction_utils import FeatureExtractionMixin
|
||||
from ...file_utils import FEATURE_EXTRACTOR_NAME
|
||||
from ...file_utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME
|
||||
from .auto_factory import _LazyAutoMapping
|
||||
from .configuration_auto import (
|
||||
CONFIG_MAPPING_NAMES,
|
||||
@@ -142,7 +142,12 @@ class AutoFeatureExtractor:
|
||||
os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME)
|
||||
)
|
||||
|
||||
if not is_feature_extraction_file and not is_directory:
|
||||
has_local_config = (
|
||||
os.path.exists(os.path.join(pretrained_model_name_or_path, CONFIG_NAME)) if is_directory else False
|
||||
)
|
||||
|
||||
# load config, if it can be loaded
|
||||
if not is_feature_extraction_file and (has_local_config or not is_directory):
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
@@ -150,6 +155,7 @@ class AutoFeatureExtractor:
|
||||
config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
model_type = config_class_to_model_type(type(config).__name__)
|
||||
|
||||
if model_type is not None:
|
||||
return FEATURE_EXTRACTOR_MAPPING[type(config)].from_dict(config_dict, **kwargs)
|
||||
elif "feature_extractor_type" in config_dict:
|
||||
@@ -157,7 +163,7 @@ class AutoFeatureExtractor:
|
||||
return feature_extractor_class.from_dict(config_dict, **kwargs)
|
||||
|
||||
raise ValueError(
|
||||
f"Unrecognized model in {pretrained_model_name_or_path}. Should have a `feature_extractor_type` key in "
|
||||
f"its {FEATURE_EXTRACTOR_NAME}, or contain one of the following strings "
|
||||
f"in its name: {', '.join(FEATURE_EXTRACTOR_MAPPING.keys())}"
|
||||
f"Unrecognized feature extractor in {pretrained_model_name_or_path}. Should have a `feature_extractor_type` key in "
|
||||
f"its {FEATURE_EXTRACTOR_NAME}, or one of the following `model_type` keys in its {CONFIG_NAME}: "
|
||||
f"{', '.join(c for c in FEATURE_EXTRACTOR_MAPPING_NAMES.keys())}"
|
||||
)
|
||||
|
||||
@@ -14,15 +14,17 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import AutoFeatureExtractor, Wav2Vec2FeatureExtractor
|
||||
from transformers import AutoFeatureExtractor, Wav2Vec2Config, Wav2Vec2FeatureExtractor
|
||||
|
||||
|
||||
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
|
||||
SAMPLE_FEATURE_EXTRACTION_CONFIG = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy_feature_extractor_config.json"
|
||||
)
|
||||
SAMPLE_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy-config.json")
|
||||
|
||||
|
||||
class AutoFeatureExtractorTest(unittest.TestCase):
|
||||
@@ -30,10 +32,27 @@ class AutoFeatureExtractorTest(unittest.TestCase):
|
||||
config = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
|
||||
|
||||
def test_feature_extractor_from_local_directory(self):
|
||||
def test_feature_extractor_from_local_directory_from_key(self):
|
||||
config = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
|
||||
self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
|
||||
|
||||
def test_feature_extractor_from_local_directory_from_config(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model_config = Wav2Vec2Config()
|
||||
|
||||
# remove feature_extractor_type to make sure config.json alone is enough to load feature processor locally
|
||||
config_dict = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR).to_dict()
|
||||
config_dict.pop("feature_extractor_type")
|
||||
config = Wav2Vec2FeatureExtractor(config_dict)
|
||||
|
||||
# save in new folder
|
||||
model_config.save_pretrained(tmpdirname)
|
||||
config.save_pretrained(tmpdirname)
|
||||
|
||||
config = AutoFeatureExtractor.from_pretrained(tmpdirname)
|
||||
|
||||
self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
|
||||
|
||||
def test_feature_extractor_from_local_file(self):
|
||||
config = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG)
|
||||
self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
|
||||
|
||||
Reference in New Issue
Block a user