From 398bb03f9865d3e30cd2eda4b8e0a9ca3402c7a7 Mon Sep 17 00:00:00 2001 From: James Noeckel Date: Sun, 22 Dec 2019 23:18:41 -0800 Subject: [PATCH] fix out-of-place call to scatter, whose named argument name is source, not src --- src/transformers/modeling_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 0ec6972841..e9ecfd60dc 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -958,7 +958,9 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf") sorted_indices_to_remove[..., 0] = 0 # scatter sorted tensors to original indexing - indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove) + indices_to_remove = sorted_indices_to_remove.scatter( + dim=1, index=sorted_indices, source=sorted_indices_to_remove + ) logits[indices_to_remove] = filter_value return logits