[Generate, Test] Split generate test function into beam search, no beam search (#3601)
* split beam search and no beam search test * fix test * clean generate tests
This commit is contained in:
committed by
GitHub
parent
1789c7daf1
commit
2ee410560e
@@ -624,70 +624,96 @@ class ModelTesterMixin:
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
model(**inputs_dict)
|
model(**inputs_dict)
|
||||||
|
|
||||||
def test_lm_head_model_random_generate(self):
|
def test_lm_head_model_random_no_beam_search_generate(self):
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
input_ids = inputs_dict.get("input_ids")
|
input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]
|
||||||
|
|
||||||
|
# iterate over all generative models
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
if config.bos_token_id is None:
|
||||||
|
# if bos token id is not defined mobel needs input_ids
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
model.generate(do_sample=True, max_length=5)
|
||||||
|
# num_return_sequences = 1
|
||||||
|
self._check_generated_ids(model.generate(input_ids, do_sample=True))
|
||||||
|
else:
|
||||||
|
# num_return_sequences = 1
|
||||||
|
self._check_generated_ids(model.generate(do_sample=True, max_length=5))
|
||||||
|
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
# generating multiple sequences when no beam search generation
|
||||||
|
# is not allowed as it would always generate the same sequences
|
||||||
|
model.generate(input_ids, do_sample=False, num_return_sequences=2)
|
||||||
|
|
||||||
|
# num_return_sequences > 1, sample
|
||||||
|
self._check_generated_ids(model.generate(input_ids, do_sample=True, num_return_sequences=2))
|
||||||
|
|
||||||
|
# check bad words tokens language generation
|
||||||
|
# create list of 1-seq bad token and list of 2-seq of bad tokens
|
||||||
|
bad_words_ids = [self._generate_random_bad_tokens(1, model), self._generate_random_bad_tokens(2, model)]
|
||||||
|
output_tokens = model.generate(
|
||||||
|
input_ids, do_sample=True, bad_words_ids=bad_words_ids, num_return_sequences=2
|
||||||
|
)
|
||||||
|
# only count generated tokens
|
||||||
|
generated_ids = output_tokens[:, input_ids.shape[-1] :]
|
||||||
|
self.assertFalse(self._check_match_tokens(generated_ids.tolist(), bad_words_ids))
|
||||||
|
|
||||||
|
def test_lm_head_model_random_beam_search_generate(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]
|
||||||
|
|
||||||
if self.is_encoder_decoder:
|
if self.is_encoder_decoder:
|
||||||
config.output_past = True # needed for Bart TODO: might have to update for other encoder-decoder models
|
# needed for Bart beam search
|
||||||
|
config.output_past = True
|
||||||
|
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
model.to(torch_device)
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
if config.bos_token_id is None:
|
if config.bos_token_id is None:
|
||||||
with self.assertRaises(AssertionError):
|
# if bos token id is not defined mobel needs input_ids, num_return_sequences = 1
|
||||||
model.generate(do_sample=True, max_length=5)
|
self._check_generated_ids(model.generate(input_ids, do_sample=True, num_beams=2))
|
||||||
# batch_size = 1
|
|
||||||
self._check_generated_ids(model.generate(input_ids, do_sample=True))
|
|
||||||
# batch_size = 1, num_beams > 1
|
|
||||||
self._check_generated_ids(model.generate(input_ids, do_sample=True, num_beams=3))
|
|
||||||
else:
|
else:
|
||||||
# batch_size = 1
|
# num_return_sequences = 1
|
||||||
self._check_generated_ids(model.generate(do_sample=True, max_length=5))
|
self._check_generated_ids(model.generate(do_sample=True, max_length=5, num_beams=2))
|
||||||
# batch_size = 1, num_beams > 1
|
|
||||||
self._check_generated_ids(model.generate(do_sample=True, max_length=5, num_beams=3))
|
|
||||||
|
|
||||||
with self.assertRaises(AssertionError):
|
|
||||||
# generating multiple sequences when greedy no beam generation
|
|
||||||
# is not allowed as it would always generate the same sequences
|
|
||||||
model.generate(input_ids, do_sample=False, num_return_sequences=2)
|
|
||||||
|
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
# generating more sequences than having beams leads is not possible
|
# generating more sequences than having beams leads is not possible
|
||||||
model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2)
|
model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2)
|
||||||
|
|
||||||
# batch_size > 1, sample
|
# num_return_sequences > 1, sample
|
||||||
self._check_generated_ids(model.generate(input_ids, do_sample=True, num_return_sequences=3))
|
self._check_generated_ids(model.generate(input_ids, do_sample=True, num_beams=2, num_return_sequences=2,))
|
||||||
# batch_size > 1, greedy
|
# num_return_sequences > 1, greedy
|
||||||
self._check_generated_ids(model.generate(input_ids, do_sample=False))
|
self._check_generated_ids(model.generate(input_ids, do_sample=False, num_beams=2, num_return_sequences=2))
|
||||||
|
|
||||||
# batch_size > 1, num_beams > 1, sample
|
|
||||||
self._check_generated_ids(model.generate(input_ids, do_sample=True, num_beams=3, num_return_sequences=3,))
|
|
||||||
# batch_size > 1, num_beams > 1, greedy
|
|
||||||
self._check_generated_ids(model.generate(input_ids, do_sample=False, num_beams=3, num_return_sequences=3))
|
|
||||||
|
|
||||||
# check bad words tokens language generation
|
# check bad words tokens language generation
|
||||||
bad_words_ids = [
|
# create list of 1-seq bad token and list of 2-seq of bad tokens
|
||||||
ids_tensor((1, 1), self.model_tester.vocab_size).squeeze(-1).tolist(),
|
bad_words_ids = [self._generate_random_bad_tokens(1, model), self._generate_random_bad_tokens(2, model)]
|
||||||
ids_tensor((2, 1), self.model_tester.vocab_size).squeeze(-1).tolist(),
|
|
||||||
]
|
|
||||||
|
|
||||||
# sampling
|
|
||||||
output_tokens = model.generate(
|
output_tokens = model.generate(
|
||||||
input_ids, do_sample=True, bad_words_ids=bad_words_ids, num_return_sequences=3
|
input_ids, do_sample=False, bad_words_ids=bad_words_ids, num_beams=2, num_return_sequences=2
|
||||||
)
|
)
|
||||||
|
# only count generated tokens
|
||||||
generated_ids = output_tokens[:, input_ids.shape[-1] :]
|
generated_ids = output_tokens[:, input_ids.shape[-1] :]
|
||||||
self.assertFalse(self._check_match_tokens(generated_ids.tolist(), bad_words_ids))
|
self.assertFalse(self._check_match_tokens(generated_ids.tolist(), bad_words_ids))
|
||||||
|
|
||||||
# beam search
|
def _generate_random_bad_tokens(self, num_bad_tokens, model):
|
||||||
output_tokens = model.generate(
|
# special tokens cannot be bad tokens
|
||||||
input_ids, do_sample=False, bad_words_ids=bad_words_ids, num_beams=3, num_return_sequences=3
|
special_tokens = []
|
||||||
)
|
if model.config.bos_token_id is not None:
|
||||||
generated_ids = output_tokens[:, input_ids.shape[-1] :]
|
special_tokens.append(model.config.bos_token_id)
|
||||||
self.assertFalse(self._check_match_tokens(generated_ids.tolist(), bad_words_ids))
|
if model.config.pad_token_id is not None:
|
||||||
|
special_tokens.append(model.config.pad_token_id)
|
||||||
|
if model.config.eos_token_id is not None:
|
||||||
|
special_tokens.append(model.config.eos_token_id)
|
||||||
|
|
||||||
|
# create random bad tokens that are not special tokens
|
||||||
|
bad_tokens = []
|
||||||
|
while len(bad_tokens) < num_bad_tokens:
|
||||||
|
token = ids_tensor((1, 1), self.model_tester.vocab_size).squeeze(0).numpy()[0]
|
||||||
|
if token not in special_tokens:
|
||||||
|
bad_tokens.append(token)
|
||||||
|
return bad_tokens
|
||||||
|
|
||||||
def _check_generated_ids(self, output_ids):
|
def _check_generated_ids(self, output_ids):
|
||||||
for token_id in output_ids[0].tolist():
|
for token_id in output_ids[0].tolist():
|
||||||
|
|||||||
@@ -420,68 +420,96 @@ class TFModelTesterMixin:
|
|||||||
|
|
||||||
model(inputs_dict)
|
model(inputs_dict)
|
||||||
|
|
||||||
def test_lm_head_model_random_generate(self):
|
def test_lm_head_model_random_no_beam_search_generate(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]
|
||||||
|
|
||||||
|
# iterate over all generative models
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
if config.bos_token_id is None:
|
||||||
|
# if bos token id is not defined mobel needs input_ids
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
model.generate(do_sample=True, max_length=5)
|
||||||
|
# num_return_sequences = 1
|
||||||
|
self._check_generated_ids(model.generate(input_ids, do_sample=True))
|
||||||
|
else:
|
||||||
|
# num_return_sequences = 1
|
||||||
|
self._check_generated_ids(model.generate(do_sample=True, max_length=5))
|
||||||
|
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
# generating multiple sequences when no beam search generation
|
||||||
|
# is not allowed as it would always generate the same sequences
|
||||||
|
model.generate(input_ids, do_sample=False, num_return_sequences=2)
|
||||||
|
|
||||||
|
# num_return_sequences > 1, sample
|
||||||
|
self._check_generated_ids(model.generate(input_ids, do_sample=True, num_return_sequences=2))
|
||||||
|
|
||||||
|
# check bad words tokens language generation
|
||||||
|
# create list of 1-seq bad token and list of 2-seq of bad tokens
|
||||||
|
bad_words_ids = [self._generate_random_bad_tokens(1, model), self._generate_random_bad_tokens(2, model)]
|
||||||
|
output_tokens = model.generate(
|
||||||
|
input_ids, do_sample=True, bad_words_ids=bad_words_ids, num_return_sequences=2
|
||||||
|
)
|
||||||
|
# only count generated tokens
|
||||||
|
generated_ids = output_tokens[:, input_ids.shape[-1] :]
|
||||||
|
self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
|
||||||
|
|
||||||
|
def test_lm_head_model_random_beam_search_generate(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]
|
input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]
|
||||||
|
|
||||||
if self.is_encoder_decoder:
|
if self.is_encoder_decoder:
|
||||||
config.output_past = True # needed for Bart TODO: might have to update for other encoder-decoder models
|
# needed for Bart beam search
|
||||||
|
config.output_past = True
|
||||||
|
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
|
|
||||||
if config.bos_token_id is None:
|
if config.bos_token_id is None:
|
||||||
with self.assertRaises(AssertionError):
|
# if bos token id is not defined mobel needs input_ids, num_return_sequences = 1
|
||||||
model.generate(do_sample=True, max_length=5)
|
self._check_generated_ids(model.generate(input_ids, do_sample=True, num_beams=2))
|
||||||
# batch_size = 1
|
|
||||||
self._check_generated_ids(model.generate(input_ids, do_sample=True))
|
|
||||||
# batch_size = 1, num_beams > 1
|
|
||||||
self._check_generated_ids(model.generate(input_ids, do_sample=True, num_beams=3))
|
|
||||||
else:
|
else:
|
||||||
# batch_size = 1
|
# num_return_sequences = 1
|
||||||
self._check_generated_ids(model.generate(do_sample=True, max_length=5))
|
self._check_generated_ids(model.generate(do_sample=True, max_length=5, num_beams=2))
|
||||||
# batch_size = 1, num_beams > 1
|
|
||||||
self._check_generated_ids(model.generate(do_sample=True, max_length=5, num_beams=3))
|
|
||||||
|
|
||||||
with self.assertRaises(AssertionError):
|
|
||||||
# generating multiple sequences when greedy no beam generation
|
|
||||||
# is not allowed as it would always generate the same sequences
|
|
||||||
model.generate(input_ids, do_sample=False, num_return_sequences=2)
|
|
||||||
|
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
# generating more sequences than having beams leads is not possible
|
# generating more sequences than having beams leads is not possible
|
||||||
model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2)
|
model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2)
|
||||||
|
|
||||||
# batch_size > 1, sample
|
# num_return_sequences > 1, sample
|
||||||
self._check_generated_ids(model.generate(input_ids, do_sample=True, num_return_sequences=3))
|
self._check_generated_ids(model.generate(input_ids, do_sample=True, num_beams=2, num_return_sequences=2,))
|
||||||
# batch_size > 1, greedy
|
# num_return_sequences > 1, greedy
|
||||||
self._check_generated_ids(model.generate(input_ids, do_sample=False))
|
self._check_generated_ids(model.generate(input_ids, do_sample=False, num_beams=2, num_return_sequences=2))
|
||||||
|
|
||||||
# batch_size > 1, num_beams > 1, sample
|
|
||||||
self._check_generated_ids(model.generate(input_ids, do_sample=True, num_beams=3, num_return_sequences=3,))
|
|
||||||
# batch_size > 1, num_beams > 1, greedy
|
|
||||||
self._check_generated_ids(model.generate(input_ids, do_sample=False, num_beams=3, num_return_sequences=3))
|
|
||||||
|
|
||||||
# check bad words tokens language generation
|
# check bad words tokens language generation
|
||||||
bad_words_ids = [
|
# create list of 1-seq bad token and list of 2-seq of bad tokens
|
||||||
tf.squeeze(ids_tensor((1, 1), self.model_tester.vocab_size), -1).numpy().tolist(),
|
bad_words_ids = [self._generate_random_bad_tokens(1, model), self._generate_random_bad_tokens(2, model)]
|
||||||
tf.squeeze(ids_tensor((2, 1), self.model_tester.vocab_size), -1).numpy().tolist(),
|
|
||||||
]
|
|
||||||
|
|
||||||
# sampling
|
|
||||||
output_tokens = model.generate(
|
output_tokens = model.generate(
|
||||||
input_ids, do_sample=True, bad_words_ids=bad_words_ids, num_return_sequences=3
|
input_ids, do_sample=False, bad_words_ids=bad_words_ids, num_beams=2, num_return_sequences=2
|
||||||
)
|
)
|
||||||
|
# only count generated tokens
|
||||||
generated_ids = output_tokens[:, input_ids.shape[-1] :]
|
generated_ids = output_tokens[:, input_ids.shape[-1] :]
|
||||||
self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
|
self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
|
||||||
|
|
||||||
# beam search
|
def _generate_random_bad_tokens(self, num_bad_tokens, model):
|
||||||
output_tokens = model.generate(
|
# special tokens cannot be bad tokens
|
||||||
input_ids, do_sample=False, bad_words_ids=bad_words_ids, num_beams=3, num_return_sequences=3
|
special_tokens = []
|
||||||
)
|
if model.config.bos_token_id is not None:
|
||||||
generated_ids = output_tokens[:, input_ids.shape[-1] :]
|
special_tokens.append(model.config.bos_token_id)
|
||||||
self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
|
if model.config.pad_token_id is not None:
|
||||||
|
special_tokens.append(model.config.pad_token_id)
|
||||||
|
if model.config.eos_token_id is not None:
|
||||||
|
special_tokens.append(model.config.eos_token_id)
|
||||||
|
|
||||||
|
# create random bad tokens that are not special tokens
|
||||||
|
bad_tokens = []
|
||||||
|
while len(bad_tokens) < num_bad_tokens:
|
||||||
|
token = tf.squeeze(ids_tensor((1, 1), self.model_tester.vocab_size), 0).numpy()[0]
|
||||||
|
if token not in special_tokens:
|
||||||
|
bad_tokens.append(token)
|
||||||
|
return bad_tokens
|
||||||
|
|
||||||
def _check_generated_ids(self, output_ids):
|
def _check_generated_ids(self, output_ids):
|
||||||
for token_id in output_ids[0].numpy().tolist():
|
for token_id in output_ids[0].numpy().tolist():
|
||||||
|
|||||||
Reference in New Issue
Block a user