Add register method to AutoProcessor (#15669)
* Add push_to_hub method to processors * Fix test * The other one too! * Add register method to AutoProcessor * Update src/transformers/models/auto/processing_auto.py Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr> Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
@@ -60,7 +60,10 @@ def processor_class_from_name(class_name: str):
|
|||||||
|
|
||||||
module = importlib.import_module(f".{module_name}", "transformers.models")
|
module = importlib.import_module(f".{module_name}", "transformers.models")
|
||||||
return getattr(module, class_name)
|
return getattr(module, class_name)
|
||||||
break
|
|
||||||
|
for processor in PROCESSOR_MAPPING._extra_content.values():
|
||||||
|
if getattr(processor, "__name__", None) == class_name:
|
||||||
|
return processor
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -231,3 +234,15 @@ class AutoProcessor:
|
|||||||
f"its {FEATURE_EXTRACTOR_NAME}, or one of the following `model_type` keys in its {CONFIG_NAME}: "
|
f"its {FEATURE_EXTRACTOR_NAME}, or one of the following `model_type` keys in its {CONFIG_NAME}: "
|
||||||
f"{', '.join(c for c in PROCESSOR_MAPPING_NAMES.keys())}"
|
f"{', '.join(c for c in PROCESSOR_MAPPING_NAMES.keys())}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def register(config_class, processor_class):
|
||||||
|
"""
|
||||||
|
Register a new processor for this class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_class ([`PretrainedConfig`]):
|
||||||
|
The configuration corresponding to the model to register.
|
||||||
|
processor_class ([`FeatureExtractorMixin`]): The processor to register.
|
||||||
|
"""
|
||||||
|
PROCESSOR_MAPPING.register(config_class, processor_class)
|
||||||
|
|||||||
@@ -23,7 +23,19 @@ from shutil import copyfile
|
|||||||
|
|
||||||
from huggingface_hub import Repository, delete_repo, login
|
from huggingface_hub import Repository, delete_repo, login
|
||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
from transformers import AutoProcessor, AutoTokenizer, Wav2Vec2Config, Wav2Vec2FeatureExtractor, Wav2Vec2Processor
|
from transformers import (
|
||||||
|
CONFIG_MAPPING,
|
||||||
|
FEATURE_EXTRACTOR_MAPPING,
|
||||||
|
PROCESSOR_MAPPING,
|
||||||
|
TOKENIZER_MAPPING,
|
||||||
|
AutoConfig,
|
||||||
|
AutoFeatureExtractor,
|
||||||
|
AutoProcessor,
|
||||||
|
AutoTokenizer,
|
||||||
|
Wav2Vec2Config,
|
||||||
|
Wav2Vec2FeatureExtractor,
|
||||||
|
Wav2Vec2Processor,
|
||||||
|
)
|
||||||
from transformers.file_utils import FEATURE_EXTRACTOR_NAME, is_tokenizers_available
|
from transformers.file_utils import FEATURE_EXTRACTOR_NAME, is_tokenizers_available
|
||||||
from transformers.testing_utils import PASS, USER, is_staging_test
|
from transformers.testing_utils import PASS, USER, is_staging_test
|
||||||
from transformers.tokenization_utils import TOKENIZER_CONFIG_FILE
|
from transformers.tokenization_utils import TOKENIZER_CONFIG_FILE
|
||||||
@@ -31,6 +43,7 @@ from transformers.tokenization_utils import TOKENIZER_CONFIG_FILE
|
|||||||
|
|
||||||
sys.path.append(str(Path(__file__).parent.parent / "utils"))
|
sys.path.append(str(Path(__file__).parent.parent / "utils"))
|
||||||
|
|
||||||
|
from test_module.custom_configuration import CustomConfig # noqa E402
|
||||||
from test_module.custom_feature_extraction import CustomFeatureExtractor # noqa E402
|
from test_module.custom_feature_extraction import CustomFeatureExtractor # noqa E402
|
||||||
from test_module.custom_processing import CustomProcessor # noqa E402
|
from test_module.custom_processing import CustomProcessor # noqa E402
|
||||||
from test_module.custom_tokenization import CustomTokenizer # noqa E402
|
from test_module.custom_tokenization import CustomTokenizer # noqa E402
|
||||||
@@ -45,6 +58,8 @@ SAMPLE_PROCESSOR_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__fil
|
|||||||
|
|
||||||
|
|
||||||
class AutoFeatureExtractorTest(unittest.TestCase):
|
class AutoFeatureExtractorTest(unittest.TestCase):
|
||||||
|
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "bla", "blou"]
|
||||||
|
|
||||||
def test_processor_from_model_shortcut(self):
|
def test_processor_from_model_shortcut(self):
|
||||||
processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
|
processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
|
||||||
self.assertIsInstance(processor, Wav2Vec2Processor)
|
self.assertIsInstance(processor, Wav2Vec2Processor)
|
||||||
@@ -154,6 +169,42 @@ class AutoFeatureExtractorTest(unittest.TestCase):
|
|||||||
else:
|
else:
|
||||||
self.assertEqual(tokenizer.__class__.__name__, "NewTokenizer")
|
self.assertEqual(tokenizer.__class__.__name__, "NewTokenizer")
|
||||||
|
|
||||||
|
def test_new_processor_registration(self):
|
||||||
|
try:
|
||||||
|
AutoConfig.register("custom", CustomConfig)
|
||||||
|
AutoFeatureExtractor.register(CustomConfig, CustomFeatureExtractor)
|
||||||
|
AutoTokenizer.register(CustomConfig, slow_tokenizer_class=CustomTokenizer)
|
||||||
|
AutoProcessor.register(CustomConfig, CustomProcessor)
|
||||||
|
# Trying to register something existing in the Transformers library will raise an error
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
AutoProcessor.register(Wav2Vec2Config, Wav2Vec2Processor)
|
||||||
|
|
||||||
|
# Now that the config is registered, it can be used as any other config with the auto-API
|
||||||
|
feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
vocab_file = os.path.join(tmp_dir, "vocab.txt")
|
||||||
|
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||||
|
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
|
||||||
|
tokenizer = CustomTokenizer(vocab_file)
|
||||||
|
|
||||||
|
processor = CustomProcessor(feature_extractor, tokenizer)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
processor.save_pretrained(tmp_dir)
|
||||||
|
new_processor = AutoProcessor.from_pretrained(tmp_dir)
|
||||||
|
self.assertIsInstance(new_processor, CustomProcessor)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if "custom" in CONFIG_MAPPING._extra_content:
|
||||||
|
del CONFIG_MAPPING._extra_content["custom"]
|
||||||
|
if CustomConfig in FEATURE_EXTRACTOR_MAPPING._extra_content:
|
||||||
|
del FEATURE_EXTRACTOR_MAPPING._extra_content[CustomConfig]
|
||||||
|
if CustomConfig in TOKENIZER_MAPPING._extra_content:
|
||||||
|
del TOKENIZER_MAPPING._extra_content[CustomConfig]
|
||||||
|
if CustomConfig in PROCESSOR_MAPPING._extra_content:
|
||||||
|
del PROCESSOR_MAPPING._extra_content[CustomConfig]
|
||||||
|
|
||||||
|
|
||||||
@is_staging_test
|
@is_staging_test
|
||||||
class ProcessorPushToHubTester(unittest.TestCase):
|
class ProcessorPushToHubTester(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user