Make using safetensors files automated. (#27571)
* [WIP] Make using safetensors files automated. If `use_safetensors=True` is used, and it doesn't exist: - Don't crash just yet - Lookup for an open PR containing it. - If yes, use that instead - If not, touch the space to convert, wait for conversion to be finished and the PR to be opened - Use that new PR - Profit. * Remove the token. * [Auto Safetensors] Websocket -> SSE (#27656) * Websocket -> SSE * Support sharded + tests +cleanup a * env var * Apply suggestions from code review * Thanks Simon * Thanks Wauplin Co-authored-by: Wauplin <lucainp@gmail.com> * Cleanup * Update tests * Tests should pass * Apply to other tests * Extend extension * relax requirement on latest hfh * Revert * Correct private handling & debug statements * Skip gated repos as of now * Address review comments Co-authored-by: ArthurZucker <arthur.zucker@gmail.com> --------- Co-authored-by: Lysandre Debut <hi@lysand.re> Co-authored-by: Lysandre <lysandre@huggingface.co> Co-authored-by: Wauplin <lucainp@gmail.com> Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr> Co-authored-by: ArthurZucker <arthur.zucker@gmail.com>
This commit is contained in:
@@ -49,6 +49,7 @@ from .pytorch_utils import ( # noqa: F401
|
|||||||
prune_layer,
|
prune_layer,
|
||||||
prune_linear_layer,
|
prune_linear_layer,
|
||||||
)
|
)
|
||||||
|
from .safetensors_conversion import auto_conversion
|
||||||
from .utils import (
|
from .utils import (
|
||||||
ADAPTER_SAFE_WEIGHTS_NAME,
|
ADAPTER_SAFE_WEIGHTS_NAME,
|
||||||
ADAPTER_WEIGHTS_NAME,
|
ADAPTER_WEIGHTS_NAME,
|
||||||
@@ -3088,8 +3089,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
if resolved_archive_file is not None:
|
if resolved_archive_file is not None:
|
||||||
is_sharded = True
|
is_sharded = True
|
||||||
elif use_safetensors:
|
elif use_safetensors:
|
||||||
|
if revision == "main":
|
||||||
|
resolved_archive_file, revision, is_sharded = auto_conversion(
|
||||||
|
pretrained_model_name_or_path, **cached_file_kwargs
|
||||||
|
)
|
||||||
|
cached_file_kwargs["revision"] = revision
|
||||||
|
if resolved_archive_file is None:
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} and thus cannot be loaded with `safetensors`. Please make sure that the model has been saved with `safe_serialization=True` or do not set `use_safetensors=True`."
|
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
||||||
|
f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} "
|
||||||
|
"and thus cannot be loaded with `safetensors`. Please make sure that the model has "
|
||||||
|
"been saved with `safe_serialization=True` or do not set `use_safetensors=True`."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# This repo has no safetensors file of any kind, we switch to PyTorch.
|
# This repo has no safetensors file of any kind, we switch to PyTorch.
|
||||||
@@ -3144,7 +3154,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
|
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
|
||||||
# to the original exception.
|
# to the original exception.
|
||||||
raise
|
raise
|
||||||
except Exception:
|
except Exception as e:
|
||||||
# For any other exception, we throw a generic error.
|
# For any other exception, we throw a generic error.
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
|
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
|
||||||
@@ -3152,7 +3162,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
|
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
|
||||||
f" directory containing a file named {_add_variant(WEIGHTS_NAME, variant)},"
|
f" directory containing a file named {_add_variant(WEIGHTS_NAME, variant)},"
|
||||||
f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}."
|
f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}."
|
||||||
)
|
) from e
|
||||||
|
|
||||||
if is_local:
|
if is_local:
|
||||||
logger.info(f"loading weights file {archive_file}")
|
logger.info(f"loading weights file {archive_file}")
|
||||||
|
|||||||
107
src/transformers/safetensors_conversion.py
Normal file
107
src/transformers/safetensors_conversion.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from huggingface_hub import Discussion, HfApi, get_repo_discussions
|
||||||
|
|
||||||
|
from .utils import cached_file, logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def previous_pr(api: HfApi, model_id: str, pr_title: str, token: str) -> Optional["Discussion"]:
|
||||||
|
main_commit = api.list_repo_commits(model_id, token=token)[0].commit_id
|
||||||
|
for discussion in get_repo_discussions(repo_id=model_id, token=token):
|
||||||
|
if discussion.title == pr_title and discussion.status == "open" and discussion.is_pull_request:
|
||||||
|
commits = api.list_repo_commits(model_id, revision=discussion.git_reference, token=token)
|
||||||
|
|
||||||
|
if main_commit == commits[1].commit_id:
|
||||||
|
return discussion
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def spawn_conversion(token: str, private: bool, model_id: str):
|
||||||
|
logger.info("Attempting to convert .bin model on the fly to safetensors.")
|
||||||
|
|
||||||
|
safetensors_convert_space_url = "https://safetensors-convert.hf.space"
|
||||||
|
sse_url = f"{safetensors_convert_space_url}/queue/join"
|
||||||
|
sse_data_url = f"{safetensors_convert_space_url}/queue/data"
|
||||||
|
|
||||||
|
# The `fn_index` is necessary to indicate to gradio that we will use the `run` method of the Space.
|
||||||
|
hash_data = {"fn_index": 1, "session_hash": str(uuid.uuid4())}
|
||||||
|
|
||||||
|
def start(_sse_connection, payload):
|
||||||
|
for line in _sse_connection.iter_lines():
|
||||||
|
line = line.decode()
|
||||||
|
if line.startswith("data:"):
|
||||||
|
resp = json.loads(line[5:])
|
||||||
|
logger.debug(f"Safetensors conversion status: {resp['msg']}")
|
||||||
|
if resp["msg"] == "queue_full":
|
||||||
|
raise ValueError("Queue is full! Please try again.")
|
||||||
|
elif resp["msg"] == "send_data":
|
||||||
|
event_id = resp["event_id"]
|
||||||
|
response = requests.post(
|
||||||
|
sse_data_url,
|
||||||
|
stream=True,
|
||||||
|
params=hash_data,
|
||||||
|
json={"event_id": event_id, **payload, **hash_data},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
elif resp["msg"] == "process_completed":
|
||||||
|
return
|
||||||
|
|
||||||
|
with requests.get(sse_url, stream=True, params=hash_data) as sse_connection:
|
||||||
|
data = {"data": [model_id, private, token]}
|
||||||
|
try:
|
||||||
|
logger.debug("Spawning safetensors automatic conversion.")
|
||||||
|
start(sse_connection, data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error during conversion: {repr(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_conversion_pr_reference(api: HfApi, model_id: str, **kwargs):
|
||||||
|
private = api.model_info(model_id).private
|
||||||
|
|
||||||
|
logger.info("Attempting to create safetensors variant")
|
||||||
|
pr_title = "Adding `safetensors` variant of this model"
|
||||||
|
token = kwargs.get("token")
|
||||||
|
|
||||||
|
# This looks into the current repo's open PRs to see if a PR for safetensors was already open. If so, it
|
||||||
|
# returns it. It checks that the PR was opened by the bot and not by another user so as to prevent
|
||||||
|
# security breaches.
|
||||||
|
pr = previous_pr(api, model_id, pr_title, token=token)
|
||||||
|
|
||||||
|
if pr is None or (not private and pr.author != "SFConvertBot"):
|
||||||
|
spawn_conversion(token, private, model_id)
|
||||||
|
pr = previous_pr(api, model_id, pr_title, token=token)
|
||||||
|
else:
|
||||||
|
logger.info("Safetensors PR exists")
|
||||||
|
|
||||||
|
sha = f"refs/pr/{pr.num}"
|
||||||
|
|
||||||
|
return sha
|
||||||
|
|
||||||
|
|
||||||
|
def auto_conversion(pretrained_model_name_or_path: str, **cached_file_kwargs):
|
||||||
|
api = HfApi(token=cached_file_kwargs.get("token"))
|
||||||
|
sha = get_conversion_pr_reference(api, pretrained_model_name_or_path, **cached_file_kwargs)
|
||||||
|
|
||||||
|
if sha is None:
|
||||||
|
return None, None
|
||||||
|
cached_file_kwargs["revision"] = sha
|
||||||
|
del cached_file_kwargs["_commit_hash"]
|
||||||
|
|
||||||
|
# This is an additional HEAD call that could be removed if we could infer sharded/non-sharded from the PR
|
||||||
|
# description.
|
||||||
|
sharded = api.file_exists(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
"model.safetensors.index.json",
|
||||||
|
revision=sha,
|
||||||
|
token=cached_file_kwargs.get("token"),
|
||||||
|
)
|
||||||
|
filename = "model.safetensors.index.json" if sharded else "model.safetensors"
|
||||||
|
|
||||||
|
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
|
||||||
|
return resolved_archive_file, sha, sharded
|
||||||
@@ -21,9 +21,11 @@ import sys
|
|||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
import unittest.mock as mock
|
import unittest.mock as mock
|
||||||
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from huggingface_hub import HfFolder, delete_repo
|
import requests
|
||||||
|
from huggingface_hub import HfApi, HfFolder, delete_repo
|
||||||
from huggingface_hub.file_download import http_get
|
from huggingface_hub.file_download import http_get
|
||||||
from pytest import mark
|
from pytest import mark
|
||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
@@ -829,15 +831,9 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
|
|
||||||
@require_safetensors
|
@require_safetensors
|
||||||
def test_use_safetensors(self):
|
def test_use_safetensors(self):
|
||||||
# test nice error message if no safetensor files available
|
# Should not raise anymore
|
||||||
with self.assertRaises(OSError) as env_error:
|
|
||||||
AutoModel.from_pretrained("hf-internal-testing/tiny-random-RobertaModel", use_safetensors=True)
|
AutoModel.from_pretrained("hf-internal-testing/tiny-random-RobertaModel", use_safetensors=True)
|
||||||
|
|
||||||
self.assertTrue(
|
|
||||||
"model.safetensors or model.safetensors.index.json and thus cannot be loaded with `safetensors`"
|
|
||||||
in str(env_error.exception)
|
|
||||||
)
|
|
||||||
|
|
||||||
# test that error if only safetensors is available
|
# test that error if only safetensors is available
|
||||||
with self.assertRaises(OSError) as env_error:
|
with self.assertRaises(OSError) as env_error:
|
||||||
BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-safetensors", use_safetensors=False)
|
BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-safetensors", use_safetensors=False)
|
||||||
@@ -1171,6 +1167,202 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
self.assertTrue(torch.equal(p1, p2))
|
self.assertTrue(torch.equal(p1, p2))
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class ModelOnTheFlyConversionTester(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.user = "huggingface-hub-ci"
|
||||||
|
cls.token = os.getenv("HUGGINGFACE_PRODUCTION_USER_TOKEN", None)
|
||||||
|
|
||||||
|
if cls.token is None:
|
||||||
|
raise ValueError("Cannot run tests as secret isn't setup.")
|
||||||
|
|
||||||
|
cls.api = HfApi(token=cls.token)
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.repo_name = f"{self.user}/test-model-on-the-fly-{uuid.uuid4()}"
|
||||||
|
|
||||||
|
def tearDown(self) -> None:
|
||||||
|
self.api.delete_repo(self.repo_name)
|
||||||
|
|
||||||
|
def test_safetensors_on_the_fly_conversion(self):
|
||||||
|
config = BertConfig(
|
||||||
|
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||||
|
)
|
||||||
|
initial_model = BertModel(config)
|
||||||
|
|
||||||
|
initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False)
|
||||||
|
converted_model = BertModel.from_pretrained(self.repo_name, use_safetensors=True)
|
||||||
|
|
||||||
|
with self.subTest("Initial and converted models are equal"):
|
||||||
|
for p1, p2 in zip(initial_model.parameters(), converted_model.parameters()):
|
||||||
|
self.assertTrue(torch.equal(p1, p2))
|
||||||
|
|
||||||
|
with self.subTest("PR was open with the safetensors account"):
|
||||||
|
discussions = self.api.get_repo_discussions(self.repo_name)
|
||||||
|
discussion = next(discussions)
|
||||||
|
self.assertEqual(discussion.author, "SFconvertbot")
|
||||||
|
self.assertEqual(discussion.title, "Adding `safetensors` variant of this model")
|
||||||
|
|
||||||
|
def test_safetensors_on_the_fly_conversion_private(self):
|
||||||
|
config = BertConfig(
|
||||||
|
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||||
|
)
|
||||||
|
initial_model = BertModel(config)
|
||||||
|
|
||||||
|
initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False, private=True)
|
||||||
|
converted_model = BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token)
|
||||||
|
|
||||||
|
with self.subTest("Initial and converted models are equal"):
|
||||||
|
for p1, p2 in zip(initial_model.parameters(), converted_model.parameters()):
|
||||||
|
self.assertTrue(torch.equal(p1, p2))
|
||||||
|
|
||||||
|
with self.subTest("PR was open with the safetensors account"):
|
||||||
|
discussions = self.api.get_repo_discussions(self.repo_name, token=self.token)
|
||||||
|
discussion = next(discussions)
|
||||||
|
self.assertEqual(discussion.author, self.user)
|
||||||
|
self.assertEqual(discussion.title, "Adding `safetensors` variant of this model")
|
||||||
|
|
||||||
|
def test_safetensors_on_the_fly_conversion_gated(self):
|
||||||
|
config = BertConfig(
|
||||||
|
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||||
|
)
|
||||||
|
initial_model = BertModel(config)
|
||||||
|
|
||||||
|
initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False)
|
||||||
|
headers = {"Authorization": f"Bearer {self.token}"}
|
||||||
|
requests.put(
|
||||||
|
f"https://huggingface.co/api/models/{self.repo_name}/settings", json={"gated": "auto"}, headers=headers
|
||||||
|
)
|
||||||
|
converted_model = BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token)
|
||||||
|
|
||||||
|
with self.subTest("Initial and converted models are equal"):
|
||||||
|
for p1, p2 in zip(initial_model.parameters(), converted_model.parameters()):
|
||||||
|
self.assertTrue(torch.equal(p1, p2))
|
||||||
|
|
||||||
|
with self.subTest("PR was open with the safetensors account"):
|
||||||
|
discussions = self.api.get_repo_discussions(self.repo_name)
|
||||||
|
discussion = next(discussions)
|
||||||
|
self.assertEqual(discussion.author, "SFconvertbot")
|
||||||
|
self.assertEqual(discussion.title, "Adding `safetensors` variant of this model")
|
||||||
|
|
||||||
|
def test_safetensors_on_the_fly_sharded_conversion(self):
|
||||||
|
config = BertConfig(
|
||||||
|
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||||
|
)
|
||||||
|
initial_model = BertModel(config)
|
||||||
|
|
||||||
|
initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False, max_shard_size="200kb")
|
||||||
|
converted_model = BertModel.from_pretrained(self.repo_name, use_safetensors=True)
|
||||||
|
|
||||||
|
with self.subTest("Initial and converted models are equal"):
|
||||||
|
for p1, p2 in zip(initial_model.parameters(), converted_model.parameters()):
|
||||||
|
self.assertTrue(torch.equal(p1, p2))
|
||||||
|
|
||||||
|
with self.subTest("PR was open with the safetensors account"):
|
||||||
|
discussions = self.api.get_repo_discussions(self.repo_name)
|
||||||
|
discussion = next(discussions)
|
||||||
|
self.assertEqual(discussion.author, "SFconvertbot")
|
||||||
|
self.assertEqual(discussion.title, "Adding `safetensors` variant of this model")
|
||||||
|
|
||||||
|
def test_safetensors_on_the_fly_sharded_conversion_private(self):
|
||||||
|
config = BertConfig(
|
||||||
|
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||||
|
)
|
||||||
|
initial_model = BertModel(config)
|
||||||
|
|
||||||
|
initial_model.push_to_hub(
|
||||||
|
self.repo_name, token=self.token, safe_serialization=False, max_shard_size="200kb", private=True
|
||||||
|
)
|
||||||
|
converted_model = BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token)
|
||||||
|
|
||||||
|
with self.subTest("Initial and converted models are equal"):
|
||||||
|
for p1, p2 in zip(initial_model.parameters(), converted_model.parameters()):
|
||||||
|
self.assertTrue(torch.equal(p1, p2))
|
||||||
|
|
||||||
|
with self.subTest("PR was open with the safetensors account"):
|
||||||
|
discussions = self.api.get_repo_discussions(self.repo_name)
|
||||||
|
discussion = next(discussions)
|
||||||
|
self.assertEqual(discussion.author, self.user)
|
||||||
|
self.assertEqual(discussion.title, "Adding `safetensors` variant of this model")
|
||||||
|
|
||||||
|
def test_safetensors_on_the_fly_sharded_conversion_gated(self):
|
||||||
|
config = BertConfig(
|
||||||
|
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||||
|
)
|
||||||
|
initial_model = BertModel(config)
|
||||||
|
|
||||||
|
initial_model.push_to_hub(self.repo_name, token=self.token, max_shard_size="200kb", safe_serialization=False)
|
||||||
|
headers = {"Authorization": f"Bearer {self.token}"}
|
||||||
|
requests.put(
|
||||||
|
f"https://huggingface.co/api/models/{self.repo_name}/settings", json={"gated": "auto"}, headers=headers
|
||||||
|
)
|
||||||
|
converted_model = BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token)
|
||||||
|
|
||||||
|
with self.subTest("Initial and converted models are equal"):
|
||||||
|
for p1, p2 in zip(initial_model.parameters(), converted_model.parameters()):
|
||||||
|
self.assertTrue(torch.equal(p1, p2))
|
||||||
|
|
||||||
|
with self.subTest("PR was open with the safetensors account"):
|
||||||
|
discussions = self.api.get_repo_discussions(self.repo_name)
|
||||||
|
discussion = next(discussions)
|
||||||
|
self.assertEqual(discussion.author, "SFconvertbot")
|
||||||
|
self.assertEqual(discussion.title, "Adding `safetensors` variant of this model")
|
||||||
|
|
||||||
|
@unittest.skip("Edge case, should work once the Space is updated`")
|
||||||
|
def test_safetensors_on_the_fly_wrong_user_opened_pr(self):
|
||||||
|
config = BertConfig(
|
||||||
|
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||||
|
)
|
||||||
|
initial_model = BertModel(config)
|
||||||
|
|
||||||
|
initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False, private=True)
|
||||||
|
BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token)
|
||||||
|
|
||||||
|
# This should have opened a PR with the user's account
|
||||||
|
with self.subTest("PR was open with the safetensors account"):
|
||||||
|
discussions = self.api.get_repo_discussions(self.repo_name)
|
||||||
|
discussion = next(discussions)
|
||||||
|
self.assertEqual(discussion.author, self.user)
|
||||||
|
self.assertEqual(discussion.title, "Adding `safetensors` variant of this model")
|
||||||
|
|
||||||
|
# We now switch the repo visibility to public
|
||||||
|
self.api.update_repo_visibility(self.repo_name, private=False)
|
||||||
|
|
||||||
|
# We once again call from_pretrained, which should call the bot to open a PR
|
||||||
|
BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token)
|
||||||
|
|
||||||
|
with self.subTest("PR was open with the safetensors account"):
|
||||||
|
discussions = self.api.get_repo_discussions(self.repo_name)
|
||||||
|
|
||||||
|
bot_opened_pr = None
|
||||||
|
bot_opened_pr_title = None
|
||||||
|
|
||||||
|
for discussion in discussions:
|
||||||
|
if discussion.author == "SFconvertBot":
|
||||||
|
bot_opened_pr = True
|
||||||
|
bot_opened_pr_title = discussion.title
|
||||||
|
|
||||||
|
self.assertTrue(bot_opened_pr)
|
||||||
|
self.assertEqual(bot_opened_pr_title, "Adding `safetensors` variant of this model")
|
||||||
|
|
||||||
|
def test_safetensors_on_the_fly_specific_revision(self):
|
||||||
|
config = BertConfig(
|
||||||
|
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||||
|
)
|
||||||
|
initial_model = BertModel(config)
|
||||||
|
|
||||||
|
# Push a model on `main`
|
||||||
|
initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False)
|
||||||
|
|
||||||
|
# Push a model on a given revision
|
||||||
|
initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False, revision="new-branch")
|
||||||
|
|
||||||
|
# Try to convert the model on that revision should raise
|
||||||
|
with self.assertRaises(EnvironmentError):
|
||||||
|
BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token, revision="new-branch")
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@is_staging_test
|
@is_staging_test
|
||||||
class ModelPushToHubTester(unittest.TestCase):
|
class ModelPushToHubTester(unittest.TestCase):
|
||||||
|
|||||||
@@ -330,6 +330,7 @@ IGNORE_SUBMODULES = [
|
|||||||
"modeling_flax_pytorch_utils",
|
"modeling_flax_pytorch_utils",
|
||||||
"models.esm.openfold_utils",
|
"models.esm.openfold_utils",
|
||||||
"modeling_attn_mask_utils",
|
"modeling_attn_mask_utils",
|
||||||
|
"safetensors_conversion",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user