[Generate] Add bad words list argument to the generate function (#3367)
* add bad words list * make style * add bad_words_tokens * make style * better naming * make style * fix typo
This commit is contained in:
committed by
GitHub
parent
ae6834e028
commit
b38d552a92
@@ -427,14 +427,14 @@ class TFModelTesterMixin:
|
||||
with self.assertRaises(AssertionError):
|
||||
model.generate(do_sample=True, max_length=5)
|
||||
# batch_size = 1
|
||||
self._check_generated_tokens(model.generate(input_ids, do_sample=True))
|
||||
self._check_generated_ids(model.generate(input_ids, do_sample=True))
|
||||
# batch_size = 1, num_beams > 1
|
||||
self._check_generated_tokens(model.generate(input_ids, do_sample=True, num_beams=3))
|
||||
self._check_generated_ids(model.generate(input_ids, do_sample=True, num_beams=3))
|
||||
else:
|
||||
# batch_size = 1
|
||||
self._check_generated_tokens(model.generate(do_sample=True, max_length=5))
|
||||
self._check_generated_ids(model.generate(do_sample=True, max_length=5))
|
||||
# batch_size = 1, num_beams > 1
|
||||
self._check_generated_tokens(model.generate(do_sample=True, max_length=5, num_beams=3))
|
||||
self._check_generated_ids(model.generate(do_sample=True, max_length=5, num_beams=3))
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
# generating multiple sequences when greedy no beam generation
|
||||
@@ -446,24 +446,52 @@ class TFModelTesterMixin:
|
||||
model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2)
|
||||
|
||||
# batch_size > 1, sample
|
||||
self._check_generated_tokens(model.generate(input_ids, do_sample=True, num_return_sequences=3))
|
||||
self._check_generated_ids(model.generate(input_ids, do_sample=True, num_return_sequences=3))
|
||||
# batch_size > 1, greedy
|
||||
self._check_generated_tokens(model.generate(input_ids, do_sample=False))
|
||||
self._check_generated_ids(model.generate(input_ids, do_sample=False))
|
||||
|
||||
# batch_size > 1, num_beams > 1, sample
|
||||
self._check_generated_tokens(
|
||||
model.generate(input_ids, do_sample=True, num_beams=3, num_return_sequences=3,)
|
||||
)
|
||||
self._check_generated_ids(model.generate(input_ids, do_sample=True, num_beams=3, num_return_sequences=3,))
|
||||
# batch_size > 1, num_beams > 1, greedy
|
||||
self._check_generated_tokens(
|
||||
model.generate(input_ids, do_sample=False, num_beams=3, num_return_sequences=3)
|
||||
)
|
||||
self._check_generated_ids(model.generate(input_ids, do_sample=False, num_beams=3, num_return_sequences=3))
|
||||
|
||||
def _check_generated_tokens(self, output_ids):
|
||||
# check bad words tokens language generation
|
||||
bad_words_ids = [
|
||||
tf.squeeze(ids_tensor((1, 1), self.model_tester.vocab_size), -1).numpy().tolist(),
|
||||
tf.squeeze(ids_tensor((2, 1), self.model_tester.vocab_size), -1).numpy().tolist(),
|
||||
]
|
||||
|
||||
# sampling
|
||||
output_tokens = model.generate(
|
||||
input_ids, do_sample=True, bad_words_ids=bad_words_ids, num_return_sequences=3
|
||||
)
|
||||
generated_ids = output_tokens[:, input_ids.shape[-1] :]
|
||||
self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
|
||||
|
||||
# beam search
|
||||
output_tokens = model.generate(
|
||||
input_ids, do_sample=False, bad_words_ids=bad_words_ids, num_beams=3, num_return_sequences=3
|
||||
)
|
||||
generated_ids = output_tokens[:, input_ids.shape[-1] :]
|
||||
self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
|
||||
|
||||
def _check_generated_ids(self, output_ids):
|
||||
for token_id in output_ids[0].numpy().tolist():
|
||||
self.assertGreaterEqual(token_id, 0)
|
||||
self.assertLess(token_id, self.model_tester.vocab_size)
|
||||
|
||||
def _check_match_tokens(self, generated_ids, bad_words_ids):
|
||||
# for all bad word tokens
|
||||
for bad_word_ids in bad_words_ids:
|
||||
# for all slices in batch
|
||||
for generated_ids_slice in generated_ids:
|
||||
# for all word idx
|
||||
for i in range(len(bad_word_ids), len(generated_ids_slice)):
|
||||
# if tokens match
|
||||
if generated_ids_slice[i - len(bad_word_ids) : i] == bad_word_ids:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=None):
|
||||
"""Creates a random int32 tensor of the shape within the vocab size."""
|
||||
|
||||
Reference in New Issue
Block a user