@@ -860,8 +860,9 @@ class ModelTesterMixin:
|
|||||||
model_eager = AutoModelForCausalLM.from_config(config, torch_dtype=torch.float32)
|
model_eager = AutoModelForCausalLM.from_config(config, torch_dtype=torch.float32)
|
||||||
|
|
||||||
model_eager.save_pretrained(tmpdir)
|
model_eager.save_pretrained(tmpdir)
|
||||||
with torch.device(torch_device):
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model = AutoModelForCausalLM.from_pretrained(tmpdir, torch_dtype=torch.float32)
|
tmpdir, torch_dtype=torch.float32, device_map=torch_device
|
||||||
|
)
|
||||||
inputs_dict["num_items_in_batch"] = inputs_dict["input_ids"].shape[0]
|
inputs_dict["num_items_in_batch"] = inputs_dict["input_ids"].shape[0]
|
||||||
inputs_dict["labels"] = inputs_dict["input_ids"]
|
inputs_dict["labels"] = inputs_dict["input_ids"]
|
||||||
_ = model(**inputs_dict, return_dict=False)
|
_ = model(**inputs_dict, return_dict=False)
|
||||||
|
|||||||
Reference in New Issue
Block a user