[Generate Tests] Make sure no tokens are force-generated (#18053)

This commit is contained in:
Patrick von Platen
2022-07-07 15:08:34 +02:00
committed by GitHub
parent 91c4a3ab1a
commit 2544c1434f
6 changed files with 48 additions and 0 deletions

View File

@@ -104,6 +104,12 @@ class PegasusModelTester:
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
# forcing a certain token to be generated, sets all other tokens to -inf
# if however the token to be generated is already at -inf then it can lead token
# `nan` values and thus break generation
self.forced_bos_token_id = None
self.forced_eos_token_id = None
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
@@ -151,6 +157,8 @@ class PegasusModelTester:
eos_token_id=self.eos_token_id,
bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id,
forced_bos_token_id=self.forced_bos_token_id,
forced_eos_token_id=self.forced_eos_token_id,
)
def prepare_config_and_inputs_for_common(self):