Better typing for num_items_in_batch (#38728)
* fix * style * type checking ? * maybe this ? * fix * can't be an int anymore * fix
This commit is contained in:
@@ -944,7 +944,7 @@ class ModelTesterMixin:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
tmpdir, torch_dtype=torch.float32, device_map=torch_device
|
||||
)
|
||||
inputs_dict["num_items_in_batch"] = inputs_dict["input_ids"].shape[0]
|
||||
inputs_dict["num_items_in_batch"] = torch.tensor(inputs_dict["input_ids"].shape[0])
|
||||
inputs_dict["labels"] = inputs_dict["input_ids"]
|
||||
_ = model(**inputs_dict, return_dict=False)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user