Make using safetensors files automated. (#27571)
* [WIP] Make using safetensors files automated. If `use_safetensors=True` is used, and it doesn't exist: - Don't crash just yet - Lookup for an open PR containing it. - If yes, use that instead - If not, touch the space to convert, wait for conversion to be finished and the PR to be opened - Use that new PR - Profit. * Remove the token. * [Auto Safetensors] Websocket -> SSE (#27656) * Websocket -> SSE * Support sharded + tests +cleanup a * env var * Apply suggestions from code review * Thanks Simon * Thanks Wauplin Co-authored-by: Wauplin <lucainp@gmail.com> * Cleanup * Update tests * Tests should pass * Apply to other tests * Extend extension * relax requirement on latest hfh * Revert * Correct private handling & debug statements * Skip gated repos as of now * Address review comments Co-authored-by: ArthurZucker <arthur.zucker@gmail.com> --------- Co-authored-by: Lysandre Debut <hi@lysand.re> Co-authored-by: Lysandre <lysandre@huggingface.co> Co-authored-by: Wauplin <lucainp@gmail.com> Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr> Co-authored-by: ArthurZucker <arthur.zucker@gmail.com>
This commit is contained in:
@@ -21,9 +21,11 @@ import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import HfFolder, delete_repo
|
||||
import requests
|
||||
from huggingface_hub import HfApi, HfFolder, delete_repo
|
||||
from huggingface_hub.file_download import http_get
|
||||
from pytest import mark
|
||||
from requests.exceptions import HTTPError
|
||||
@@ -829,14 +831,8 @@ class ModelUtilsTest(TestCasePlus):
|
||||
|
||||
@require_safetensors
|
||||
def test_use_safetensors(self):
|
||||
# test nice error message if no safetensor files available
|
||||
with self.assertRaises(OSError) as env_error:
|
||||
AutoModel.from_pretrained("hf-internal-testing/tiny-random-RobertaModel", use_safetensors=True)
|
||||
|
||||
self.assertTrue(
|
||||
"model.safetensors or model.safetensors.index.json and thus cannot be loaded with `safetensors`"
|
||||
in str(env_error.exception)
|
||||
)
|
||||
# Should not raise anymore
|
||||
AutoModel.from_pretrained("hf-internal-testing/tiny-random-RobertaModel", use_safetensors=True)
|
||||
|
||||
# test that error if only safetensors is available
|
||||
with self.assertRaises(OSError) as env_error:
|
||||
@@ -1171,6 +1167,202 @@ class ModelUtilsTest(TestCasePlus):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
|
||||
|
||||
@require_torch
|
||||
class ModelOnTheFlyConversionTester(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.user = "huggingface-hub-ci"
|
||||
cls.token = os.getenv("HUGGINGFACE_PRODUCTION_USER_TOKEN", None)
|
||||
|
||||
if cls.token is None:
|
||||
raise ValueError("Cannot run tests as secret isn't setup.")
|
||||
|
||||
cls.api = HfApi(token=cls.token)
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.repo_name = f"{self.user}/test-model-on-the-fly-{uuid.uuid4()}"
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.api.delete_repo(self.repo_name)
|
||||
|
||||
def test_safetensors_on_the_fly_conversion(self):
|
||||
config = BertConfig(
|
||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||
)
|
||||
initial_model = BertModel(config)
|
||||
|
||||
initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False)
|
||||
converted_model = BertModel.from_pretrained(self.repo_name, use_safetensors=True)
|
||||
|
||||
with self.subTest("Initial and converted models are equal"):
|
||||
for p1, p2 in zip(initial_model.parameters(), converted_model.parameters()):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
|
||||
with self.subTest("PR was open with the safetensors account"):
|
||||
discussions = self.api.get_repo_discussions(self.repo_name)
|
||||
discussion = next(discussions)
|
||||
self.assertEqual(discussion.author, "SFconvertbot")
|
||||
self.assertEqual(discussion.title, "Adding `safetensors` variant of this model")
|
||||
|
||||
def test_safetensors_on_the_fly_conversion_private(self):
|
||||
config = BertConfig(
|
||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||
)
|
||||
initial_model = BertModel(config)
|
||||
|
||||
initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False, private=True)
|
||||
converted_model = BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token)
|
||||
|
||||
with self.subTest("Initial and converted models are equal"):
|
||||
for p1, p2 in zip(initial_model.parameters(), converted_model.parameters()):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
|
||||
with self.subTest("PR was open with the safetensors account"):
|
||||
discussions = self.api.get_repo_discussions(self.repo_name, token=self.token)
|
||||
discussion = next(discussions)
|
||||
self.assertEqual(discussion.author, self.user)
|
||||
self.assertEqual(discussion.title, "Adding `safetensors` variant of this model")
|
||||
|
||||
def test_safetensors_on_the_fly_conversion_gated(self):
|
||||
config = BertConfig(
|
||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||
)
|
||||
initial_model = BertModel(config)
|
||||
|
||||
initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False)
|
||||
headers = {"Authorization": f"Bearer {self.token}"}
|
||||
requests.put(
|
||||
f"https://huggingface.co/api/models/{self.repo_name}/settings", json={"gated": "auto"}, headers=headers
|
||||
)
|
||||
converted_model = BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token)
|
||||
|
||||
with self.subTest("Initial and converted models are equal"):
|
||||
for p1, p2 in zip(initial_model.parameters(), converted_model.parameters()):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
|
||||
with self.subTest("PR was open with the safetensors account"):
|
||||
discussions = self.api.get_repo_discussions(self.repo_name)
|
||||
discussion = next(discussions)
|
||||
self.assertEqual(discussion.author, "SFconvertbot")
|
||||
self.assertEqual(discussion.title, "Adding `safetensors` variant of this model")
|
||||
|
||||
def test_safetensors_on_the_fly_sharded_conversion(self):
|
||||
config = BertConfig(
|
||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||
)
|
||||
initial_model = BertModel(config)
|
||||
|
||||
initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False, max_shard_size="200kb")
|
||||
converted_model = BertModel.from_pretrained(self.repo_name, use_safetensors=True)
|
||||
|
||||
with self.subTest("Initial and converted models are equal"):
|
||||
for p1, p2 in zip(initial_model.parameters(), converted_model.parameters()):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
|
||||
with self.subTest("PR was open with the safetensors account"):
|
||||
discussions = self.api.get_repo_discussions(self.repo_name)
|
||||
discussion = next(discussions)
|
||||
self.assertEqual(discussion.author, "SFconvertbot")
|
||||
self.assertEqual(discussion.title, "Adding `safetensors` variant of this model")
|
||||
|
||||
def test_safetensors_on_the_fly_sharded_conversion_private(self):
|
||||
config = BertConfig(
|
||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||
)
|
||||
initial_model = BertModel(config)
|
||||
|
||||
initial_model.push_to_hub(
|
||||
self.repo_name, token=self.token, safe_serialization=False, max_shard_size="200kb", private=True
|
||||
)
|
||||
converted_model = BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token)
|
||||
|
||||
with self.subTest("Initial and converted models are equal"):
|
||||
for p1, p2 in zip(initial_model.parameters(), converted_model.parameters()):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
|
||||
with self.subTest("PR was open with the safetensors account"):
|
||||
discussions = self.api.get_repo_discussions(self.repo_name)
|
||||
discussion = next(discussions)
|
||||
self.assertEqual(discussion.author, self.user)
|
||||
self.assertEqual(discussion.title, "Adding `safetensors` variant of this model")
|
||||
|
||||
def test_safetensors_on_the_fly_sharded_conversion_gated(self):
|
||||
config = BertConfig(
|
||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||
)
|
||||
initial_model = BertModel(config)
|
||||
|
||||
initial_model.push_to_hub(self.repo_name, token=self.token, max_shard_size="200kb", safe_serialization=False)
|
||||
headers = {"Authorization": f"Bearer {self.token}"}
|
||||
requests.put(
|
||||
f"https://huggingface.co/api/models/{self.repo_name}/settings", json={"gated": "auto"}, headers=headers
|
||||
)
|
||||
converted_model = BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token)
|
||||
|
||||
with self.subTest("Initial and converted models are equal"):
|
||||
for p1, p2 in zip(initial_model.parameters(), converted_model.parameters()):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
|
||||
with self.subTest("PR was open with the safetensors account"):
|
||||
discussions = self.api.get_repo_discussions(self.repo_name)
|
||||
discussion = next(discussions)
|
||||
self.assertEqual(discussion.author, "SFconvertbot")
|
||||
self.assertEqual(discussion.title, "Adding `safetensors` variant of this model")
|
||||
|
||||
@unittest.skip("Edge case, should work once the Space is updated`")
|
||||
def test_safetensors_on_the_fly_wrong_user_opened_pr(self):
|
||||
config = BertConfig(
|
||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||
)
|
||||
initial_model = BertModel(config)
|
||||
|
||||
initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False, private=True)
|
||||
BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token)
|
||||
|
||||
# This should have opened a PR with the user's account
|
||||
with self.subTest("PR was open with the safetensors account"):
|
||||
discussions = self.api.get_repo_discussions(self.repo_name)
|
||||
discussion = next(discussions)
|
||||
self.assertEqual(discussion.author, self.user)
|
||||
self.assertEqual(discussion.title, "Adding `safetensors` variant of this model")
|
||||
|
||||
# We now switch the repo visibility to public
|
||||
self.api.update_repo_visibility(self.repo_name, private=False)
|
||||
|
||||
# We once again call from_pretrained, which should call the bot to open a PR
|
||||
BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token)
|
||||
|
||||
with self.subTest("PR was open with the safetensors account"):
|
||||
discussions = self.api.get_repo_discussions(self.repo_name)
|
||||
|
||||
bot_opened_pr = None
|
||||
bot_opened_pr_title = None
|
||||
|
||||
for discussion in discussions:
|
||||
if discussion.author == "SFconvertBot":
|
||||
bot_opened_pr = True
|
||||
bot_opened_pr_title = discussion.title
|
||||
|
||||
self.assertTrue(bot_opened_pr)
|
||||
self.assertEqual(bot_opened_pr_title, "Adding `safetensors` variant of this model")
|
||||
|
||||
def test_safetensors_on_the_fly_specific_revision(self):
|
||||
config = BertConfig(
|
||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||
)
|
||||
initial_model = BertModel(config)
|
||||
|
||||
# Push a model on `main`
|
||||
initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False)
|
||||
|
||||
# Push a model on a given revision
|
||||
initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False, revision="new-branch")
|
||||
|
||||
# Try to convert the model on that revision should raise
|
||||
with self.assertRaises(EnvironmentError):
|
||||
BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token, revision="new-branch")
|
||||
|
||||
|
||||
@require_torch
|
||||
@is_staging_test
|
||||
class ModelPushToHubTester(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user