fix test (#9669)
This commit is contained in:
committed by
GitHub
parent
357fb1c5d8
commit
12c1b5b8f4
@@ -1073,7 +1073,7 @@ class ModelTesterMixin:
|
||||
|
||||
# some params shouldn't be scattered by nn.DataParallel
|
||||
# so just remove them if they are present.
|
||||
blacklist_non_batched_params = ["head_mask"]
|
||||
blacklist_non_batched_params = ["head_mask", "decoder_head_mask"]
|
||||
for k in blacklist_non_batched_params:
|
||||
inputs_dict.pop(k, None)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user