Fix batching tests for new models (Mamba and SegGPT) (#29633)
* fix batchinng tests for new models * Update tests/models/seggpt/test_modeling_seggpt.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
31d01150ad
commit
5ac264d8a8
@@ -720,8 +720,8 @@ class ModelTesterMixin:
|
||||
batched_object.values(), single_row_object.values()
|
||||
):
|
||||
recursive_check(batched_object_value, single_row_object_value, model_name, key)
|
||||
# do not compare returned loss (0-dim tensor) or codebook ids (int)
|
||||
elif batched_object is None or isinstance(batched_object, int):
|
||||
# do not compare returned loss (0-dim tensor) / codebook ids (int) / caching objects
|
||||
elif batched_object is None or not isinstance(batched_object, torch.Tensor):
|
||||
return
|
||||
elif batched_object.dim() == 0:
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user