[Performance improvement] "Bad tokens ids" optimization (#6064)
* Optimized banned token masking * Avoid duplicate EOS masking if in bad_words_id * Updated mask generation to handle empty banned token list * Addition of unit tests for the updated bad_words_ids masking * Updated timeout handling in `test_postprocess_next_token_scores_large_bad_words_list` unit test * Updated timeout handling in `test_postprocess_next_token_scores_large_bad_words_list` unit test (timeout does not work on Windows) * Moving Marian import to the test context to allow TF only environments to run * Moving imports to torch_available test * Updated operations device and test * Updated operations device and test * Added docstring and comment for in-place scores modification * Moving test to own test_generation_utils, use of lighter models for testing * removed unneded imports in test_modeling_common * revert formatting change for ModelTesterMixin * Updated caching, simplified eos token id test, removed unnecessary @require_torch * formatting compliance
This commit is contained in:
@@ -15,7 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Iterable, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@@ -89,11 +89,12 @@ class GenerationMixin:
|
||||
scores[i, banned_tokens] = -float("inf")
|
||||
|
||||
if bad_words_ids is not None:
|
||||
# Exclude EOS token (already processed)
|
||||
bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids))
|
||||
# calculate a list of banned tokens according to bad words
|
||||
banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)
|
||||
|
||||
for i, banned_tokens in enumerate(banned_tokens):
|
||||
scores[i, banned_tokens] = -float("inf")
|
||||
banned_tokens = calc_banned_bad_words_ids(input_ids.tolist(), bad_words_ids)
|
||||
# Modify the scores in place by setting the banned tokens logits to `-inf`
|
||||
set_scores_to_inf_for_banned_tokens(scores, banned_tokens)
|
||||
|
||||
return scores
|
||||
|
||||
@@ -893,7 +894,7 @@ def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iter
|
||||
bad_words_ids
|
||||
)
|
||||
|
||||
if _tokens_match(prev_input_ids_slice.tolist(), banned_token_seq[:-1]) is False:
|
||||
if _tokens_match(prev_input_ids_slice, banned_token_seq[:-1]) is False:
|
||||
# if tokens do not match continue
|
||||
continue
|
||||
|
||||
@@ -904,6 +905,30 @@ def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iter
|
||||
return banned_tokens
|
||||
|
||||
|
||||
def set_scores_to_inf_for_banned_tokens(scores: torch.Tensor, banned_tokens: List[List[int]]) -> None:
|
||||
""" Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be
|
||||
a list of list of banned tokens to ban in the format [[batch index, vocabulary position],...]
|
||||
Args:
|
||||
scores: logits distribution of shape (batch size, vocabulary size)
|
||||
banned_tokens: list of list of tokens to ban of length (batch_size)
|
||||
"""
|
||||
banned_mask_list = []
|
||||
for idx, batch_banned_tokens in enumerate(banned_tokens):
|
||||
for token in batch_banned_tokens:
|
||||
banned_mask_list.append([idx, token])
|
||||
if not banned_mask_list:
|
||||
return
|
||||
banned_mask = torch.LongTensor(banned_mask_list)
|
||||
indices = torch.ones(len(banned_mask))
|
||||
# A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates:
|
||||
# [ 0 1 1 ]
|
||||
# [ 0 0 0 ]
|
||||
# [ 1 0 0 ]
|
||||
|
||||
banned_mask = torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()).to(scores.device).to_dense().bool()
|
||||
scores.masked_fill_(banned_mask, -float("inf"))
|
||||
|
||||
|
||||
def top_k_top_p_filtering(
|
||||
logits: Tensor,
|
||||
top_k: int = 0,
|
||||
|
||||
Reference in New Issue
Block a user