Better typing for num_items_in_batch (#38728)

* fix

* style

* type checking ?

* maybe this ?

* fix

* can't be an int anymore

* fix
This commit is contained in:
Marc Sun
2025-06-11 16:26:41 +02:00
committed by GitHub
parent 84710a4291
commit 11ad9be153
5 changed files with 47 additions and 18 deletions

View File

@@ -944,7 +944,7 @@ class ModelTesterMixin:
model = AutoModelForCausalLM.from_pretrained(
tmpdir, torch_dtype=torch.float32, device_map=torch_device
)
inputs_dict["num_items_in_batch"] = inputs_dict["input_ids"].shape[0]
inputs_dict["num_items_in_batch"] = torch.tensor(inputs_dict["input_ids"].shape[0])
inputs_dict["labels"] = inputs_dict["input_ids"]
_ = model(**inputs_dict, return_dict=False)