From 0c9f01a8e59ee145db977f28bef5a97063424a8d Mon Sep 17 00:00:00 2001 From: LSinev Date: Wed, 13 Jan 2021 15:47:47 +0300 Subject: [PATCH] Speed up TopKLogitsWarper and TopPLogitsWarper (pytorch) (#9557) * make TopKLogitsWarper faster * make TopPLogitsWarper faster --- src/transformers/generation_logits_process.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation_logits_process.py b/src/transformers/generation_logits_process.py index 465cdc4f6d..166cd5aa8c 100644 --- a/src/transformers/generation_logits_process.py +++ b/src/transformers/generation_logits_process.py @@ -20,7 +20,6 @@ from typing import Callable, Iterable, List import numpy as np import torch -from torch.nn import functional as F from .file_utils import add_start_docstrings @@ -191,7 +190,7 @@ class TopPLogitsWarper(LogitsWarper): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: sorted_logits, sorted_indices = torch.sort(scores, descending=True) - cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) sorted_indices_to_remove = cumulative_probs > self.top_p @@ -204,7 +203,7 @@ class TopPLogitsWarper(LogitsWarper): # scatter sorted tensors to original indexing indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) - scores[indices_to_remove] = self.filter_value + scores = scores.masked_fill(indices_to_remove, self.filter_value) return scores @@ -233,7 +232,7 @@ class TopKLogitsWarper(LogitsWarper): top_k = min(max(self.top_k, self.min_tokens_to_keep), scores.size(-1)) # Safety check # Remove all tokens with a probability less than the last token of the top-k indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None] - scores[indices_to_remove] = self.filter_value + scores = scores.masked_fill(indices_to_remove, self.filter_value) return scores