[Generate] Add bad words list argument to the generate function (#3367)

* add bad words list

* make style

* add bad_words_tokens

* make style

* better naming

* make style

* fix typo
This commit is contained in:
Patrick von Platen
2020-03-31 18:42:31 +02:00
committed by GitHub
parent ae6834e028
commit b38d552a92
5 changed files with 240 additions and 32 deletions

View File

@@ -427,14 +427,14 @@ class TFModelTesterMixin:
with self.assertRaises(AssertionError):
model.generate(do_sample=True, max_length=5)
# batch_size = 1
self._check_generated_tokens(model.generate(input_ids, do_sample=True))
self._check_generated_ids(model.generate(input_ids, do_sample=True))
# batch_size = 1, num_beams > 1
self._check_generated_tokens(model.generate(input_ids, do_sample=True, num_beams=3))
self._check_generated_ids(model.generate(input_ids, do_sample=True, num_beams=3))
else:
# batch_size = 1
self._check_generated_tokens(model.generate(do_sample=True, max_length=5))
self._check_generated_ids(model.generate(do_sample=True, max_length=5))
# batch_size = 1, num_beams > 1
self._check_generated_tokens(model.generate(do_sample=True, max_length=5, num_beams=3))
self._check_generated_ids(model.generate(do_sample=True, max_length=5, num_beams=3))
with self.assertRaises(AssertionError):
# generating multiple sequences when greedy no beam generation
@@ -446,24 +446,52 @@ class TFModelTesterMixin:
model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2)
# batch_size > 1, sample
self._check_generated_tokens(model.generate(input_ids, do_sample=True, num_return_sequences=3))
self._check_generated_ids(model.generate(input_ids, do_sample=True, num_return_sequences=3))
# batch_size > 1, greedy
self._check_generated_tokens(model.generate(input_ids, do_sample=False))
self._check_generated_ids(model.generate(input_ids, do_sample=False))
# batch_size > 1, num_beams > 1, sample
self._check_generated_tokens(
model.generate(input_ids, do_sample=True, num_beams=3, num_return_sequences=3,)
)
self._check_generated_ids(model.generate(input_ids, do_sample=True, num_beams=3, num_return_sequences=3,))
# batch_size > 1, num_beams > 1, greedy
self._check_generated_tokens(
model.generate(input_ids, do_sample=False, num_beams=3, num_return_sequences=3)
)
self._check_generated_ids(model.generate(input_ids, do_sample=False, num_beams=3, num_return_sequences=3))
def _check_generated_tokens(self, output_ids):
# check bad words tokens language generation
bad_words_ids = [
tf.squeeze(ids_tensor((1, 1), self.model_tester.vocab_size), -1).numpy().tolist(),
tf.squeeze(ids_tensor((2, 1), self.model_tester.vocab_size), -1).numpy().tolist(),
]
# sampling
output_tokens = model.generate(
input_ids, do_sample=True, bad_words_ids=bad_words_ids, num_return_sequences=3
)
generated_ids = output_tokens[:, input_ids.shape[-1] :]
self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
# beam search
output_tokens = model.generate(
input_ids, do_sample=False, bad_words_ids=bad_words_ids, num_beams=3, num_return_sequences=3
)
generated_ids = output_tokens[:, input_ids.shape[-1] :]
self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
def _check_generated_ids(self, output_ids):
for token_id in output_ids[0].numpy().tolist():
self.assertGreaterEqual(token_id, 0)
self.assertLess(token_id, self.model_tester.vocab_size)
def _check_match_tokens(self, generated_ids, bad_words_ids):
# for all bad word tokens
for bad_word_ids in bad_words_ids:
# for all slices in batch
for generated_ids_slice in generated_ids:
# for all word idx
for i in range(len(bad_word_ids), len(generated_ids_slice)):
# if tokens match
if generated_ids_slice[i - len(bad_word_ids) : i] == bad_word_ids:
return True
return False
def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=None):
"""Creates a random int32 tensor of the shape within the vocab size."""