[AutoFeatureExtractor] Fix loading of local folders if config.json exists (#13166)
* up * up
This commit is contained in:
committed by
GitHub
parent
439a43b6b4
commit
ecfa7eb260
@@ -20,6 +20,7 @@ from collections import OrderedDict
|
|||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...file_utils import CONFIG_NAME
|
||||||
|
|
||||||
|
|
||||||
CONFIG_MAPPING_NAMES = OrderedDict(
|
CONFIG_MAPPING_NAMES = OrderedDict(
|
||||||
@@ -520,6 +521,6 @@ class AutoConfig:
|
|||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unrecognized model in {pretrained_model_name_or_path}. "
|
f"Unrecognized model in {pretrained_model_name_or_path}. "
|
||||||
"Should have a `model_type` key in its config.json, or contain one of the following strings "
|
f"Should have a `model_type` key in its {CONFIG_NAME}, or contain one of the following strings "
|
||||||
f"in its name: {', '.join(CONFIG_MAPPING.keys())}"
|
f"in its name: {', '.join(CONFIG_MAPPING.keys())}"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from collections import OrderedDict
|
|||||||
# Build the list of all feature extractors
|
# Build the list of all feature extractors
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
from ...feature_extraction_utils import FeatureExtractionMixin
|
from ...feature_extraction_utils import FeatureExtractionMixin
|
||||||
from ...file_utils import FEATURE_EXTRACTOR_NAME
|
from ...file_utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME
|
||||||
from .auto_factory import _LazyAutoMapping
|
from .auto_factory import _LazyAutoMapping
|
||||||
from .configuration_auto import (
|
from .configuration_auto import (
|
||||||
CONFIG_MAPPING_NAMES,
|
CONFIG_MAPPING_NAMES,
|
||||||
@@ -142,7 +142,12 @@ class AutoFeatureExtractor:
|
|||||||
os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME)
|
os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME)
|
||||||
)
|
)
|
||||||
|
|
||||||
if not is_feature_extraction_file and not is_directory:
|
has_local_config = (
|
||||||
|
os.path.exists(os.path.join(pretrained_model_name_or_path, CONFIG_NAME)) if is_directory else False
|
||||||
|
)
|
||||||
|
|
||||||
|
# load config, if it can be loaded
|
||||||
|
if not is_feature_extraction_file and (has_local_config or not is_directory):
|
||||||
if not isinstance(config, PretrainedConfig):
|
if not isinstance(config, PretrainedConfig):
|
||||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
|
|
||||||
@@ -150,6 +155,7 @@ class AutoFeatureExtractor:
|
|||||||
config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
|
config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
|
||||||
|
|
||||||
model_type = config_class_to_model_type(type(config).__name__)
|
model_type = config_class_to_model_type(type(config).__name__)
|
||||||
|
|
||||||
if model_type is not None:
|
if model_type is not None:
|
||||||
return FEATURE_EXTRACTOR_MAPPING[type(config)].from_dict(config_dict, **kwargs)
|
return FEATURE_EXTRACTOR_MAPPING[type(config)].from_dict(config_dict, **kwargs)
|
||||||
elif "feature_extractor_type" in config_dict:
|
elif "feature_extractor_type" in config_dict:
|
||||||
@@ -157,7 +163,7 @@ class AutoFeatureExtractor:
|
|||||||
return feature_extractor_class.from_dict(config_dict, **kwargs)
|
return feature_extractor_class.from_dict(config_dict, **kwargs)
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unrecognized model in {pretrained_model_name_or_path}. Should have a `feature_extractor_type` key in "
|
f"Unrecognized feature extractor in {pretrained_model_name_or_path}. Should have a `feature_extractor_type` key in "
|
||||||
f"its {FEATURE_EXTRACTOR_NAME}, or contain one of the following strings "
|
f"its {FEATURE_EXTRACTOR_NAME}, or one of the following `model_type` keys in its {CONFIG_NAME}: "
|
||||||
f"in its name: {', '.join(FEATURE_EXTRACTOR_MAPPING.keys())}"
|
f"{', '.join(c for c in FEATURE_EXTRACTOR_MAPPING_NAMES.keys())}"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -14,15 +14,17 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import AutoFeatureExtractor, Wav2Vec2FeatureExtractor
|
from transformers import AutoFeatureExtractor, Wav2Vec2Config, Wav2Vec2FeatureExtractor
|
||||||
|
|
||||||
|
|
||||||
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
|
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
|
||||||
SAMPLE_FEATURE_EXTRACTION_CONFIG = os.path.join(
|
SAMPLE_FEATURE_EXTRACTION_CONFIG = os.path.join(
|
||||||
os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy_feature_extractor_config.json"
|
os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy_feature_extractor_config.json"
|
||||||
)
|
)
|
||||||
|
SAMPLE_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy-config.json")
|
||||||
|
|
||||||
|
|
||||||
class AutoFeatureExtractorTest(unittest.TestCase):
|
class AutoFeatureExtractorTest(unittest.TestCase):
|
||||||
@@ -30,10 +32,27 @@ class AutoFeatureExtractorTest(unittest.TestCase):
|
|||||||
config = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
|
config = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
|
||||||
self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
|
self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
|
||||||
|
|
||||||
def test_feature_extractor_from_local_directory(self):
|
def test_feature_extractor_from_local_directory_from_key(self):
|
||||||
config = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
|
config = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
|
||||||
self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
|
self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
|
||||||
|
|
||||||
|
def test_feature_extractor_from_local_directory_from_config(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model_config = Wav2Vec2Config()
|
||||||
|
|
||||||
|
# remove feature_extractor_type to make sure config.json alone is enough to load feature processor locally
|
||||||
|
config_dict = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR).to_dict()
|
||||||
|
config_dict.pop("feature_extractor_type")
|
||||||
|
config = Wav2Vec2FeatureExtractor(config_dict)
|
||||||
|
|
||||||
|
# save in new folder
|
||||||
|
model_config.save_pretrained(tmpdirname)
|
||||||
|
config.save_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
config = AutoFeatureExtractor.from_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
|
||||||
|
|
||||||
def test_feature_extractor_from_local_file(self):
|
def test_feature_extractor_from_local_file(self):
|
||||||
config = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG)
|
config = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG)
|
||||||
self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
|
self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
|
||||||
|
|||||||
Reference in New Issue
Block a user