@@ -860,8 +860,9 @@ class ModelTesterMixin:
|
||||
model_eager = AutoModelForCausalLM.from_config(config, torch_dtype=torch.float32)
|
||||
|
||||
model_eager.save_pretrained(tmpdir)
|
||||
with torch.device(torch_device):
|
||||
model = AutoModelForCausalLM.from_pretrained(tmpdir, torch_dtype=torch.float32)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
tmpdir, torch_dtype=torch.float32, device_map=torch_device
|
||||
)
|
||||
inputs_dict["num_items_in_batch"] = inputs_dict["input_ids"].shape[0]
|
||||
inputs_dict["labels"] = inputs_dict["input_ids"]
|
||||
_ = model(**inputs_dict, return_dict=False)
|
||||
|
||||
Reference in New Issue
Block a user