From b382a09e28c7e59129246ccdf4b00f2cac4547a4 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 11 Mar 2024 19:42:06 +0100 Subject: [PATCH] Experimental loading of MLX files (#29511) * Experimental loading of MLX files * Update exception message * Add test * Style * Use model from hf-internal-testing --- src/transformers/modeling_utils.py | 5 ++++- tests/test_modeling_utils.py | 20 ++++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index ca5ee26e27..9f4e0a136a 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3297,9 +3297,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix elif metadata.get("format") == "flax": from_flax = True logger.info("A Flax safetensors file is being loaded in a PyTorch model.") + elif metadata.get("format") == "mlx": + # This is a mlx file, we assume weights are compatible with pt + pass else: raise ValueError( - f"Incompatible safetensors file. File metadata is not ['pt', 'tf', 'flax'] but {metadata.get('format')}" + f"Incompatible safetensors file. File metadata is not ['pt', 'tf', 'flax', 'mlx'] but {metadata.get('format')}" ) from_pt = not (from_tf | from_flax) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 87b933425f..1b2351be93 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -1256,6 +1256,26 @@ class ModelUtilsTest(TestCasePlus): self.assertEqual(len(logs.output), 1) self.assertIn("Your generation config was originally created from the model config", logs.output[0]) + @require_safetensors + def test_model_from_pretrained_from_mlx(self): + from safetensors import safe_open + + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-mistral-mlx") + self.assertIsNotNone(model) + + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir, safe_serialization=True) + with safe_open(os.path.join(tmp_dir, "model.safetensors"), framework="pt") as f: + metadata = f.metadata() + self.assertEqual(metadata.get("format"), "pt") + new_model = AutoModelForCausalLM.from_pretrained(tmp_dir) + + input_ids = torch.randint(100, 1000, (1, 10)) + with torch.no_grad(): + outputs = model(input_ids) + outputs_from_saved = new_model(input_ids) + self.assertTrue(torch.allclose(outputs_from_saved["logits"], outputs["logits"])) + @slow @require_torch