Fix failing conversion (#34010)

* Fix

* Tests

* Typo

* Typo
This commit is contained in:
Lysandre Debut
2024-10-11 14:59:23 +02:00
committed by GitHub
parent 9dca0c9116
commit 409dd2d19c
2 changed files with 31 additions and 38 deletions

View File

@@ -1,5 +1,3 @@
import json
import uuid
from typing import Optional from typing import Optional
import requests import requests
@@ -26,37 +24,33 @@ def spawn_conversion(token: str, private: bool, model_id: str):
logger.info("Attempting to convert .bin model on the fly to safetensors.") logger.info("Attempting to convert .bin model on the fly to safetensors.")
safetensors_convert_space_url = "https://safetensors-convert.hf.space" safetensors_convert_space_url = "https://safetensors-convert.hf.space"
sse_url = f"{safetensors_convert_space_url}/queue/join" sse_url = f"{safetensors_convert_space_url}/call/run"
sse_data_url = f"{safetensors_convert_space_url}/queue/data"
# The `fn_index` is necessary to indicate to gradio that we will use the `run` method of the Space. def start(_sse_connection):
hash_data = {"fn_index": 1, "session_hash": str(uuid.uuid4())}
def start(_sse_connection, payload):
for line in _sse_connection.iter_lines(): for line in _sse_connection.iter_lines():
line = line.decode() line = line.decode()
if line.startswith("data:"): if line.startswith("event:"):
resp = json.loads(line[5:]) status = line[7:]
logger.debug(f"Safetensors conversion status: {resp['msg']}") logger.debug(f"Safetensors conversion status: {status}")
if resp["msg"] == "queue_full":
raise ValueError("Queue is full! Please try again.")
elif resp["msg"] == "send_data":
event_id = resp["event_id"]
response = requests.post(
sse_data_url,
stream=True,
params=hash_data,
json={"event_id": event_id, **payload, **hash_data},
)
response.raise_for_status()
elif resp["msg"] == "process_completed":
return
with requests.get(sse_url, stream=True, params=hash_data) as sse_connection: if status == "complete":
data = {"data": [model_id, private, token]} return
elif status == "heartbeat":
logger.debug("Heartbeat")
else:
logger.debug(f"Unknown status {status}")
else:
logger.debug(line)
data = {"data": [model_id, private, token]}
result = requests.post(sse_url, stream=True, json=data).json()
event_id = result["event_id"]
with requests.get(f"{sse_url}/{event_id}", stream=True) as sse_connection:
try: try:
logger.debug("Spawning safetensors automatic conversion.") logger.debug("Spawning safetensors automatic conversion.")
start(sse_connection, data) start(sse_connection)
except Exception as e: except Exception as e:
logger.warning(f"Error during conversion: {repr(e)}") logger.warning(f"Error during conversion: {repr(e)}")
@@ -86,7 +80,7 @@ def get_conversion_pr_reference(api: HfApi, model_id: str, **kwargs):
def auto_conversion(pretrained_model_name_or_path: str, ignore_errors_during_conversion=False, **cached_file_kwargs): def auto_conversion(pretrained_model_name_or_path: str, ignore_errors_during_conversion=False, **cached_file_kwargs):
try: try:
api = HfApi(token=cached_file_kwargs.get("token"), headers=http_user_agent()) api = HfApi(token=cached_file_kwargs.get("token"), headers={"user-agent": http_user_agent()})
sha = get_conversion_pr_reference(api, pretrained_model_name_or_path, **cached_file_kwargs) sha = get_conversion_pr_reference(api, pretrained_model_name_or_path, **cached_file_kwargs)
if sha is None: if sha is None:

View File

@@ -2009,19 +2009,18 @@ class ModelOnTheFlyConversionTester(unittest.TestCase):
if thread.name == "Thread-autoconversion": if thread.name == "Thread-autoconversion":
thread.join(timeout=10) thread.join(timeout=10)
with self.subTest("PR was open with the safetensors account"): discussions = self.api.get_repo_discussions(self.repo_name)
discussions = self.api.get_repo_discussions(self.repo_name)
bot_opened_pr = None bot_opened_pr = None
bot_opened_pr_title = None bot_opened_pr_title = None
for discussion in discussions: for discussion in discussions:
if discussion.author == "SFconvertbot": if discussion.author == "SFconvertbot":
bot_opened_pr = True bot_opened_pr = True
bot_opened_pr_title = discussion.title bot_opened_pr_title = discussion.title
self.assertTrue(bot_opened_pr) self.assertTrue(bot_opened_pr)
self.assertEqual(bot_opened_pr_title, "Adding `safetensors` variant of this model") self.assertEqual(bot_opened_pr_title, "Adding `safetensors` variant of this model")
@mock.patch("transformers.safetensors_conversion.spawn_conversion") @mock.patch("transformers.safetensors_conversion.spawn_conversion")
def test_absence_of_safetensors_triggers_conversion_failed(self, spawn_conversion_mock): def test_absence_of_safetensors_triggers_conversion_failed(self, spawn_conversion_mock):