* Update test_modeling_common.py

* style
This commit is contained in:
Cyril Vallez
2025-04-03 10:24:34 +02:00
committed by GitHub
parent 12048990a9
commit 6ce238fe7a

View File

@@ -860,11 +860,12 @@ class ModelTesterMixin:
model_eager = AutoModelForCausalLM.from_config(config, torch_dtype=torch.float32) model_eager = AutoModelForCausalLM.from_config(config, torch_dtype=torch.float32)
model_eager.save_pretrained(tmpdir) model_eager.save_pretrained(tmpdir)
with torch.device(torch_device): model = AutoModelForCausalLM.from_pretrained(
model = AutoModelForCausalLM.from_pretrained(tmpdir, torch_dtype=torch.float32) tmpdir, torch_dtype=torch.float32, device_map=torch_device
inputs_dict["num_items_in_batch"] = inputs_dict["input_ids"].shape[0] )
inputs_dict["labels"] = inputs_dict["input_ids"] inputs_dict["num_items_in_batch"] = inputs_dict["input_ids"].shape[0]
_ = model(**inputs_dict, return_dict=False) inputs_dict["labels"] = inputs_dict["input_ids"]
_ = model(**inputs_dict, return_dict=False)
def test_training_gradient_checkpointing(self): def test_training_gradient_checkpointing(self):
# Scenario - 1 default behaviour # Scenario - 1 default behaviour