[Generate Tests] Make sure no tokens are force-generated (#18053)
This commit is contained in:
committed by
GitHub
parent
91c4a3ab1a
commit
2544c1434f
@@ -116,6 +116,12 @@ class BartModelTester:
|
|||||||
self.pad_token_id = pad_token_id
|
self.pad_token_id = pad_token_id
|
||||||
self.bos_token_id = bos_token_id
|
self.bos_token_id = bos_token_id
|
||||||
|
|
||||||
|
# forcing a certain token to be generated, sets all other tokens to -inf
|
||||||
|
# if however the token to be generated is already at -inf then it can lead token
|
||||||
|
# `nan` values and thus break generation
|
||||||
|
self.forced_bos_token_id = None
|
||||||
|
self.forced_eos_token_id = None
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
|
||||||
@@ -145,6 +151,8 @@ class BartModelTester:
|
|||||||
eos_token_id=self.eos_token_id,
|
eos_token_id=self.eos_token_id,
|
||||||
bos_token_id=self.bos_token_id,
|
bos_token_id=self.bos_token_id,
|
||||||
pad_token_id=self.pad_token_id,
|
pad_token_id=self.pad_token_id,
|
||||||
|
forced_bos_token_id=self.forced_bos_token_id,
|
||||||
|
forced_eos_token_id=self.forced_eos_token_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_pipeline_config(self):
|
def get_pipeline_config(self):
|
||||||
|
|||||||
@@ -107,6 +107,12 @@ class BlenderbotModelTester:
|
|||||||
self.pad_token_id = pad_token_id
|
self.pad_token_id = pad_token_id
|
||||||
self.bos_token_id = bos_token_id
|
self.bos_token_id = bos_token_id
|
||||||
|
|
||||||
|
# forcing a certain token to be generated, sets all other tokens to -inf
|
||||||
|
# if however the token to be generated is already at -inf then it can lead token
|
||||||
|
# `nan` values and thus break generation
|
||||||
|
self.forced_bos_token_id = None
|
||||||
|
self.forced_eos_token_id = None
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
|
||||||
3,
|
3,
|
||||||
@@ -135,6 +141,8 @@ class BlenderbotModelTester:
|
|||||||
eos_token_id=self.eos_token_id,
|
eos_token_id=self.eos_token_id,
|
||||||
bos_token_id=self.bos_token_id,
|
bos_token_id=self.bos_token_id,
|
||||||
pad_token_id=self.pad_token_id,
|
pad_token_id=self.pad_token_id,
|
||||||
|
forced_bos_token_id=self.forced_bos_token_id,
|
||||||
|
forced_eos_token_id=self.forced_eos_token_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_pipeline_config(self):
|
def get_pipeline_config(self):
|
||||||
|
|||||||
@@ -107,6 +107,12 @@ class BlenderbotSmallModelTester:
|
|||||||
self.pad_token_id = pad_token_id
|
self.pad_token_id = pad_token_id
|
||||||
self.bos_token_id = bos_token_id
|
self.bos_token_id = bos_token_id
|
||||||
|
|
||||||
|
# forcing a certain token to be generated, sets all other tokens to -inf
|
||||||
|
# if however the token to be generated is already at -inf then it can lead token
|
||||||
|
# `nan` values and thus break generation
|
||||||
|
self.forced_bos_token_id = None
|
||||||
|
self.forced_eos_token_id = None
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
|
||||||
3,
|
3,
|
||||||
@@ -135,6 +141,8 @@ class BlenderbotSmallModelTester:
|
|||||||
eos_token_id=self.eos_token_id,
|
eos_token_id=self.eos_token_id,
|
||||||
bos_token_id=self.bos_token_id,
|
bos_token_id=self.bos_token_id,
|
||||||
pad_token_id=self.pad_token_id,
|
pad_token_id=self.pad_token_id,
|
||||||
|
forced_bos_token_id=self.forced_bos_token_id,
|
||||||
|
forced_eos_token_id=self.forced_eos_token_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
|||||||
@@ -123,6 +123,12 @@ class MarianModelTester:
|
|||||||
self.bos_token_id = bos_token_id
|
self.bos_token_id = bos_token_id
|
||||||
self.decoder_start_token_id = decoder_start_token_id
|
self.decoder_start_token_id = decoder_start_token_id
|
||||||
|
|
||||||
|
# forcing a certain token to be generated, sets all other tokens to -inf
|
||||||
|
# if however the token to be generated is already at -inf then it can lead token
|
||||||
|
# `nan` values and thus break generation
|
||||||
|
self.forced_bos_token_id = None
|
||||||
|
self.forced_eos_token_id = None
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
|
||||||
3,
|
3,
|
||||||
@@ -152,6 +158,8 @@ class MarianModelTester:
|
|||||||
bos_token_id=self.bos_token_id,
|
bos_token_id=self.bos_token_id,
|
||||||
pad_token_id=self.pad_token_id,
|
pad_token_id=self.pad_token_id,
|
||||||
decoder_start_token_id=self.decoder_start_token_id,
|
decoder_start_token_id=self.decoder_start_token_id,
|
||||||
|
forced_bos_token_id=self.forced_bos_token_id,
|
||||||
|
forced_eos_token_id=self.forced_eos_token_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
|||||||
@@ -113,6 +113,12 @@ class MBartModelTester:
|
|||||||
self.pad_token_id = pad_token_id
|
self.pad_token_id = pad_token_id
|
||||||
self.bos_token_id = bos_token_id
|
self.bos_token_id = bos_token_id
|
||||||
|
|
||||||
|
# forcing a certain token to be generated, sets all other tokens to -inf
|
||||||
|
# if however the token to be generated is already at -inf then it can lead token
|
||||||
|
# `nan` values and thus break generation
|
||||||
|
self.forced_bos_token_id = None
|
||||||
|
self.forced_eos_token_id = None
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
|
||||||
@@ -142,6 +148,8 @@ class MBartModelTester:
|
|||||||
eos_token_id=self.eos_token_id,
|
eos_token_id=self.eos_token_id,
|
||||||
bos_token_id=self.bos_token_id,
|
bos_token_id=self.bos_token_id,
|
||||||
pad_token_id=self.pad_token_id,
|
pad_token_id=self.pad_token_id,
|
||||||
|
forced_bos_token_id=self.forced_bos_token_id,
|
||||||
|
forced_eos_token_id=self.forced_eos_token_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
|||||||
@@ -104,6 +104,12 @@ class PegasusModelTester:
|
|||||||
self.pad_token_id = pad_token_id
|
self.pad_token_id = pad_token_id
|
||||||
self.bos_token_id = bos_token_id
|
self.bos_token_id = bos_token_id
|
||||||
|
|
||||||
|
# forcing a certain token to be generated, sets all other tokens to -inf
|
||||||
|
# if however the token to be generated is already at -inf then it can lead token
|
||||||
|
# `nan` values and thus break generation
|
||||||
|
self.forced_bos_token_id = None
|
||||||
|
self.forced_eos_token_id = None
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
|
||||||
@@ -151,6 +157,8 @@ class PegasusModelTester:
|
|||||||
eos_token_id=self.eos_token_id,
|
eos_token_id=self.eos_token_id,
|
||||||
bos_token_id=self.bos_token_id,
|
bos_token_id=self.bos_token_id,
|
||||||
pad_token_id=self.pad_token_id,
|
pad_token_id=self.pad_token_id,
|
||||||
|
forced_bos_token_id=self.forced_bos_token_id,
|
||||||
|
forced_eos_token_id=self.forced_eos_token_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user