Follow up for #31973 (#32025)

* fix

* [test_all] trigger full CI

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2024-07-25 16:12:23 +02:00
committed by GitHub
parent de2318894e
commit df6eee9201
9 changed files with 808 additions and 673 deletions

View File

@@ -18,10 +18,10 @@ import os
import tempfile import tempfile
import unittest import unittest
import warnings import warnings
from pathlib import Path
from huggingface_hub import HfFolder, delete_repo from huggingface_hub import HfFolder, delete_repo
from parameterized import parameterized from parameterized import parameterized
from requests.exceptions import HTTPError
from transformers import AutoConfig, GenerationConfig from transformers import AutoConfig, GenerationConfig
from transformers.generation import GenerationMode from transformers.generation import GenerationMode
@@ -228,72 +228,88 @@ class ConfigPushToHubTester(unittest.TestCase):
cls._token = TOKEN cls._token = TOKEN
HfFolder.save_token(TOKEN) HfFolder.save_token(TOKEN)
@classmethod @staticmethod
def tearDownClass(cls): def _try_delete_repo(repo_id, token):
try: try:
delete_repo(token=cls._token, repo_id="test-generation-config") # Reset repo
except HTTPError: delete_repo(repo_id=repo_id, token=token)
pass except: # noqa E722
try:
delete_repo(token=cls._token, repo_id="valid_org/test-generation-config-org")
except HTTPError:
pass pass
def test_push_to_hub(self): def test_push_to_hub(self):
config = GenerationConfig(
do_sample=True,
temperature=0.7,
length_penalty=1.0,
)
config.push_to_hub("test-generation-config", token=self._token)
new_config = GenerationConfig.from_pretrained(f"{USER}/test-generation-config")
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
try:
# Reset repo
delete_repo(token=self._token, repo_id="test-generation-config")
except: # noqa E722
pass
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
config.save_pretrained(tmp_dir, repo_id="test-generation-config", push_to_hub=True, token=self._token) try:
tmp_repo = f"{USER}/test-generation-config-{Path(tmp_dir).name}"
config = GenerationConfig(
do_sample=True,
temperature=0.7,
length_penalty=1.0,
)
config.push_to_hub(tmp_repo, token=self._token)
new_config = GenerationConfig.from_pretrained(f"{USER}/test-generation-config") new_config = GenerationConfig.from_pretrained(tmp_repo)
for k, v in config.to_dict().items(): for k, v in config.to_dict().items():
if k != "transformers_version": if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k)) self.assertEqual(v, getattr(new_config, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"{USER}/test-generation-config-{Path(tmp_dir).name}"
config = GenerationConfig(
do_sample=True,
temperature=0.7,
length_penalty=1.0,
)
# Push to hub via save_pretrained
config.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)
new_config = GenerationConfig.from_pretrained(tmp_repo)
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_in_organization(self): def test_push_to_hub_in_organization(self):
config = GenerationConfig(
do_sample=True,
temperature=0.7,
length_penalty=1.0,
)
config.push_to_hub("valid_org/test-generation-config-org", token=self._token)
new_config = GenerationConfig.from_pretrained("valid_org/test-generation-config-org")
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
try:
# Reset repo
delete_repo(token=self._token, repo_id="valid_org/test-generation-config-org")
except: # noqa E722
pass
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
config.save_pretrained( try:
tmp_dir, repo_id="valid_org/test-generation-config-org", push_to_hub=True, token=self._token tmp_repo = f"valid_org/test-generation-config-org-{Path(tmp_dir).name}"
) config = GenerationConfig(
do_sample=True,
temperature=0.7,
length_penalty=1.0,
)
config.push_to_hub(tmp_repo, token=self._token)
new_config = GenerationConfig.from_pretrained("valid_org/test-generation-config-org") new_config = GenerationConfig.from_pretrained(tmp_repo)
for k, v in config.to_dict().items(): for k, v in config.to_dict().items():
if k != "transformers_version": if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k)) self.assertEqual(v, getattr(new_config, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_in_organization_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"valid_org/test-generation-config-org-{Path(tmp_dir).name}"
config = GenerationConfig(
do_sample=True,
temperature=0.7,
length_penalty=1.0,
)
# Push to hub via save_pretrained
config.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)
new_config = GenerationConfig.from_pretrained(tmp_repo)
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)

View File

