This commit is contained in:
Patrick von Platen
2021-01-19 09:06:24 +01:00
committed by GitHub
parent 357fb1c5d8
commit 12c1b5b8f4
7 changed files with 13 additions and 13 deletions

View File

@@ -1073,7 +1073,7 @@ class ModelTesterMixin:
# some params shouldn't be scattered by nn.DataParallel
# so just remove them if they are present.
blacklist_non_batched_params = ["head_mask"]
blacklist_non_batched_params = ["head_mask", "decoder_head_mask"]
for k in blacklist_non_batched_params:
inputs_dict.pop(k, None)