From 6ce238fe7a930fc03f39cdf72fc9ce9807c83e55 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 3 Apr 2025 10:24:34 +0200 Subject: [PATCH] Fix test (#37213) * Update test_modeling_common.py * style --- tests/test_modeling_common.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index ce8921b333..42aea7b67d 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -860,11 +860,12 @@ class ModelTesterMixin: model_eager = AutoModelForCausalLM.from_config(config, torch_dtype=torch.float32) model_eager.save_pretrained(tmpdir) - with torch.device(torch_device): - model = AutoModelForCausalLM.from_pretrained(tmpdir, torch_dtype=torch.float32) - inputs_dict["num_items_in_batch"] = inputs_dict["input_ids"].shape[0] - inputs_dict["labels"] = inputs_dict["input_ids"] - _ = model(**inputs_dict, return_dict=False) + model = AutoModelForCausalLM.from_pretrained( + tmpdir, torch_dtype=torch.float32, device_map=torch_device + ) + inputs_dict["num_items_in_batch"] = inputs_dict["input_ids"].shape[0] + inputs_dict["labels"] = inputs_dict["input_ids"] + _ = model(**inputs_dict, return_dict=False) def test_training_gradient_checkpointing(self): # Scenario - 1 default behaviour