revert erroneous fix

This commit is contained in:
James Noeckel
2019-12-24 22:26:09 -08:00
parent 81db12c3ba
commit 9fb7addd4d

View File

@@ -958,9 +958,7 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf")
sorted_indices_to_remove[..., 0] = 0 sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing # scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter( indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
dim=1, index=sorted_indices, source=sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value logits[indices_to_remove] = filter_value
return logits return logits