From e1844d9a45c6232ba7fac77c6c7b3f5326e72929 Mon Sep 17 00:00:00 2001 From: James Noeckel Date: Wed, 25 Dec 2019 01:34:02 -0800 Subject: [PATCH] use positional arguments due to inconsistent API --- src/transformers/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 061e1ba57d..054de04dcf 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -958,7 +958,7 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf") sorted_indices_to_remove[..., 0] = 0 # scatter sorted tensors to original indexing - indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove) + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) logits[indices_to_remove] = filter_value return logits