Speed up TopKLogitsWarper and TopPLogitsWarper (pytorch) (#9557)
* make TopKLogitsWarper faster * make TopPLogitsWarper faster
This commit is contained in:
@@ -20,7 +20,6 @@ from typing import Callable, Iterable, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .file_utils import add_start_docstrings
|
||||
|
||||
@@ -191,7 +190,7 @@ class TopPLogitsWarper(LogitsWarper):
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
|
||||
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
||||
|
||||
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
||||
sorted_indices_to_remove = cumulative_probs > self.top_p
|
||||
@@ -204,7 +203,7 @@ class TopPLogitsWarper(LogitsWarper):
|
||||
|
||||
# scatter sorted tensors to original indexing
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
scores[indices_to_remove] = self.filter_value
|
||||
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||
return scores
|
||||
|
||||
|
||||
@@ -233,7 +232,7 @@ class TopKLogitsWarper(LogitsWarper):
|
||||
top_k = min(max(self.top_k, self.min_tokens_to_keep), scores.size(-1)) # Safety check
|
||||
# Remove all tokens with a probability less than the last token of the top-k
|
||||
indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None]
|
||||
scores[indices_to_remove] = self.filter_value
|
||||
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||
return scores
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user