[Performance improvement] "Bad tokens ids" optimization (#6064)

* Optimized banned token masking

* Avoid duplicate EOS masking if in bad_words_id

* Updated mask generation to handle empty banned token list

* Addition of unit tests for the updated bad_words_ids masking

* Updated timeout handling in `test_postprocess_next_token_scores_large_bad_words_list` unit test

* Updated timeout handling in `test_postprocess_next_token_scores_large_bad_words_list` unit test (timeout does not work on Windows)

* Moving Marian import to the test context to allow TF only environments to run

* Moving imports to torch_available test

* Updated operations device and test

* Updated operations device and test

* Added docstring and comment for in-place scores modification

* Moving test to own test_generation_utils, use of lighter models for testing

* removed unneded imports in test_modeling_common

* revert formatting change for ModelTesterMixin

* Updated caching, simplified eos token id test, removed unnecessary @require_torch

* formatting compliance
This commit is contained in:
guillaume-be
2020-08-11 11:56:40 +02:00
committed by GitHub
parent 87e124c245
commit 404782912a
2 changed files with 121 additions and 6 deletions

View File

@@ -15,7 +15,7 @@
# limitations under the License.
import logging
from typing import Iterable, Optional, Tuple
from typing import Iterable, List, Optional, Tuple
import torch
from torch import Tensor
@@ -89,11 +89,12 @@ class GenerationMixin:
scores[i, banned_tokens] = -float("inf")
if bad_words_ids is not None:
# Exclude EOS token (already processed)
bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids))
# calculate a list of banned tokens according to bad words
banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)
for i, banned_tokens in enumerate(banned_tokens):
scores[i, banned_tokens] = -float("inf")
banned_tokens = calc_banned_bad_words_ids(input_ids.tolist(), bad_words_ids)
# Modify the scores in place by setting the banned tokens logits to `-inf`
set_scores_to_inf_for_banned_tokens(scores, banned_tokens)
return scores
@@ -893,7 +894,7 @@ def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iter
bad_words_ids
)
if _tokens_match(prev_input_ids_slice.tolist(), banned_token_seq[:-1]) is False:
if _tokens_match(prev_input_ids_slice, banned_token_seq[:-1]) is False:
# if tokens do not match continue
continue
@@ -904,6 +905,30 @@ def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iter
return banned_tokens
def set_scores_to_inf_for_banned_tokens(scores: torch.Tensor, banned_tokens: List[List[int]]) -> None:
""" Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be
a list of list of banned tokens to ban in the format [[batch index, vocabulary position],...]
Args:
scores: logits distribution of shape (batch size, vocabulary size)
banned_tokens: list of list of tokens to ban of length (batch_size)
"""
banned_mask_list = []
for idx, batch_banned_tokens in enumerate(banned_tokens):
for token in batch_banned_tokens:
banned_mask_list.append([idx, token])
if not banned_mask_list:
return
banned_mask = torch.LongTensor(banned_mask_list)
indices = torch.ones(len(banned_mask))
# A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates:
# [ 0 1 1 ]
# [ 0 0 0 ]
# [ 1 0 0 ]
banned_mask = torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()).to(scores.device).to_dense().bool()
scores.masked_fill_(banned_mask, -float("inf"))
def top_k_top_p_filtering(
logits: Tensor,
top_k: int = 0,