@@ -20,10 +20,8 @@ import tempfile
import unittest import unittest
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from uuid import uuid4
from huggingface_hub import HfFolder, Repository, create_repo, delete_repo from huggingface_hub import HfFolder, Repository, create_repo, delete_repo
from requests.exceptions import HTTPError
import transformers import transformers
from transformers import ( from transformers import (
@@ -374,69 +372,73 @@ class ProcessorPushToHubTester(unittest.TestCase):
cls._token = TOKEN cls._token = TOKEN
HfFolder.save_token(TOKEN) HfFolder.save_token(TOKEN)
@classmethod @staticmethod
def tearDownClass(cls): def _try_delete_repo(repo_id, token):
try: try:
delete_repo(token=cls._token, repo_id="test-processor") # Reset repo
except HTTPError: delete_repo(repo_id=repo_id, token=token)
except: # noqa E722
pass pass
try: def test_push_to_hub_via_save_pretrained(self):
delete_repo(token=cls._token, repo_id="valid_org/test-processor-org")
except HTTPError:
pass
try:
delete_repo(token=cls._token, repo_id="test-dynamic-processor")
except HTTPError:
pass
def test_push_to_hub(self):
processor = Wav2Vec2Processor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR)
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
processor.save_pretrained(os.path.join(tmp_dir, "test-processor"), push_to_hub=True, token=self._token) try:
tmp_repo = f"{USER}/test-processor-{Path(tmp_dir).name}"
processor = Wav2Vec2Processor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR)
# Push to hub via save_pretrained
processor.save_pretrained(tmp_repo, repo_id=tmp_repo, push_to_hub=True, token=self._token)
new_processor = Wav2Vec2Processor.from_pretrained(f"{USER}/test-processor") new_processor = Wav2Vec2Processor.from_pretrained(tmp_repo)
for k, v in processor.feature_extractor.__dict__.items(): for k, v in processor.feature_extractor.__dict__.items():
self.assertEqual(v, getattr(new_processor.feature_extractor, k)) self.assertEqual(v, getattr(new_processor.feature_extractor, k))
self.assertDictEqual(new_processor.tokenizer.get_vocab(), processor.tokenizer.get_vocab()) self.assertDictEqual(new_processor.tokenizer.get_vocab(), processor.tokenizer.get_vocab())
finally:
def test_push_to_hub_in_organization(self): # Always (try to) delete the repo.
processor = Wav2Vec2Processor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR) self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_in_organization_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
processor.save_pretrained( try:
os.path.join(tmp_dir, "test-processor-org"), tmp_repo = f"valid_org/test-processor-org-{Path(tmp_dir).name}"
push_to_hub=True, processor = Wav2Vec2Processor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR)
token=self._token,
organization="valid_org",
)
new_processor = Wav2Vec2Processor.from_pretrained("valid_org/test-processor-org") # Push to hub via save_pretrained
for k, v in processor.feature_extractor.__dict__.items(): processor.save_pretrained(
self.assertEqual(v, getattr(new_processor.feature_extractor, k)) tmp_dir,
self.assertDictEqual(new_processor.tokenizer.get_vocab(), processor.tokenizer.get_vocab()) repo_id=tmp_repo,
push_to_hub=True,
token=self._token,
)
new_processor = Wav2Vec2Processor.from_pretrained(tmp_repo)
for k, v in processor.feature_extractor.__dict__.items():
self.assertEqual(v, getattr(new_processor.feature_extractor, k))
self.assertDictEqual(new_processor.tokenizer.get_vocab(), processor.tokenizer.get_vocab())
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_dynamic_processor(self): def test_push_to_hub_dynamic_processor(self):
CustomFeatureExtractor.register_for_auto_class()
CustomTokenizer.register_for_auto_class()
CustomProcessor.register_for_auto_class()
feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR)
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
vocab_file = os.path.join(tmp_dir, "vocab.txt") try:
with open(vocab_file, "w", encoding="utf-8") as vocab_writer: tmp_repo = f"{USER}/test-dynamic-processor-{Path(tmp_dir).name}"
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer = CustomTokenizer(vocab_file)
processor = CustomProcessor(feature_extractor, tokenizer) CustomFeatureExtractor.register_for_auto_class()
CustomTokenizer.register_for_auto_class()
CustomProcessor.register_for_auto_class()
random_repo_id = f"{USER}/test-dynamic-processor-{uuid4()}" feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_PROCESSOR_CONFIG_DIR)
try:
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
create_repo(random_repo_id, token=self._token) vocab_file = os.path.join(tmp_dir, "vocab.txt")
repo = Repository(tmp_dir, clone_from=random_repo_id, token=self._token) with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer = CustomTokenizer(vocab_file)
processor = CustomProcessor(feature_extractor, tokenizer)
create_repo(tmp_repo, token=self._token)
repo = Repository(tmp_dir, clone_from=tmp_repo, token=self._token)
processor.save_pretrained(tmp_dir) processor.save_pretrained(tmp_dir)
# This has added the proper auto_map field to the feature extractor config # This has added the proper auto_map field to the feature extractor config
@@ -466,8 +468,10 @@ class ProcessorPushToHubTester(unittest.TestCase):
repo.push_to_hub() repo.push_to_hub()
new_processor = AutoProcessor.from_pretrained(random_repo_id, trust_remote_code=True) new_processor = AutoProcessor.from_pretrained(tmp_repo, trust_remote_code=True)
# Can't make an isinstance check because the new_processor is from the CustomProcessor class of a dynamic module # Can't make an isinstance check because the new_processor is from the CustomProcessor class of a dynamic module
self.assertEqual(new_processor.__class__.__name__, "CustomProcessor") self.assertEqual(new_processor.__class__.__name__, "CustomProcessor")
finally:
delete_repo(repo_id=random_repo_id) finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)

View File

