From 12313838d33373d06d35b48c3c501fa832f16443 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Thu, 5 Jan 2023 07:30:25 -0500 Subject: [PATCH] Make sure dynamic objects can be saved and reloaded (#21008) * Make sure dynamic objects can be saved and reloaded * Remove processor test --- src/transformers/models/auto/auto_factory.py | 1 + .../models/auto/configuration_auto.py | 1 + .../models/auto/feature_extraction_auto.py | 1 + .../models/auto/image_processing_auto.py | 1 + .../models/auto/processing_auto.py | 1 + .../models/auto/tokenization_auto.py | 1 + tests/models/auto/test_configuration_auto.py | 6 ++++++ .../auto/test_feature_extraction_auto.py | 10 ++++++++-- .../models/auto/test_image_processing_auto.py | 10 ++++++++-- tests/models/auto/test_modeling_auto.py | 18 ++++++++++++++++++ tests/models/auto/test_processor_auto.py | 8 ++++---- tests/models/auto/test_tokenization_auto.py | 14 ++++++++++++++ 12 files changed, 64 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py index 04eb3feaac..d906505987 100644 --- a/src/transformers/models/auto/auto_factory.py +++ b/src/transformers/models/auto/auto_factory.py @@ -455,6 +455,7 @@ class _BaseAutoModelClass: model_class = get_class_from_dynamic_module( pretrained_model_name_or_path, module_file + ".py", class_name, **hub_kwargs, **kwargs ) + model_class.register_for_auto_class(cls.__name__) return model_class.from_pretrained( pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs ) diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 26c00ac3cd..6a49d2f4e2 100755 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -853,6 +853,7 @@ class AutoConfig: config_class = get_class_from_dynamic_module( pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs ) + config_class.register_for_auto_class() return config_class.from_pretrained(pretrained_model_name_or_path, **kwargs) elif "model_type" in config_dict: config_class = CONFIG_MAPPING[config_dict["model_type"]] diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index a33affe3ec..3726f9f238 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -340,6 +340,7 @@ class AutoFeatureExtractor: feature_extractor_class = get_class_from_dynamic_module( pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs ) + feature_extractor_class.register_for_auto_class() else: feature_extractor_class = feature_extractor_class_from_name(feature_extractor_class) diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 0fee88153a..e23458955c 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -352,6 +352,7 @@ class AutoImageProcessor: image_processor_class = get_class_from_dynamic_module( pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs ) + image_processor_class.register_for_auto_class() else: image_processor_class = image_processor_class_from_name(image_processor_class) diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index ee662ae57d..f1ad8f221a 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -256,6 +256,7 @@ class AutoProcessor: processor_class = get_class_from_dynamic_module( pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs ) + processor_class.register_for_auto_class() else: processor_class = processor_class_from_name(processor_class) diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index c0c700cdf1..0b21273ca9 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -641,6 +641,7 @@ class AutoTokenizer: tokenizer_class = get_class_from_dynamic_module( pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs ) + tokenizer_class.register_for_auto_class() elif use_fast and not config_tokenizer_class.endswith("Fast"): tokenizer_class_candidate = f"{config_tokenizer_class}Fast" diff --git a/tests/models/auto/test_configuration_auto.py b/tests/models/auto/test_configuration_auto.py index 2695082c41..030a03aa6d 100644 --- a/tests/models/auto/test_configuration_auto.py +++ b/tests/models/auto/test_configuration_auto.py @@ -110,3 +110,9 @@ class AutoConfigTest(unittest.TestCase): def test_from_pretrained_dynamic_config(self): config = AutoConfig.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=True) self.assertEqual(config.__class__.__name__, "NewModelConfig") + + # Test config can be reloaded. + with tempfile.TemporaryDirectory() as tmp_dir: + config.save_pretrained(tmp_dir) + reloaded_config = AutoConfig.from_pretrained(tmp_dir, trust_remote_code=True) + self.assertEqual(reloaded_config.__class__.__name__, "NewModelConfig") diff --git a/tests/models/auto/test_feature_extraction_auto.py b/tests/models/auto/test_feature_extraction_auto.py index e9d044e8da..35d3ac0fa4 100644 --- a/tests/models/auto/test_feature_extraction_auto.py +++ b/tests/models/auto/test_feature_extraction_auto.py @@ -96,10 +96,16 @@ class AutoFeatureExtractorTest(unittest.TestCase): _ = AutoFeatureExtractor.from_pretrained("hf-internal-testing/config-no-model") def test_from_pretrained_dynamic_feature_extractor(self): - model = AutoFeatureExtractor.from_pretrained( + feature_extractor = AutoFeatureExtractor.from_pretrained( "hf-internal-testing/test_dynamic_feature_extractor", trust_remote_code=True ) - self.assertEqual(model.__class__.__name__, "NewFeatureExtractor") + self.assertEqual(feature_extractor.__class__.__name__, "NewFeatureExtractor") + + # Test feature extractor can be reloaded. + with tempfile.TemporaryDirectory() as tmp_dir: + feature_extractor.save_pretrained(tmp_dir) + reloaded_feature_extractor = AutoFeatureExtractor.from_pretrained(tmp_dir, trust_remote_code=True) + self.assertEqual(reloaded_feature_extractor.__class__.__name__, "NewFeatureExtractor") def test_new_feature_extractor_registration(self): try: diff --git a/tests/models/auto/test_image_processing_auto.py b/tests/models/auto/test_image_processing_auto.py index 3d2009d5c8..7b2296e71d 100644 --- a/tests/models/auto/test_image_processing_auto.py +++ b/tests/models/auto/test_image_processing_auto.py @@ -130,10 +130,16 @@ class AutoImageProcessorTest(unittest.TestCase): _ = AutoImageProcessor.from_pretrained("hf-internal-testing/config-no-model") def test_from_pretrained_dynamic_image_processor(self): - model = AutoImageProcessor.from_pretrained( + image_processor = AutoImageProcessor.from_pretrained( "hf-internal-testing/test_dynamic_image_processor", trust_remote_code=True ) - self.assertEqual(model.__class__.__name__, "NewImageProcessor") + self.assertEqual(image_processor.__class__.__name__, "NewImageProcessor") + + # Test image processor can be reloaded. + with tempfile.TemporaryDirectory() as tmp_dir: + image_processor.save_pretrained(tmp_dir) + reloaded_image_processor = AutoImageProcessor.from_pretrained(tmp_dir, trust_remote_code=True) + self.assertEqual(reloaded_image_processor.__class__.__name__, "NewImageProcessor") def test_new_image_processor_registration(self): try: diff --git a/tests/models/auto/test_modeling_auto.py b/tests/models/auto/test_modeling_auto.py index c59abe4cd4..0008aa101b 100644 --- a/tests/models/auto/test_modeling_auto.py +++ b/tests/models/auto/test_modeling_auto.py @@ -276,10 +276,28 @@ class AutoModelTest(unittest.TestCase): model = AutoModel.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=True) self.assertEqual(model.__class__.__name__, "NewModel") + # Test model can be reloaded. + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir) + reloaded_model = AutoModel.from_pretrained(tmp_dir, trust_remote_code=True) + + self.assertEqual(reloaded_model.__class__.__name__, "NewModel") + for p1, p2 in zip(model.parameters(), reloaded_model.parameters()): + self.assertTrue(torch.equal(p1, p2)) + # This one uses a relative import to a util file, this checks it is downloaded and used properly. model = AutoModel.from_pretrained("hf-internal-testing/test_dynamic_model_with_util", trust_remote_code=True) self.assertEqual(model.__class__.__name__, "NewModel") + # Test model can be reloaded. + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir) + reloaded_model = AutoModel.from_pretrained(tmp_dir, trust_remote_code=True) + + self.assertEqual(reloaded_model.__class__.__name__, "NewModel") + for p1, p2 in zip(model.parameters(), reloaded_model.parameters()): + self.assertTrue(torch.equal(p1, p2)) + def test_new_model_registration(self): AutoConfig.register("custom", CustomConfig) diff --git a/tests/models/auto/test_processor_auto.py b/tests/models/auto/test_processor_auto.py index 6cddfc1376..91cd85a893 100644 --- a/tests/models/auto/test_processor_auto.py +++ b/tests/models/auto/test_processor_auto.py @@ -157,12 +157,12 @@ class AutoFeatureExtractorTest(unittest.TestCase): self.assertEqual(tokenizer.__class__.__name__, "NewTokenizerFast") # Test we can also load the slow version - processor = AutoProcessor.from_pretrained( + new_processor = AutoProcessor.from_pretrained( "hf-internal-testing/test_dynamic_processor", trust_remote_code=True, use_fast=False ) - tokenizer = processor.tokenizer - self.assertTrue(tokenizer.special_attribute_present) - self.assertEqual(tokenizer.__class__.__name__, "NewTokenizer") + new_tokenizer = new_processor.tokenizer + self.assertTrue(new_tokenizer.special_attribute_present) + self.assertEqual(new_tokenizer.__class__.__name__, "NewTokenizer") else: self.assertEqual(tokenizer.__class__.__name__, "NewTokenizer") diff --git a/tests/models/auto/test_tokenization_auto.py b/tests/models/auto/test_tokenization_auto.py index 020eea72cd..5814a76c37 100644 --- a/tests/models/auto/test_tokenization_auto.py +++ b/tests/models/auto/test_tokenization_auto.py @@ -302,8 +302,15 @@ class AutoTokenizerTest(unittest.TestCase): def test_from_pretrained_dynamic_tokenizer(self): tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/test_dynamic_tokenizer", trust_remote_code=True) self.assertTrue(tokenizer.special_attribute_present) + # Test tokenizer can be reloaded. + with tempfile.TemporaryDirectory() as tmp_dir: + tokenizer.save_pretrained(tmp_dir) + reloaded_tokenizer = AutoTokenizer.from_pretrained(tmp_dir, trust_remote_code=True) + self.assertTrue(reloaded_tokenizer.special_attribute_present) + if is_tokenizers_available(): self.assertEqual(tokenizer.__class__.__name__, "NewTokenizerFast") + self.assertEqual(reloaded_tokenizer.__class__.__name__, "NewTokenizerFast") # Test we can also load the slow version tokenizer = AutoTokenizer.from_pretrained( @@ -311,8 +318,15 @@ class AutoTokenizerTest(unittest.TestCase): ) self.assertTrue(tokenizer.special_attribute_present) self.assertEqual(tokenizer.__class__.__name__, "NewTokenizer") + # Test tokenizer can be reloaded. + with tempfile.TemporaryDirectory() as tmp_dir: + tokenizer.save_pretrained(tmp_dir) + reloaded_tokenizer = AutoTokenizer.from_pretrained(tmp_dir, trust_remote_code=True, use_fast=False) + self.assertEqual(reloaded_tokenizer.__class__.__name__, "NewTokenizer") + self.assertTrue(reloaded_tokenizer.special_attribute_present) else: self.assertEqual(tokenizer.__class__.__name__, "NewTokenizer") + self.assertEqual(reloaded_tokenizer.__class__.__name__, "NewTokenizer") def test_from_pretrained_dynamic_tokenizer_legacy_format(self): tokenizer = AutoTokenizer.from_pretrained(