[Performance improvement] "Bad tokens ids" optimization (#6064)
* Optimized banned token masking * Avoid duplicate EOS masking if in bad_words_id * Updated mask generation to handle empty banned token list * Addition of unit tests for the updated bad_words_ids masking * Updated timeout handling in `test_postprocess_next_token_scores_large_bad_words_list` unit test * Updated timeout handling in `test_postprocess_next_token_scores_large_bad_words_list` unit test (timeout does not work on Windows) * Moving Marian import to the test context to allow TF only environments to run * Moving imports to torch_available test * Updated operations device and test * Updated operations device and test * Added docstring and comment for in-place scores modification * Moving test to own test_generation_utils, use of lighter models for testing * removed unneded imports in test_modeling_common * revert formatting change for ModelTesterMixin * Updated caching, simplified eos token id test, removed unnecessary @require_torch * formatting compliance
This commit is contained in:
90
src/transformers/data/test_generation_utils.py
Normal file
90
src/transformers/data/test_generation_utils.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
import random
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import timeout_decorator
|
||||||
|
|
||||||
|
from transformers import is_torch_available
|
||||||
|
from transformers.file_utils import cached_property
|
||||||
|
from transformers.testing_utils import require_torch
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
MarianConfig,
|
||||||
|
MarianMTModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class GenerationUtilsTest(unittest.TestCase):
|
||||||
|
@cached_property
|
||||||
|
def config(self):
|
||||||
|
config = MarianConfig.from_pretrained("sshleifer/tiny-marian-en-de")
|
||||||
|
return config
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def model(self):
|
||||||
|
return MarianMTModel(self.config)
|
||||||
|
|
||||||
|
def test_postprocess_next_token_scores(self):
|
||||||
|
config = self.config
|
||||||
|
model = self.model
|
||||||
|
# Initialize an input id tensor with batch size 8 and sequence length 12
|
||||||
|
input_ids = torch.arange(0, 96, 1).view((8, 12))
|
||||||
|
eos = config.eos_token_id
|
||||||
|
bad_words_ids_test_cases = [[[299]], [[23, 24], [54]], [[config.eos_token_id]], []]
|
||||||
|
masked_scores = [
|
||||||
|
[(0, 299), (1, 299), (2, 299), (3, 299), (4, 299), (5, 299), (6, 299), (7, 299)],
|
||||||
|
[(1, 24), (0, 54), (1, 54), (2, 54), (3, 54), (4, 54), (5, 54), (6, 54), (7, 54)],
|
||||||
|
[(0, eos), (1, eos), (2, eos), (3, eos), (4, eos), (5, eos), (6, eos), (7, eos)],
|
||||||
|
[],
|
||||||
|
]
|
||||||
|
|
||||||
|
for test_case_index, bad_words_ids in enumerate(bad_words_ids_test_cases):
|
||||||
|
# Initialize a scores tensor with batch size 8 and vocabulary size 300
|
||||||
|
scores = torch.rand((8, 300))
|
||||||
|
output = model.postprocess_next_token_scores(
|
||||||
|
scores,
|
||||||
|
input_ids,
|
||||||
|
0,
|
||||||
|
bad_words_ids,
|
||||||
|
13,
|
||||||
|
15,
|
||||||
|
config.max_length,
|
||||||
|
config.eos_token_id,
|
||||||
|
config.repetition_penalty,
|
||||||
|
32,
|
||||||
|
5,
|
||||||
|
)
|
||||||
|
for masked_score in masked_scores[test_case_index]:
|
||||||
|
self.assertTrue(output[masked_score[0], masked_score[1]] == -float("inf"))
|
||||||
|
|
||||||
|
@timeout_decorator.timeout(10)
|
||||||
|
def test_postprocess_next_token_scores_large_bad_words_list(self):
|
||||||
|
|
||||||
|
config = self.config
|
||||||
|
model = self.model
|
||||||
|
# Initialize an input id tensor with batch size 8 and sequence length 12
|
||||||
|
input_ids = torch.arange(0, 96, 1).view((8, 12))
|
||||||
|
|
||||||
|
bad_words_ids = []
|
||||||
|
for _ in range(100):
|
||||||
|
length_bad_word = random.randint(1, 4)
|
||||||
|
bad_words_ids.append(random.sample(range(1, 300), length_bad_word))
|
||||||
|
|
||||||
|
scores = torch.rand((8, 300))
|
||||||
|
_ = model.postprocess_next_token_scores(
|
||||||
|
scores,
|
||||||
|
input_ids,
|
||||||
|
0,
|
||||||
|
bad_words_ids,
|
||||||
|
13,
|
||||||
|
15,
|
||||||
|
config.max_length,
|
||||||
|
config.eos_token_id,
|
||||||
|
config.repetition_penalty,
|
||||||
|
32,
|
||||||
|
5,
|
||||||
|
)
|
||||||
@@ -15,7 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Iterable, Optional, Tuple
|
from typing import Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
@@ -89,11 +89,12 @@ class GenerationMixin:
|
|||||||
scores[i, banned_tokens] = -float("inf")
|
scores[i, banned_tokens] = -float("inf")
|
||||||
|
|
||||||
if bad_words_ids is not None:
|
if bad_words_ids is not None:
|
||||||
|
# Exclude EOS token (already processed)
|
||||||
|
bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids))
|
||||||
# calculate a list of banned tokens according to bad words
|
# calculate a list of banned tokens according to bad words
|
||||||
banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)
|
banned_tokens = calc_banned_bad_words_ids(input_ids.tolist(), bad_words_ids)
|
||||||
|
# Modify the scores in place by setting the banned tokens logits to `-inf`
|
||||||
for i, banned_tokens in enumerate(banned_tokens):
|
set_scores_to_inf_for_banned_tokens(scores, banned_tokens)
|
||||||
scores[i, banned_tokens] = -float("inf")
|
|
||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
@@ -893,7 +894,7 @@ def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iter
|
|||||||
bad_words_ids
|
bad_words_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
if _tokens_match(prev_input_ids_slice.tolist(), banned_token_seq[:-1]) is False:
|
if _tokens_match(prev_input_ids_slice, banned_token_seq[:-1]) is False:
|
||||||
# if tokens do not match continue
|
# if tokens do not match continue
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -904,6 +905,30 @@ def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iter
|
|||||||
return banned_tokens
|
return banned_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def set_scores_to_inf_for_banned_tokens(scores: torch.Tensor, banned_tokens: List[List[int]]) -> None:
|
||||||
|
""" Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be
|
||||||
|
a list of list of banned tokens to ban in the format [[batch index, vocabulary position],...]
|
||||||
|
Args:
|
||||||
|
scores: logits distribution of shape (batch size, vocabulary size)
|
||||||
|
banned_tokens: list of list of tokens to ban of length (batch_size)
|
||||||
|
"""
|
||||||
|
banned_mask_list = []
|
||||||
|
for idx, batch_banned_tokens in enumerate(banned_tokens):
|
||||||
|
for token in batch_banned_tokens:
|
||||||
|
banned_mask_list.append([idx, token])
|
||||||
|
if not banned_mask_list:
|
||||||
|
return
|
||||||
|
banned_mask = torch.LongTensor(banned_mask_list)
|
||||||
|
indices = torch.ones(len(banned_mask))
|
||||||
|
# A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates:
|
||||||
|
# [ 0 1 1 ]
|
||||||
|
# [ 0 0 0 ]
|
||||||
|
# [ 1 0 0 ]
|
||||||
|
|
||||||
|
banned_mask = torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()).to(scores.device).to_dense().bool()
|
||||||
|
scores.masked_fill_(banned_mask, -float("inf"))
|
||||||
|
|
||||||
|
|
||||||
def top_k_top_p_filtering(
|
def top_k_top_p_filtering(
|
||||||
logits: Tensor,
|
logits: Tensor,
|
||||||
top_k: int = 0,
|
top_k: int = 0,
|
||||||
|
|||||||
Reference in New Issue
Block a user