@@ -98,88 +98,106 @@ class ConfigPushToHubTester(unittest.TestCase):
cls._token = TOKEN cls._token = TOKEN
HfFolder.save_token(TOKEN) HfFolder.save_token(TOKEN)
@classmethod @staticmethod
def tearDownClass(cls): def _try_delete_repo(repo_id, token):
try: try:
delete_repo(token=cls._token, repo_id="test-config") # Reset repo
except HTTPError: delete_repo(repo_id=repo_id, token=token)
pass except: # noqa E722
try:
delete_repo(token=cls._token, repo_id="valid_org/test-config-org")
except HTTPError:
pass
try:
delete_repo(token=cls._token, repo_id="test-dynamic-config")
except HTTPError:
pass pass
def test_push_to_hub(self): def test_push_to_hub(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
config.push_to_hub("test-config", token=self._token)
new_config = BertConfig.from_pretrained(f"{USER}/test-config")
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
try:
# Reset repo
delete_repo(token=self._token, repo_id="test-config")
except: # noqa E722
pass
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
config.save_pretrained(tmp_dir, repo_id="test-config", push_to_hub=True, token=self._token) try:
tmp_repo = f"{USER}/test-config-{Path(tmp_dir).name}"
new_config = BertConfig.from_pretrained(f"{USER}/test-config") config = BertConfig(
for k, v in config.to_dict().items(): vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
if k != "transformers_version": )
self.assertEqual(v, getattr(new_config, k)) config.push_to_hub(tmp_repo, token=self._token)
new_config = BertConfig.from_pretrained(tmp_repo)
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"{USER}/test-config-{Path(tmp_dir).name}"
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
# Push to hub via save_pretrained
config.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)
new_config = BertConfig.from_pretrained(tmp_repo)
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_in_organization(self): def test_push_to_hub_in_organization(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
config.push_to_hub("valid_org/test-config-org", token=self._token)
new_config = BertConfig.from_pretrained("valid_org/test-config-org")
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
try:
# Reset repo
delete_repo(token=self._token, repo_id="valid_org/test-config-org")
except: # noqa E722
pass
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
config.save_pretrained(tmp_dir, repo_id="valid_org/test-config-org", push_to_hub=True, token=self._token) try:
tmp_repo = f"valid_org/test-config-org-{Path(tmp_dir).name}"
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
config.push_to_hub(tmp_repo, token=self._token)
new_config = BertConfig.from_pretrained("valid_org/test-config-org") new_config = BertConfig.from_pretrained(tmp_repo)
for k, v in config.to_dict().items(): for k, v in config.to_dict().items():
if k != "transformers_version": if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k)) self.assertEqual(v, getattr(new_config, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_in_organization_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"valid_org/test-config-org-{Path(tmp_dir).name}"
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
# Push to hub via save_pretrained
config.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)
new_config = BertConfig.from_pretrained(tmp_repo)
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_dynamic_config(self): def test_push_to_hub_dynamic_config(self):
CustomConfig.register_for_auto_class() with tempfile.TemporaryDirectory() as tmp_dir:
config = CustomConfig(attribute=42) try:
tmp_repo = f"{USER}/test-dynamic-config-{Path(tmp_dir).name}"
config.push_to_hub("test-dynamic-config", token=self._token) CustomConfig.register_for_auto_class()
config = CustomConfig(attribute=42)
# This has added the proper auto_map field to the config config.push_to_hub(tmp_repo, token=self._token)
self.assertDictEqual(config.auto_map, {"AutoConfig": "custom_configuration.CustomConfig"})
new_config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-config", trust_remote_code=True) # This has added the proper auto_map field to the config
# Can't make an isinstance check because the new_config is from the FakeConfig class of a dynamic module self.assertDictEqual(config.auto_map, {"AutoConfig": "custom_configuration.CustomConfig"})
self.assertEqual(new_config.__class__.__name__, "CustomConfig")
self.assertEqual(new_config.attribute, 42) new_config = AutoConfig.from_pretrained(tmp_repo, trust_remote_code=True)
# Can't make an isinstance check because the new_config is from the FakeConfig class of a dynamic module
self.assertEqual(new_config.__class__.__name__, "CustomConfig")
self.assertEqual(new_config.attribute, 42)
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
class ConfigTestUtils(unittest.TestCase): class ConfigTestUtils(unittest.TestCase):

View File

@@ -60,85 +60,91 @@ class FeatureExtractorPushToHubTester(unittest.TestCase):
cls._token = TOKEN cls._token = TOKEN
HfFolder.save_token(TOKEN) HfFolder.save_token(TOKEN)
@classmethod @staticmethod
def tearDownClass(cls): def _try_delete_repo(repo_id, token):
try: try:
delete_repo(token=cls._token, repo_id="test-feature-extractor") # Reset repo
except HTTPError: delete_repo(repo_id=repo_id, token=token)
pass except: # noqa E722
try:
delete_repo(token=cls._token, repo_id="valid_org/test-feature-extractor-org")
except HTTPError:
pass
try:
delete_repo(token=cls._token, repo_id="test-dynamic-feature-extractor")
except HTTPError:
pass pass
def test_push_to_hub(self): def test_push_to_hub(self):
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
feature_extractor.push_to_hub("test-feature-extractor", token=self._token)
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"{USER}/test-feature-extractor")
for k, v in feature_extractor.__dict__.items():
self.assertEqual(v, getattr(new_feature_extractor, k))
try:
# Reset repo
delete_repo(token=self._token, repo_id="test-feature-extractor")
except: # noqa E722
pass
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
feature_extractor.save_pretrained( try:
tmp_dir, repo_id="test-feature-extractor", push_to_hub=True, token=self._token tmp_repo = f"{USER}/test-feature-extractor-{Path(tmp_dir).name}"
)
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"{USER}/test-feature-extractor") feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
for k, v in feature_extractor.__dict__.items(): feature_extractor.push_to_hub(tmp_repo, token=self._token)
self.assertEqual(v, getattr(new_feature_extractor, k))
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(tmp_repo)
for k, v in feature_extractor.__dict__.items():
self.assertEqual(v, getattr(new_feature_extractor, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"{USER}/test-feature-extractor-{Path(tmp_dir).name}"
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
# Push to hub via save_pretrained
feature_extractor.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(tmp_repo)
for k, v in feature_extractor.__dict__.items():
self.assertEqual(v, getattr(new_feature_extractor, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_in_organization(self): def test_push_to_hub_in_organization(self):
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
feature_extractor.push_to_hub("valid_org/test-feature-extractor", token=self._token)
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("valid_org/test-feature-extractor")
for k, v in feature_extractor.__dict__.items():
self.assertEqual(v, getattr(new_feature_extractor, k))
try:
# Reset repo
delete_repo(token=self._token, repo_id="valid_org/test-feature-extractor")
except: # noqa E722
pass
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
feature_extractor.save_pretrained( try:
tmp_dir, repo_id="valid_org/test-feature-extractor-org", push_to_hub=True, token=self._token tmp_repo = f"valid_org/test-feature-extractor-{Path(tmp_dir).name}"
) feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
feature_extractor.push_to_hub(tmp_repo, token=self._token)
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("valid_org/test-feature-extractor-org") new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(tmp_repo)
for k, v in feature_extractor.__dict__.items(): for k, v in feature_extractor.__dict__.items():
self.assertEqual(v, getattr(new_feature_extractor, k)) self.assertEqual(v, getattr(new_feature_extractor, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_in_organization_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"valid_org/test-feature-extractor-{Path(tmp_dir).name}"
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
# Push to hub via save_pretrained
feature_extractor.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(tmp_repo)
for k, v in feature_extractor.__dict__.items():
self.assertEqual(v, getattr(new_feature_extractor, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_dynamic_feature_extractor(self): def test_push_to_hub_dynamic_feature_extractor(self):
CustomFeatureExtractor.register_for_auto_class() with tempfile.TemporaryDirectory() as tmp_dir:
feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR) try:
tmp_repo = f"{USER}/test-dynamic-feature-extractor-{Path(tmp_dir).name}"
CustomFeatureExtractor.register_for_auto_class()
feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
feature_extractor.push_to_hub("test-dynamic-feature-extractor", token=self._token) feature_extractor.push_to_hub(tmp_repo, token=self._token)
# This has added the proper auto_map field to the config # This has added the proper auto_map field to the config
self.assertDictEqual( self.assertDictEqual(
feature_extractor.auto_map, feature_extractor.auto_map,
{"AutoFeatureExtractor": "custom_feature_extraction.CustomFeatureExtractor"}, {"AutoFeatureExtractor": "custom_feature_extraction.CustomFeatureExtractor"},
) )
new_feature_extractor = AutoFeatureExtractor.from_pretrained( new_feature_extractor = AutoFeatureExtractor.from_pretrained(tmp_repo, trust_remote_code=True)
f"{USER}/test-dynamic-feature-extractor", trust_remote_code=True # Can't make an isinstance check because the new_feature_extractor is from the CustomFeatureExtractor class of a dynamic module
) self.assertEqual(new_feature_extractor.__class__.__name__, "CustomFeatureExtractor")
# Can't make an isinstance check because the new_feature_extractor is from the CustomFeatureExtractor class of a dynamic module finally:
self.assertEqual(new_feature_extractor.__class__.__name__, "CustomFeatureExtractor") # Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)

View File

@@ -71,88 +71,93 @@ class ImageProcessorPushToHubTester(unittest.TestCase):
cls._token = TOKEN cls._token = TOKEN
HfFolder.save_token(TOKEN) HfFolder.save_token(TOKEN)
@classmethod @staticmethod
def tearDownClass(cls): def _try_delete_repo(repo_id, token):
try: try:
delete_repo(token=cls._token, repo_id="test-image-processor") # Reset repo
except HTTPError: delete_repo(repo_id=repo_id, token=token)
pass except: # noqa E722
try:
delete_repo(token=cls._token, repo_id="valid_org/test-image-processor-org")
except HTTPError:
pass
try:
delete_repo(token=cls._token, repo_id="test-dynamic-image-processor")
except HTTPError:
pass pass
def test_push_to_hub(self): def test_push_to_hub(self):
image_processor = ViTImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
image_processor.push_to_hub("test-image-processor", token=self._token)
new_image_processor = ViTImageProcessor.from_pretrained(f"{USER}/test-image-processor")
for k, v in image_processor.__dict__.items():
self.assertEqual(v, getattr(new_image_processor, k))
try:
# Reset repo
delete_repo(token=self._token, repo_id="test-image-processor")
except: # noqa E722
pass
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
image_processor.save_pretrained( try:
tmp_dir, repo_id="test-image-processor", push_to_hub=True, token=self._token tmp_repo = f"{USER}/test-image-processor-{Path(tmp_dir).name}"
) image_processor = ViTImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
image_processor.push_to_hub(tmp_repo, token=self._token)
new_image_processor = ViTImageProcessor.from_pretrained(f"{USER}/test-image-processor") new_image_processor = ViTImageProcessor.from_pretrained(tmp_repo)
for k, v in image_processor.__dict__.items(): for k, v in image_processor.__dict__.items():
self.assertEqual(v, getattr(new_image_processor, k)) self.assertEqual(v, getattr(new_image_processor, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"{USER}/test-image-processor-{Path(tmp_dir).name}"
image_processor = ViTImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
# Push to hub via save_pretrained
image_processor.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)
new_image_processor = ViTImageProcessor.from_pretrained(tmp_repo)
for k, v in image_processor.__dict__.items():
self.assertEqual(v, getattr(new_image_processor, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_in_organization(self): def test_push_to_hub_in_organization(self):
image_processor = ViTImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
image_processor.push_to_hub("valid_org/test-image-processor", token=self._token)
new_image_processor = ViTImageProcessor.from_pretrained("valid_org/test-image-processor")
for k, v in image_processor.__dict__.items():
self.assertEqual(v, getattr(new_image_processor, k))
try:
# Reset repo
delete_repo(token=self._token, repo_id="valid_org/test-image-processor")
except: # noqa E722
pass
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
image_processor.save_pretrained( try:
tmp_dir, repo_id="valid_org/test-image-processor-org", push_to_hub=True, token=self._token tmp_repo = f"valid_org/test-image-processor-{Path(tmp_dir).name}"
) image_processor = ViTImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
image_processor.push_to_hub(tmp_repo, token=self._token)
new_image_processor = ViTImageProcessor.from_pretrained("valid_org/test-image-processor-org") new_image_processor = ViTImageProcessor.from_pretrained(tmp_repo)
for k, v in image_processor.__dict__.items(): for k, v in image_processor.__dict__.items():
self.assertEqual(v, getattr(new_image_processor, k)) self.assertEqual(v, getattr(new_image_processor, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_in_organization_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"valid_org/test-image-processor-{Path(tmp_dir).name}"
image_processor = ViTImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
# Push to hub via save_pretrained
image_processor.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)
new_image_processor = ViTImageProcessor.from_pretrained(tmp_repo)
for k, v in image_processor.__dict__.items():
self.assertEqual(v, getattr(new_image_processor, k))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_dynamic_image_processor(self): def test_push_to_hub_dynamic_image_processor(self):
CustomImageProcessor.register_for_auto_class() with tempfile.TemporaryDirectory() as tmp_dir:
image_processor = CustomImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR) try:
tmp_repo = f"{USER}/test-dynamic-image-processor-{Path(tmp_dir).name}"
CustomImageProcessor.register_for_auto_class()
image_processor = CustomImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
image_processor.push_to_hub("test-dynamic-image-processor", token=self._token) image_processor.push_to_hub(tmp_repo, token=self._token)
# This has added the proper auto_map field to the config # This has added the proper auto_map field to the config
self.assertDictEqual( self.assertDictEqual(
image_processor.auto_map, image_processor.auto_map,
{"AutoImageProcessor": "custom_image_processing.CustomImageProcessor"}, {"AutoImageProcessor": "custom_image_processing.CustomImageProcessor"},
) )
new_image_processor = AutoImageProcessor.from_pretrained( new_image_processor = AutoImageProcessor.from_pretrained(tmp_repo, trust_remote_code=True)
f"{USER}/test-dynamic-image-processor", trust_remote_code=True # Can't make an isinstance check because the new_image_processor is from the CustomImageProcessor class of a dynamic module
) self.assertEqual(new_image_processor.__class__.__name__, "CustomImageProcessor")
# Can't make an isinstance check because the new_image_processor is from the CustomImageProcessor class of a dynamic module finally:
self.assertEqual(new_image_processor.__class__.__name__, "CustomImageProcessor") # Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
class ImageProcessingUtilsTester(unittest.TestCase): class ImageProcessingUtilsTester(unittest.TestCase):

View File

@@ -14,10 +14,10 @@
import tempfile import tempfile
import unittest import unittest
from pathlib import Path
import numpy as np import numpy as np
from huggingface_hub import HfFolder, delete_repo, snapshot_download from huggingface_hub import HfFolder, delete_repo, snapshot_download
from requests.exceptions import HTTPError
from transformers import BertConfig, BertModel, is_flax_available, is_torch_available from transformers import BertConfig, BertModel, is_flax_available, is_torch_available
from transformers.testing_utils import ( from transformers.testing_utils import (
@@ -55,89 +55,103 @@ class FlaxModelPushToHubTester(unittest.TestCase):
cls._token = TOKEN cls._token = TOKEN
HfFolder.save_token(TOKEN) HfFolder.save_token(TOKEN)
@classmethod @staticmethod
def tearDownClass(cls): def _try_delete_repo(repo_id, token):
try: try:
delete_repo(token=cls._token, repo_id="test-model-flax") # Reset repo
except HTTPError: delete_repo(repo_id=repo_id, token=token)
pass except: # noqa E722
try:
delete_repo(token=cls._token, repo_id="valid_org/test-model-flax-org")
except HTTPError:
pass pass
def test_push_to_hub(self): def test_push_to_hub(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = FlaxBertModel(config)
model.push_to_hub("test-model-flax", token=self._token)
new_model = FlaxBertModel.from_pretrained(f"{USER}/test-model-flax")
base_params = flatten_dict(unfreeze(model.params))
new_params = flatten_dict(unfreeze(new_model.params))
for key in base_params.keys():
max_diff = (base_params[key] - new_params[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
try:
# Reset repo
delete_repo(token=self._token, repo_id="test-model-flax")
except: # noqa E722
pass
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, repo_id="test-model-flax", push_to_hub=True, token=self._token) try:
tmp_repo = f"{USER}/test-model-flax-{Path(tmp_dir).name}"
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = FlaxBertModel(config)
model.push_to_hub(tmp_repo, token=self._token)
new_model = FlaxBertModel.from_pretrained(f"{USER}/test-model-flax") new_model = FlaxBertModel.from_pretrained(tmp_repo)
base_params = flatten_dict(unfreeze(model.params)) base_params = flatten_dict(unfreeze(model.params))
new_params = flatten_dict(unfreeze(new_model.params)) new_params = flatten_dict(unfreeze(new_model.params))
for key in base_params.keys(): for key in base_params.keys():
max_diff = (base_params[key] - new_params[key]).sum().item() max_diff = (base_params[key] - new_params[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"{USER}/test-model-flax-{Path(tmp_dir).name}"
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = FlaxBertModel(config)
# Push to hub via save_pretrained
model.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)
new_model = FlaxBertModel.from_pretrained(tmp_repo)
base_params = flatten_dict(unfreeze(model.params))
new_params = flatten_dict(unfreeze(new_model.params))
for key in base_params.keys():
max_diff = (base_params[key] - new_params[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_in_organization(self): def test_push_to_hub_in_organization(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = FlaxBertModel(config)
model.push_to_hub("valid_org/test-model-flax-org", token=self._token)
new_model = FlaxBertModel.from_pretrained("valid_org/test-model-flax-org")
base_params = flatten_dict(unfreeze(model.params))
new_params = flatten_dict(unfreeze(new_model.params))
for key in base_params.keys():
max_diff = (base_params[key] - new_params[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
try:
# Reset repo
delete_repo(token=self._token, repo_id="valid_org/test-model-flax-org")
except: # noqa E722
pass
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained( try:
tmp_dir, repo_id="valid_org/test-model-flax-org", push_to_hub=True, token=self._token tmp_repo = f"valid_org/test-model-flax-org-{Path(tmp_dir).name}"
) config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = FlaxBertModel(config)
model.push_to_hub(tmp_repo, token=self._token)
new_model = FlaxBertModel.from_pretrained("valid_org/test-model-flax-org") new_model = FlaxBertModel.from_pretrained(tmp_repo)
base_params = flatten_dict(unfreeze(model.params)) base_params = flatten_dict(unfreeze(model.params))
new_params = flatten_dict(unfreeze(new_model.params)) new_params = flatten_dict(unfreeze(new_model.params))
for key in base_params.keys(): for key in base_params.keys():
max_diff = (base_params[key] - new_params[key]).sum().item() max_diff = (base_params[key] - new_params[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_in_organization_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"valid_org/test-model-flax-org-{Path(tmp_dir).name}"
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = FlaxBertModel(config)
# Push to hub via save_pretrained
model.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)
new_model = FlaxBertModel.from_pretrained(tmp_repo)
base_params = flatten_dict(unfreeze(model.params))
new_params = flatten_dict(unfreeze(new_model.params))
for key in base_params.keys():
max_diff = (base_params[key] - new_params[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def check_models_equal(model1, model2): def check_models_equal(model1, model2):

View File

@@ -23,6 +23,7 @@ import random
import tempfile import tempfile
import unittest import unittest
import unittest.mock as mock import unittest.mock as mock
from pathlib import Path
from huggingface_hub import HfFolder, Repository, delete_repo, snapshot_download from huggingface_hub import HfFolder, Repository, delete_repo, snapshot_download
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
@@ -682,127 +683,149 @@ class TFModelPushToHubTester(unittest.TestCase):
cls._token = TOKEN cls._token = TOKEN
HfFolder.save_token(TOKEN) HfFolder.save_token(TOKEN)
@classmethod @staticmethod
def tearDownClass(cls): def _try_delete_repo(repo_id, token):
try: try:
delete_repo(token=cls._token, repo_id="test-model-tf") # Reset repo
except HTTPError: delete_repo(repo_id=repo_id, token=token)
pass except: # noqa E722
try:
delete_repo(token=cls._token, repo_id="test-model-tf-callback")
except HTTPError:
pass
try:
delete_repo(token=cls._token, repo_id="valid_org/test-model-tf-org")
except HTTPError:
pass pass
def test_push_to_hub(self): def test_push_to_hub(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = TFBertModel(config)
# Make sure model is properly initialized
model.build_in_name_scope()
logging.set_verbosity_info()
logger = logging.get_logger("transformers.utils.hub")
with CaptureLogger(logger) as cl:
model.push_to_hub("test-model-tf", token=self._token)
logging.set_verbosity_warning()
# Check the model card was created and uploaded.
self.assertIn("Uploading the following files to __DUMMY_TRANSFORMERS_USER__/test-model-tf", cl.out)
new_model = TFBertModel.from_pretrained(f"{USER}/test-model-tf")
models_equal = True
for p1, p2 in zip(model.weights, new_model.weights):
if not tf.math.reduce_all(p1 == p2):
models_equal = False
break
self.assertTrue(models_equal)
try:
# Reset repo
delete_repo(token=self._token, repo_id="test-model-tf")
except: # noqa E722
pass
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, repo_id="test-model-tf", push_to_hub=True, token=self._token) try:
tmp_repo = f"{USER}/test-model-tf-{Path(tmp_dir).name}"
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = TFBertModel(config)
# Make sure model is properly initialized
model.build_in_name_scope()
new_model = TFBertModel.from_pretrained(f"{USER}/test-model-tf") logging.set_verbosity_info()
models_equal = True logger = logging.get_logger("transformers.utils.hub")
for p1, p2 in zip(model.weights, new_model.weights): with CaptureLogger(logger) as cl:
if not tf.math.reduce_all(p1 == p2): model.push_to_hub(tmp_repo, token=self._token)
models_equal = False logging.set_verbosity_warning()
break # Check the model card was created and uploaded.
self.assertTrue(models_equal) self.assertIn("Uploading the following files to __DUMMY_TRANSFORMERS_USER__/test-model-tf", cl.out)
new_model = TFBertModel.from_pretrained(tmp_repo)
models_equal = True
for p1, p2 in zip(model.weights, new_model.weights):
if not tf.math.reduce_all(p1 == p2):
models_equal = False
break
self.assertTrue(models_equal)
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"{USER}/test-model-tf-{Path(tmp_dir).name}"
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = TFBertModel(config)
# Make sure model is properly initialized
model.build_in_name_scope()
# Push to hub via save_pretrained
model.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)
new_model = TFBertModel.from_pretrained(tmp_repo)
models_equal = True
for p1, p2 in zip(model.weights, new_model.weights):
if not tf.math.reduce_all(p1 == p2):
models_equal = False
break
self.assertTrue(models_equal)
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
@is_pt_tf_cross_test @is_pt_tf_cross_test
def test_push_to_hub_callback(self): def test_push_to_hub_callback(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = TFBertForMaskedLM(config)
model.compile()
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
push_to_hub_callback = PushToHubCallback( try:
output_dir=tmp_dir, tmp_repo = f"{USER}/test-model-tf-callback-{Path(tmp_dir).name}"
hub_model_id="test-model-tf-callback", config = BertConfig(
hub_token=self._token, vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
) )
model.fit(model.dummy_inputs, model.dummy_inputs, epochs=1, callbacks=[push_to_hub_callback]) model = TFBertForMaskedLM(config)
model.compile()
new_model = TFBertForMaskedLM.from_pretrained(f"{USER}/test-model-tf-callback") push_to_hub_callback = PushToHubCallback(
models_equal = True output_dir=tmp_dir,
for p1, p2 in zip(model.weights, new_model.weights): hub_model_id=tmp_repo,
if not tf.math.reduce_all(p1 == p2): hub_token=self._token,
models_equal = False )
break model.fit(model.dummy_inputs, model.dummy_inputs, epochs=1, callbacks=[push_to_hub_callback])
self.assertTrue(models_equal)
tf_push_to_hub_params = dict(inspect.signature(TFPreTrainedModel.push_to_hub).parameters) new_model = TFBertForMaskedLM.from_pretrained(tmp_repo)
tf_push_to_hub_params.pop("base_model_card_args") models_equal = True
pt_push_to_hub_params = dict(inspect.signature(PreTrainedModel.push_to_hub).parameters) for p1, p2 in zip(model.weights, new_model.weights):
pt_push_to_hub_params.pop("deprecated_kwargs") if not tf.math.reduce_all(p1 == p2):
self.assertDictEaual(tf_push_to_hub_params, pt_push_to_hub_params) models_equal = False
break
self.assertTrue(models_equal)
tf_push_to_hub_params = dict(inspect.signature(TFPreTrainedModel.push_to_hub).parameters)
tf_push_to_hub_params.pop("base_model_card_args")
pt_push_to_hub_params = dict(inspect.signature(PreTrainedModel.push_to_hub).parameters)
pt_push_to_hub_params.pop("deprecated_kwargs")
self.assertDictEaual(tf_push_to_hub_params, pt_push_to_hub_params)
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_in_organization(self): def test_push_to_hub_in_organization(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = TFBertModel(config)
# Make sure model is properly initialized
model.build_in_name_scope()
model.push_to_hub("valid_org/test-model-tf-org", token=self._token)
new_model = TFBertModel.from_pretrained("valid_org/test-model-tf-org")
models_equal = True
for p1, p2 in zip(model.weights, new_model.weights):
if not tf.math.reduce_all(p1 == p2):
models_equal = False
break
self.assertTrue(models_equal)
try:
# Reset repo
delete_repo(token=self._token, repo_id="valid_org/test-model-tf-org")
except: # noqa E722
pass
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, push_to_hub=True, token=self._token, repo_id="valid_org/test-model-tf-org") try:
tmp_repo = f"valid_org/test-model-tf-org-{Path(tmp_dir).name}"
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = TFBertModel(config)
# Make sure model is properly initialized
model.build_in_name_scope()
new_model = TFBertModel.from_pretrained("valid_org/test-model-tf-org") model.push_to_hub(tmp_repo, token=self._token)
models_equal = True
for p1, p2 in zip(model.weights, new_model.weights): new_model = TFBertModel.from_pretrained(tmp_repo)
if not tf.math.reduce_all(p1 == p2): models_equal = True
models_equal = False for p1, p2 in zip(model.weights, new_model.weights):
break if not tf.math.reduce_all(p1 == p2):
self.assertTrue(models_equal) models_equal = False
break
self.assertTrue(models_equal)
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_in_organization_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"valid_org/test-model-tf-org-{Path(tmp_dir).name}"
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = TFBertModel(config)
# Make sure model is properly initialized
model.build_in_name_scope()
# Push to hub via save_pretrained
model.save_pretrained(tmp_dir, push_to_hub=True, token=self._token, repo_id=tmp_repo)
new_model = TFBertModel.from_pretrained(tmp_repo)
models_equal = True
for p1, p2 in zip(model.weights, new_model.weights):
if not tf.math.reduce_all(p1 == p2):
models_equal = False
break
self.assertTrue(models_equal)
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)

View File

@@ -1876,142 +1876,168 @@ class ModelPushToHubTester(unittest.TestCase):
cls._token = TOKEN cls._token = TOKEN
HfFolder.save_token(TOKEN) HfFolder.save_token(TOKEN)
@classmethod @staticmethod
def tearDownClass(cls): def _try_delete_repo(repo_id, token):
try: try:
delete_repo(token=cls._token, repo_id="test-model") # Reset repo
except HTTPError: delete_repo(repo_id=repo_id, token=token)
pass except: # noqa E722
try:
delete_repo(token=cls._token, repo_id="valid_org/test-model-org")
except HTTPError:
pass
try:
delete_repo(token=cls._token, repo_id="test-dynamic-model")
except HTTPError:
pass
try:
delete_repo(token=cls._token, repo_id="test-dynamic-model-with-tags")
except HTTPError:
pass pass
@unittest.skip(reason="This test is flaky") @unittest.skip(reason="This test is flaky")
def test_push_to_hub(self): def test_push_to_hub(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = BertModel(config)
model.push_to_hub("test-model", token=self._token)
new_model = BertModel.from_pretrained(f"{USER}/test-model")
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
try:
# Reset repo
delete_repo(token=self._token, repo_id="test-model")
except: # noqa E722
pass
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, repo_id="test-model", push_to_hub=True, token=self._token) try:
tmp_repo = f"{USER}/test-model-{Path(tmp_dir).name}"
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = BertModel(config)
model.push_to_hub(tmp_repo, token=self._token)
new_model = BertModel.from_pretrained(f"{USER}/test-model") new_model = BertModel.from_pretrained(tmp_repo)
for p1, p2 in zip(model.parameters(), new_model.parameters()): for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2)) self.assertTrue(torch.equal(p1, p2))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
@unittest.skip(reason="This test is flaky")
def test_push_to_hub_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"{USER}/test-model-{Path(tmp_dir).name}"
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = BertModel(config)
# Push to hub via save_pretrained
model.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)
new_model = BertModel.from_pretrained(tmp_repo)
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_with_description(self): def test_push_to_hub_with_description(self):
config = BertConfig( with tempfile.TemporaryDirectory() as tmp_dir:
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 try:
) tmp_repo = f"{USER}/test-model-{Path(tmp_dir).name}"
model = BertModel(config) config = BertConfig(
COMMIT_DESCRIPTION = """ vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = BertModel(config)
COMMIT_DESCRIPTION = """
The commit description supports markdown synthax see: The commit description supports markdown synthax see:
```python ```python
>>> form transformers import AutoConfig >>> form transformers import AutoConfig
>>> config = AutoConfig.from_pretrained("google-bert/bert-base-uncased") >>> config = AutoConfig.from_pretrained("google-bert/bert-base-uncased")
``` ```
""" """
commit_details = model.push_to_hub( commit_details = model.push_to_hub(
"test-model", use_auth_token=self._token, create_pr=True, commit_description=COMMIT_DESCRIPTION tmp_repo, use_auth_token=self._token, create_pr=True, commit_description=COMMIT_DESCRIPTION
) )
self.assertEqual(commit_details.commit_description, COMMIT_DESCRIPTION) self.assertEqual(commit_details.commit_description, COMMIT_DESCRIPTION)
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
@unittest.skip(reason="This test is flaky") @unittest.skip(reason="This test is flaky")
def test_push_to_hub_in_organization(self): def test_push_to_hub_in_organization(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = BertModel(config)
model.push_to_hub("valid_org/test-model-org", token=self._token)
new_model = BertModel.from_pretrained("valid_org/test-model-org")
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
try:
# Reset repo
delete_repo(token=self._token, repo_id="valid_org/test-model-org")
except: # noqa E722
pass
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, push_to_hub=True, token=self._token, repo_id="valid_org/test-model-org") try:
tmp_repo = f"valid_org/test-model-org-{Path(tmp_dir).name}"
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = BertModel(config)
model.push_to_hub(tmp_repo, token=self._token)
new_model = BertModel.from_pretrained("valid_org/test-model-org") new_model = BertModel.from_pretrained(tmp_repo)
for p1, p2 in zip(model.parameters(), new_model.parameters()): for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2)) self.assertTrue(torch.equal(p1, p2))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
@unittest.skip(reason="This test is flaky")
def test_push_to_hub_in_organization_via_save_pretrained(self):
with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"valid_org/test-model-org-{Path(tmp_dir).name}"
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = BertModel(config)
# Push to hub via save_pretrained
model.save_pretrained(tmp_dir, push_to_hub=True, token=self._token, repo_id=tmp_repo)
new_model = BertModel.from_pretrained(tmp_repo)
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_dynamic_model(self): def test_push_to_hub_dynamic_model(self):
CustomConfig.register_for_auto_class() with tempfile.TemporaryDirectory() as tmp_dir:
CustomModel.register_for_auto_class() try:
tmp_repo = f"{USER}/test-dynamic-model-{Path(tmp_dir).name}"
CustomConfig.register_for_auto_class()
CustomModel.register_for_auto_class()
config = CustomConfig(hidden_size=32) config = CustomConfig(hidden_size=32)
model = CustomModel(config) model = CustomModel(config)
model.push_to_hub("test-dynamic-model", token=self._token) model.push_to_hub(tmp_repo, token=self._token)
# checks # checks
self.assertDictEqual( self.assertDictEqual(
config.auto_map, config.auto_map,
{"AutoConfig": "custom_configuration.CustomConfig", "AutoModel": "custom_modeling.CustomModel"}, {"AutoConfig": "custom_configuration.CustomConfig", "AutoModel": "custom_modeling.CustomModel"},
) )
new_model = AutoModel.from_pretrained(f"{USER}/test-dynamic-model", trust_remote_code=True) new_model = AutoModel.from_pretrained(tmp_repo, trust_remote_code=True)
# Can't make an isinstance check because the new_model is from the CustomModel class of a dynamic module # Can't make an isinstance check because the new_model is from the CustomModel class of a dynamic module
self.assertEqual(new_model.__class__.__name__, "CustomModel") self.assertEqual(new_model.__class__.__name__, "CustomModel")
for p1, p2 in zip(model.parameters(), new_model.parameters()): for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2)) self.assertTrue(torch.equal(p1, p2))
config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-model", trust_remote_code=True) config = AutoConfig.from_pretrained(tmp_repo, trust_remote_code=True)
new_model = AutoModel.from_config(config, trust_remote_code=True) new_model = AutoModel.from_config(config, trust_remote_code=True)
self.assertEqual(new_model.__class__.__name__, "CustomModel") self.assertEqual(new_model.__class__.__name__, "CustomModel")
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_with_tags(self): def test_push_to_hub_with_tags(self):
from huggingface_hub import ModelCard with tempfile.TemporaryDirectory() as tmp_dir:
try:
tmp_repo = f"{USER}/test-dynamic-model-with-tags-{Path(tmp_dir).name}"
from huggingface_hub import ModelCard
new_tags = ["tag-1", "tag-2"] new_tags = ["tag-1", "tag-2"]
CustomConfig.register_for_auto_class() CustomConfig.register_for_auto_class()
CustomModel.register_for_auto_class() CustomModel.register_for_auto_class()
config = CustomConfig(hidden_size=32) config = CustomConfig(hidden_size=32)
model = CustomModel(config) model = CustomModel(config)
self.assertTrue(model.model_tags is None) self.assertTrue(model.model_tags is None)
model.add_model_tags(new_tags) model.add_model_tags(new_tags)
self.assertTrue(model.model_tags == new_tags) self.assertTrue(model.model_tags == new_tags)
model.push_to_hub("test-dynamic-model-with-tags", token=self._token) model.push_to_hub(tmp_repo, token=self._token)
loaded_model_card = ModelCard.load(f"{USER}/test-dynamic-model-with-tags") loaded_model_card = ModelCard.load(tmp_repo)
self.assertEqual(loaded_model_card.data.tags, new_tags) self.assertEqual(loaded_model_card.data.tags, new_tags)
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
@require_torch @require_torch

View File

@@ -118,110 +118,133 @@ class TokenizerPushToHubTester(unittest.TestCase):
cls._token = TOKEN cls._token = TOKEN
HfFolder.save_token(TOKEN) HfFolder.save_token(TOKEN)
@classmethod @staticmethod
def tearDownClass(cls): def _try_delete_repo(repo_id, token):
try: try:
delete_repo(token=cls._token, repo_id="test-tokenizer") # Reset repo
except HTTPError: delete_repo(repo_id=repo_id, token=token)
pass except: # noqa E722
try:
delete_repo(token=cls._token, repo_id="valid_org/test-tokenizer-org")
except HTTPError:
pass
try:
delete_repo(token=cls._token, repo_id="test-dynamic-tokenizer")
except HTTPError:
pass pass
def test_push_to_hub(self): def test_push_to_hub(self):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
vocab_file = os.path.join(tmp_dir, "vocab.txt") try:
with open(vocab_file, "w", encoding="utf-8") as vocab_writer: tmp_repo = f"{USER}/test-tokenizer-{Path(tmp_dir).name}"
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens])) vocab_file = os.path.join(tmp_dir, "vocab.txt")
tokenizer = BertTokenizer(vocab_file) with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer = BertTokenizer(vocab_file)
tokenizer.push_to_hub("test-tokenizer", token=self._token) tokenizer.push_to_hub(tmp_repo, token=self._token)
new_tokenizer = BertTokenizer.from_pretrained(f"{USER}/test-tokenizer") new_tokenizer = BertTokenizer.from_pretrained(tmp_repo)
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab) self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
try: def test_push_to_hub_via_save_pretrained(self):
# Reset repo
delete_repo(token=self._token, repo_id="test-tokenizer")
except: # noqa E722
pass
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
tokenizer.save_pretrained(tmp_dir, repo_id="test-tokenizer", push_to_hub=True, token=self._token) try:
tmp_repo = f"{USER}/test-tokenizer-{Path(tmp_dir).name}"
vocab_file = os.path.join(tmp_dir, "vocab.txt")
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer = BertTokenizer(vocab_file)
new_tokenizer = BertTokenizer.from_pretrained(f"{USER}/test-tokenizer") # Push to hub via save_pretrained
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab) tokenizer.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)
new_tokenizer = BertTokenizer.from_pretrained(tmp_repo)
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
def test_push_to_hub_in_organization(self): def test_push_to_hub_in_organization(self):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
vocab_file = os.path.join(tmp_dir, "vocab.txt") try:
with open(vocab_file, "w", encoding="utf-8") as vocab_writer: tmp_repo = f"valid_org/test-tokenizer-{Path(tmp_dir).name}"
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens])) vocab_file = os.path.join(tmp_dir, "vocab.txt")
tokenizer = BertTokenizer(vocab_file) with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer = BertTokenizer(vocab_file)
tokenizer.push_to_hub("valid_org/test-tokenizer-org", token=self._token) tokenizer.push_to_hub(tmp_repo, token=self._token)
new_tokenizer = BertTokenizer.from_pretrained("valid_org/test-tokenizer-org") new_tokenizer = BertTokenizer.from_pretrained(tmp_repo)
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab) self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
try: def test_push_to_hub_in_organization_via_save_pretrained(self):
# Reset repo
delete_repo(token=self._token, repo_id="valid_org/test-tokenizer-org")
except: # noqa E722
pass
# Push to hub via save_pretrained
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
tokenizer.save_pretrained( try:
tmp_dir, repo_id="valid_org/test-tokenizer-org", push_to_hub=True, token=self._token tmp_repo = f"valid_org/test-tokenizer-{Path(tmp_dir).name}"
) vocab_file = os.path.join(tmp_dir, "vocab.txt")
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer = BertTokenizer(vocab_file)
new_tokenizer = BertTokenizer.from_pretrained("valid_org/test-tokenizer-org") # Push to hub via save_pretrained
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab) tokenizer.save_pretrained(tmp_dir, repo_id=tmp_repo, push_to_hub=True, token=self._token)
new_tokenizer = BertTokenizer.from_pretrained(tmp_repo)
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
@require_tokenizers @require_tokenizers
def test_push_to_hub_dynamic_tokenizer(self): def test_push_to_hub_dynamic_tokenizer(self):
CustomTokenizer.register_for_auto_class()
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
vocab_file = os.path.join(tmp_dir, "vocab.txt") try:
with open(vocab_file, "w", encoding="utf-8") as vocab_writer: tmp_repo = f"{USER}/test-dynamic-tokenizer-{Path(tmp_dir).name}"
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens])) CustomTokenizer.register_for_auto_class()
tokenizer = CustomTokenizer(vocab_file)
# No fast custom tokenizer vocab_file = os.path.join(tmp_dir, "vocab.txt")
tokenizer.push_to_hub("test-dynamic-tokenizer", token=self._token) with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer = CustomTokenizer(vocab_file)
tokenizer = AutoTokenizer.from_pretrained(f"{USER}/test-dynamic-tokenizer", trust_remote_code=True) # No fast custom tokenizer
# Can't make an isinstance check because the new_model.config is from the CustomTokenizer class of a dynamic module tokenizer.push_to_hub(tmp_repo, token=self._token)
self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizer")
# Fast and slow custom tokenizer tokenizer = AutoTokenizer.from_pretrained(tmp_repo, trust_remote_code=True)
CustomTokenizerFast.register_for_auto_class() # Can't make an isinstance check because the new_model.config is from the CustomTokenizer class of a dynamic module
self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizer")
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
@require_tokenizers
def test_push_to_hub_dynamic_tokenizer_with_both_slow_and_fast_classes(self):
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
vocab_file = os.path.join(tmp_dir, "vocab.txt") try:
with open(vocab_file, "w", encoding="utf-8") as vocab_writer: tmp_repo = f"{USER}/test-dynamic-tokenizer-{Path(tmp_dir).name}"
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens])) CustomTokenizer.register_for_auto_class()
bert_tokenizer = BertTokenizerFast.from_pretrained(tmp_dir) # Fast and slow custom tokenizer
bert_tokenizer.save_pretrained(tmp_dir) CustomTokenizerFast.register_for_auto_class()
tokenizer = CustomTokenizerFast.from_pretrained(tmp_dir)
tokenizer.push_to_hub("test-dynamic-tokenizer", token=self._token) vocab_file = os.path.join(tmp_dir, "vocab.txt")
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
tokenizer = AutoTokenizer.from_pretrained(f"{USER}/test-dynamic-tokenizer", trust_remote_code=True) bert_tokenizer = BertTokenizerFast.from_pretrained(tmp_dir)
# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module bert_tokenizer.save_pretrained(tmp_dir)
self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizerFast") tokenizer = CustomTokenizerFast.from_pretrained(tmp_dir)
tokenizer = AutoTokenizer.from_pretrained(
f"{USER}/test-dynamic-tokenizer", use_fast=False, trust_remote_code=True tokenizer.push_to_hub(tmp_repo, token=self._token)
)
# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module tokenizer = AutoTokenizer.from_pretrained(tmp_repo, trust_remote_code=True)
self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizer") # Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizerFast")
tokenizer = AutoTokenizer.from_pretrained(tmp_repo, use_fast=False, trust_remote_code=True)
# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizer")
finally:
# Always (try to) delete the repo.
self._try_delete_repo(repo_id=tmp_repo, token=self._token)
class TrieTest(unittest.TestCase): class TrieTest(unittest.TestCase):