@@ -1,5 +1,3 @@
|
|||||||
import json
|
|
||||||
import uuid
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
@@ -26,37 +24,33 @@ def spawn_conversion(token: str, private: bool, model_id: str):
|
|||||||
logger.info("Attempting to convert .bin model on the fly to safetensors.")
|
logger.info("Attempting to convert .bin model on the fly to safetensors.")
|
||||||
|
|
||||||
safetensors_convert_space_url = "https://safetensors-convert.hf.space"
|
safetensors_convert_space_url = "https://safetensors-convert.hf.space"
|
||||||
sse_url = f"{safetensors_convert_space_url}/queue/join"
|
sse_url = f"{safetensors_convert_space_url}/call/run"
|
||||||
sse_data_url = f"{safetensors_convert_space_url}/queue/data"
|
|
||||||
|
|
||||||
# The `fn_index` is necessary to indicate to gradio that we will use the `run` method of the Space.
|
def start(_sse_connection):
|
||||||
hash_data = {"fn_index": 1, "session_hash": str(uuid.uuid4())}
|
|
||||||
|
|
||||||
def start(_sse_connection, payload):
|
|
||||||
for line in _sse_connection.iter_lines():
|
for line in _sse_connection.iter_lines():
|
||||||
line = line.decode()
|
line = line.decode()
|
||||||
if line.startswith("data:"):
|
if line.startswith("event:"):
|
||||||
resp = json.loads(line[5:])
|
status = line[7:]
|
||||||
logger.debug(f"Safetensors conversion status: {resp['msg']}")
|
logger.debug(f"Safetensors conversion status: {status}")
|
||||||
if resp["msg"] == "queue_full":
|
|
||||||
raise ValueError("Queue is full! Please try again.")
|
if status == "complete":
|
||||||
elif resp["msg"] == "send_data":
|
return
|
||||||
event_id = resp["event_id"]
|
elif status == "heartbeat":
|
||||||
response = requests.post(
|
logger.debug("Heartbeat")
|
||||||
sse_data_url,
|
else:
|
||||||
stream=True,
|
logger.debug(f"Unknown status {status}")
|
||||||
params=hash_data,
|
else:
|
||||||
json={"event_id": event_id, **payload, **hash_data},
|
logger.debug(line)
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
elif resp["msg"] == "process_completed":
|
|
||||||
return
|
|
||||||
|
|
||||||
with requests.get(sse_url, stream=True, params=hash_data) as sse_connection:
|
|
||||||
data = {"data": [model_id, private, token]}
|
data = {"data": [model_id, private, token]}
|
||||||
|
|
||||||
|
result = requests.post(sse_url, stream=True, json=data).json()
|
||||||
|
event_id = result["event_id"]
|
||||||
|
|
||||||
|
with requests.get(f"{sse_url}/{event_id}", stream=True) as sse_connection:
|
||||||
try:
|
try:
|
||||||
logger.debug("Spawning safetensors automatic conversion.")
|
logger.debug("Spawning safetensors automatic conversion.")
|
||||||
start(sse_connection, data)
|
start(sse_connection)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error during conversion: {repr(e)}")
|
logger.warning(f"Error during conversion: {repr(e)}")
|
||||||
|
|
||||||
@@ -86,7 +80,7 @@ def get_conversion_pr_reference(api: HfApi, model_id: str, **kwargs):
|
|||||||
|
|
||||||
def auto_conversion(pretrained_model_name_or_path: str, ignore_errors_during_conversion=False, **cached_file_kwargs):
|
def auto_conversion(pretrained_model_name_or_path: str, ignore_errors_during_conversion=False, **cached_file_kwargs):
|
||||||
try:
|
try:
|
||||||
api = HfApi(token=cached_file_kwargs.get("token"), headers=http_user_agent())
|
api = HfApi(token=cached_file_kwargs.get("token"), headers={"user-agent": http_user_agent()})
|
||||||
sha = get_conversion_pr_reference(api, pretrained_model_name_or_path, **cached_file_kwargs)
|
sha = get_conversion_pr_reference(api, pretrained_model_name_or_path, **cached_file_kwargs)
|
||||||
|
|
||||||
if sha is None:
|
if sha is None:
|
||||||
|
|||||||
@@ -2009,7 +2009,6 @@ class ModelOnTheFlyConversionTester(unittest.TestCase):
|
|||||||
if thread.name == "Thread-autoconversion":
|
if thread.name == "Thread-autoconversion":
|
||||||
thread.join(timeout=10)
|
thread.join(timeout=10)
|
||||||
|
|
||||||
with self.subTest("PR was open with the safetensors account"):
|
|
||||||
discussions = self.api.get_repo_discussions(self.repo_name)
|
discussions = self.api.get_repo_discussions(self.repo_name)
|
||||||
|
|
||||||
bot_opened_pr = None
|
bot_opened_pr = None
|
||||||
|
|||||||
Reference in New Issue
Block a user