Add the ImageClassificationPipeline (#11598)
* Add the ImageClassificationPipeline * Code review Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com> * Have `load_image` at the module level Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -16,9 +16,10 @@
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from transformers import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor, Wav2Vec2FeatureExtractor
|
||||
from transformers import AutoFeatureExtractor, Wav2Vec2FeatureExtractor
|
||||
|
||||
|
||||
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
|
||||
SAMPLE_FEATURE_EXTRACTION_CONFIG = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy_feature_extractor_config.json"
|
||||
)
|
||||
@@ -29,16 +30,10 @@ class AutoFeatureExtractorTest(unittest.TestCase):
|
||||
config = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
|
||||
|
||||
def test_feature_extractor_from_local_directory(self):
|
||||
config = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
|
||||
self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
|
||||
|
||||
def test_feature_extractor_from_local_file(self):
|
||||
config = AutoFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG)
|
||||
self.assertIsInstance(config, Wav2Vec2FeatureExtractor)
|
||||
|
||||
def test_pattern_matching_fallback(self):
|
||||
"""
|
||||
In cases where config.json doesn't include a model_type,
|
||||
perform a few safety checks on the config mapping's order.
|
||||
"""
|
||||
# no key string should be included in a later key string (typical failure case)
|
||||
keys = list(FEATURE_EXTRACTOR_MAPPING.keys())
|
||||
for i, key in enumerate(keys):
|
||||
self.assertFalse(any(key in later_key for later_key in keys[i + 1 :]))
|
||||
|
||||
Reference in New Issue
Block a user