Make sure dynamic objects can be saved and reloaded (#21008)

* Make sure dynamic objects can be saved and reloaded

* Remove processor test
This commit is contained in:
Sylvain Gugger
2023-01-05 07:30:25 -05:00
committed by GitHub
parent bf82c9b74f
commit 12313838d3
12 changed files with 64 additions and 8 deletions

View File

@@ -110,3 +110,9 @@ class AutoConfigTest(unittest.TestCase):
def test_from_pretrained_dynamic_config(self):
config = AutoConfig.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=True)
self.assertEqual(config.__class__.__name__, "NewModelConfig")
# Test config can be reloaded.
with tempfile.TemporaryDirectory() as tmp_dir:
config.save_pretrained(tmp_dir)
reloaded_config = AutoConfig.from_pretrained(tmp_dir, trust_remote_code=True)
self.assertEqual(reloaded_config.__class__.__name__, "NewModelConfig")

View File

@@ -96,10 +96,16 @@ class AutoFeatureExtractorTest(unittest.TestCase):
_ = AutoFeatureExtractor.from_pretrained("hf-internal-testing/config-no-model")
def test_from_pretrained_dynamic_feature_extractor(self):
model = AutoFeatureExtractor.from_pretrained(
feature_extractor = AutoFeatureExtractor.from_pretrained(
"hf-internal-testing/test_dynamic_feature_extractor", trust_remote_code=True
)
self.assertEqual(model.__class__.__name__, "NewFeatureExtractor")
self.assertEqual(feature_extractor.__class__.__name__, "NewFeatureExtractor")
# Test feature extractor can be reloaded.
with tempfile.TemporaryDirectory() as tmp_dir:
feature_extractor.save_pretrained(tmp_dir)
reloaded_feature_extractor = AutoFeatureExtractor.from_pretrained(tmp_dir, trust_remote_code=True)
self.assertEqual(reloaded_feature_extractor.__class__.__name__, "NewFeatureExtractor")
def test_new_feature_extractor_registration(self):
try:

View File

@@ -130,10 +130,16 @@ class AutoImageProcessorTest(unittest.TestCase):
_ = AutoImageProcessor.from_pretrained("hf-internal-testing/config-no-model")
def test_from_pretrained_dynamic_image_processor(self):
model = AutoImageProcessor.from_pretrained(
image_processor = AutoImageProcessor.from_pretrained(
"hf-internal-testing/test_dynamic_image_processor", trust_remote_code=True
)
self.assertEqual(model.__class__.__name__, "NewImageProcessor")
self.assertEqual(image_processor.__class__.__name__, "NewImageProcessor")
# Test image processor can be reloaded.
with tempfile.TemporaryDirectory() as tmp_dir:
image_processor.save_pretrained(tmp_dir)
reloaded_image_processor = AutoImageProcessor.from_pretrained(tmp_dir, trust_remote_code=True)
self.assertEqual(reloaded_image_processor.__class__.__name__, "NewImageProcessor")
def test_new_image_processor_registration(self):
try:

View File

@@ -276,10 +276,28 @@ class AutoModelTest(unittest.TestCase):
model = AutoModel.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=True)
self.assertEqual(model.__class__.__name__, "NewModel")
# Test model can be reloaded.
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
reloaded_model = AutoModel.from_pretrained(tmp_dir, trust_remote_code=True)
self.assertEqual(reloaded_model.__class__.__name__, "NewModel")
for p1, p2 in zip(model.parameters(), reloaded_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
# This one uses a relative import to a util file, this checks it is downloaded and used properly.
model = AutoModel.from_pretrained("hf-internal-testing/test_dynamic_model_with_util", trust_remote_code=True)
self.assertEqual(model.__class__.__name__, "NewModel")
# Test model can be reloaded.
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
reloaded_model = AutoModel.from_pretrained(tmp_dir, trust_remote_code=True)
self.assertEqual(reloaded_model.__class__.__name__, "NewModel")
for p1, p2 in zip(model.parameters(), reloaded_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
def test_new_model_registration(self):
AutoConfig.register("custom", CustomConfig)

View File

@@ -157,12 +157,12 @@ class AutoFeatureExtractorTest(unittest.TestCase):
self.assertEqual(tokenizer.__class__.__name__, "NewTokenizerFast")
# Test we can also load the slow version
processor = AutoProcessor.from_pretrained(
new_processor = AutoProcessor.from_pretrained(
"hf-internal-testing/test_dynamic_processor", trust_remote_code=True, use_fast=False
)
tokenizer = processor.tokenizer
self.assertTrue(tokenizer.special_attribute_present)
self.assertEqual(tokenizer.__class__.__name__, "NewTokenizer")
new_tokenizer = new_processor.tokenizer
self.assertTrue(new_tokenizer.special_attribute_present)
self.assertEqual(new_tokenizer.__class__.__name__, "NewTokenizer")
else:
self.assertEqual(tokenizer.__class__.__name__, "NewTokenizer")

View File

@@ -302,8 +302,15 @@ class AutoTokenizerTest(unittest.TestCase):
def test_from_pretrained_dynamic_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/test_dynamic_tokenizer", trust_remote_code=True)
self.assertTrue(tokenizer.special_attribute_present)
# Test tokenizer can be reloaded.
with tempfile.TemporaryDirectory() as tmp_dir:
tokenizer.save_pretrained(tmp_dir)
reloaded_tokenizer = AutoTokenizer.from_pretrained(tmp_dir, trust_remote_code=True)
self.assertTrue(reloaded_tokenizer.special_attribute_present)
if is_tokenizers_available():
self.assertEqual(tokenizer.__class__.__name__, "NewTokenizerFast")
self.assertEqual(reloaded_tokenizer.__class__.__name__, "NewTokenizerFast")
# Test we can also load the slow version
tokenizer = AutoTokenizer.from_pretrained(
@@ -311,8 +318,15 @@ class AutoTokenizerTest(unittest.TestCase):
)
self.assertTrue(tokenizer.special_attribute_present)
self.assertEqual(tokenizer.__class__.__name__, "NewTokenizer")
# Test tokenizer can be reloaded.
with tempfile.TemporaryDirectory() as tmp_dir:
tokenizer.save_pretrained(tmp_dir)
reloaded_tokenizer = AutoTokenizer.from_pretrained(tmp_dir, trust_remote_code=True, use_fast=False)
self.assertEqual(reloaded_tokenizer.__class__.__name__, "NewTokenizer")
self.assertTrue(reloaded_tokenizer.special_attribute_present)
else:
self.assertEqual(tokenizer.__class__.__name__, "NewTokenizer")
self.assertEqual(reloaded_tokenizer.__class__.__name__, "NewTokenizer")
def test_from_pretrained_dynamic_tokenizer_legacy_format(self):
tokenizer = AutoTokenizer.from_pretrained(