Experimental loading of MLX files (#29511)

* Experimental loading of MLX files

* Update exception message

* Add test

* Style

* Use model from hf-internal-testing
This commit is contained in:
Pedro Cuenca
2024-03-11 19:42:06 +01:00
committed by GitHub
parent 73a27345d4
commit b382a09e28
2 changed files with 24 additions and 1 deletions

View File

@@ -1256,6 +1256,26 @@ class ModelUtilsTest(TestCasePlus):
self.assertEqual(len(logs.output), 1)
self.assertIn("Your generation config was originally created from the model config", logs.output[0])
@require_safetensors
def test_model_from_pretrained_from_mlx(self):
from safetensors import safe_open
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-mistral-mlx")
self.assertIsNotNone(model)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, safe_serialization=True)
with safe_open(os.path.join(tmp_dir, "model.safetensors"), framework="pt") as f:
metadata = f.metadata()
self.assertEqual(metadata.get("format"), "pt")
new_model = AutoModelForCausalLM.from_pretrained(tmp_dir)
input_ids = torch.randint(100, 1000, (1, 10))
with torch.no_grad():
outputs = model(input_ids)
outputs_from_saved = new_model(input_ids)
self.assertTrue(torch.allclose(outputs_from_saved["logits"], outputs["logits"]))
@slow
@require_torch