Speed up TopKLogitsWarper and TopPLogitsWarper (pytorch) (#9557)

* make TopKLogitsWarper faster

* make TopPLogitsWarper faster
This commit is contained in:
LSinev
2021-01-13 15:47:47 +03:00
committed by GitHub
parent 27d0e01d75
commit 0c9f01a8e5

View File

@@ -20,7 +20,6 @@ from typing import Callable, Iterable, List
import numpy as np
import torch
from torch.nn import functional as F
from .file_utils import add_start_docstrings
@@ -191,7 +190,7 @@ class TopPLogitsWarper(LogitsWarper):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > self.top_p
@@ -204,7 +203,7 @@ class TopPLogitsWarper(LogitsWarper):
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores[indices_to_remove] = self.filter_value
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores
@@ -233,7 +232,7 @@ class TopKLogitsWarper(LogitsWarper):
top_k = min(max(self.top_k, self.min_tokens_to_keep), scores.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None]
scores[indices_to_remove] = self.filter_value
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores