From b38d552a92a0a201c005afae0e1b861ae6de9ce0 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 31 Mar 2020 18:42:31 +0200 Subject: [PATCH] [Generate] Add bad words list argument to the generate function (#3367) * add bad words list * make style * add bad_words_tokens * make style * better naming * make style * fix typo --- src/transformers/configuration_utils.py | 1 + src/transformers/modeling_tf_utils.py | 90 ++++++++++++++++++++++++- src/transformers/modeling_utils.py | 73 +++++++++++++++++++- tests/test_modeling_common.py | 54 +++++++++++---- tests/test_modeling_tf_common.py | 54 +++++++++++---- 5 files changed, 240 insertions(+), 32 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 01f6b6554a..67477b76d0 100644 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -80,6 +80,7 @@ class PretrainedConfig(object): self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0) self.length_penalty = kwargs.pop("length_penalty", 1.0) self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0) + self.bad_words_ids = kwargs.pop("bad_words_ids", None) self.num_return_sequences = kwargs.pop("num_return_sequences", 1) # Fine-tuning task arguments diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 97400cf233..99264bb99b 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -467,6 +467,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): top_k=None, top_p=None, repetition_penalty=None, + bad_words_ids=None, bos_token_id=None, pad_token_id=None, eos_token_id=None, @@ -532,6 +533,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): no_repeat_ngram_size: (`optional`) int If set to int > 0, all ngrams of size `no_repeat_ngram_size` can only occur once. + bad_words_ids: (`optional`) list of lists of int + `bad_words_ids` contains tokens that are not allowed to be generated. In order to get the tokens of the words that should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`. + num_return_sequences: (`optional`) int The number of independently computed returned sequences for each element in the batch. Default to 1. @@ -582,6 +586,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True))) + tokenizer = AutoTokenizer.from_pretrained('gpt2') # Initialize tokenizer + model = TFAutoModelWithLMHead.from_pretrained('gpt2') # Download model and configuration from S3 and cache. + input_context = 'My cute dog' # "Legal" is one of the control codes for ctrl + bad_words_ids = [tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ['idiot', 'stupid', 'shut up']] + input_ids = tokenizer.encode(input_context, return_tensors='tf') # encode input context + outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids) # generate sequences without allowing bad_words to be generated """ # We cannot generate if the model does not have a LM head @@ -607,6 +617,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): no_repeat_ngram_size = ( no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size ) + bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids num_return_sequences = ( num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences ) @@ -641,6 +652,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): assert ( isinstance(num_return_sequences, int) and num_return_sequences > 0 ), "`num_return_sequences` should be a strictely positive integer." + assert ( + bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list) + ), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated" if input_ids is None: assert isinstance(bos_token_id, int) and bos_token_id >= 0, ( @@ -742,6 +756,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): top_p=top_p, repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, + bad_words_ids=bad_words_ids, bos_token_id=bos_token_id, pad_token_id=pad_token_id, eos_token_id=eos_token_id, @@ -766,6 +781,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): top_p=top_p, repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, + bad_words_ids=bad_words_ids, bos_token_id=bos_token_id, pad_token_id=pad_token_id, eos_token_id=eos_token_id, @@ -790,6 +806,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): top_p, repetition_penalty, no_repeat_ngram_size, + bad_words_ids, bos_token_id, pad_token_id, eos_token_id, @@ -828,7 +845,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): if no_repeat_ngram_size > 0: # calculate a list of banned tokens to prevent repetitively generating the same ngrams # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345 - banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len) + banned_tokens = calc_banned_ngram_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len) # create banned_tokens boolean mask banned_tokens_indices_mask = [] for banned_tokens_slice in banned_tokens: @@ -840,6 +857,20 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): next_token_logits, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf") ) + if bad_words_ids is not None: + # calculate a list of banned tokens according to bad words + banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids) + + banned_tokens_indices_mask = [] + for banned_tokens_slice in banned_tokens: + banned_tokens_indices_mask.append( + [True if token in banned_tokens_slice else False for token in range(vocab_size)] + ) + + next_token_logits = set_tensor_by_indices_to_value( + next_token_logits, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf") + ) + # set eos token prob to zero if min_length is not reached if eos_token_id is not None and cur_len < min_length: # create eos_token_id boolean mask @@ -936,6 +967,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): top_p, repetition_penalty, no_repeat_ngram_size, + bad_words_ids, bos_token_id, pad_token_id, decoder_start_token_id, @@ -1012,7 +1044,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): # calculate a list of banned tokens to prevent repetitively generating the same ngrams # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345 num_batch_hypotheses = batch_size * num_beams - banned_tokens = calc_banned_tokens(input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len) + banned_tokens = calc_banned_ngram_tokens( + input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len + ) # create banned_tokens boolean mask banned_tokens_indices_mask = [] for banned_tokens_slice in banned_tokens: @@ -1024,6 +1058,20 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): scores, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf") ) + if bad_words_ids is not None: + # calculate a list of banned tokens according to bad words + banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids) + + banned_tokens_indices_mask = [] + for banned_tokens_slice in banned_tokens: + banned_tokens_indices_mask.append( + [True if token in banned_tokens_slice else False for token in range(vocab_size)] + ) + + scores = set_tensor_by_indices_to_value( + scores, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf") + ) + assert shape_list(scores) == [batch_size * num_beams, vocab_size] if do_sample: @@ -1243,7 +1291,7 @@ def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty): return tf.convert_to_tensor(token_penalties, dtype=tf.float32) -def calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len): +def calc_banned_ngram_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len): # Copied from fairseq for no_repeat_ngram in beam_search""" if cur_len + 1 < no_repeat_ngram_size: # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet @@ -1266,6 +1314,42 @@ def calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len) return banned_tokens +def calc_banned_bad_words_ids(prev_input_ids, bad_words_ids): + banned_tokens = [] + + def _tokens_match(prev_tokens, tokens): + if len(tokens) == 0: + # if bad word tokens is just one token always ban it + return True + if len(tokens) > len(prev_input_ids): + # if bad word tokens are longer then prev input_ids they can't be equal + return False + + if prev_tokens[-len(tokens) :] == tokens: + # if tokens match + return True + else: + return False + + for prev_input_ids_slice in prev_input_ids: + banned_tokens_slice = [] + + for banned_token_seq in bad_words_ids: + assert len(banned_token_seq) > 0, "Banned words token sequences {} cannot have an empty list".format( + bad_words_ids + ) + + if _tokens_match(prev_input_ids_slice.numpy().tolist(), banned_token_seq[:-1]) is False: + # if tokens do not match continue + continue + + banned_tokens_slice.append(banned_token_seq[-1]) + + banned_tokens.append(banned_tokens_slice) + + return banned_tokens + + def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1): """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering Args: diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 6063fccc41..86077769af 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -667,6 +667,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): top_k=None, top_p=None, repetition_penalty=None, + bad_words_ids=None, bos_token_id=None, pad_token_id=None, eos_token_id=None, @@ -731,6 +732,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): no_repeat_ngram_size: (`optional`) int If set to int > 0, all ngrams of size `no_repeat_ngram_size` can only occur once. + bad_words_ids: (`optional`) list of lists of int + `bad_words_ids` contains tokens that are not allowed to be generated. In order to get the tokens of the words that should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`. num_return_sequences: (`optional`) int The number of independently computed returned sequences for each element in the batch. Default to 1. @@ -782,6 +785,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True))) + tokenizer = AutoTokenizer.from_pretrained('gpt2') # Initialize tokenizer + model = AutoModelWithLMHead.from_pretrained('gpt2') # Download model and configuration from S3 and cache. + input_context = 'My cute dog' # "Legal" is one of the control codes for ctrl + bad_words_ids = [tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ['idiot', 'stupid', 'shut up']] + input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context + outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids) # generate sequences without allowing bad_words to be generated """ # We cannot generate if the model does not have a LM head @@ -807,6 +816,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): no_repeat_ngram_size = ( no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size ) + bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids num_return_sequences = ( num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences ) @@ -844,6 +854,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): assert ( isinstance(num_return_sequences, int) and num_return_sequences > 0 ), "`num_return_sequences` should be a strictly positive integer." + assert ( + bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list) + ), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated" if input_ids is None: assert isinstance(bos_token_id, int) and bos_token_id >= 0, ( @@ -964,6 +977,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): top_p=top_p, repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, + bad_words_ids=bad_words_ids, bos_token_id=bos_token_id, pad_token_id=pad_token_id, decoder_start_token_id=decoder_start_token_id, @@ -988,6 +1002,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): top_p=top_p, repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, + bad_words_ids=bad_words_ids, bos_token_id=bos_token_id, pad_token_id=pad_token_id, decoder_start_token_id=decoder_start_token_id, @@ -1011,6 +1026,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): top_p, repetition_penalty, no_repeat_ngram_size, + bad_words_ids, bos_token_id, pad_token_id, eos_token_id, @@ -1045,7 +1061,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): if no_repeat_ngram_size > 0: # calculate a list of banned tokens to prevent repetitively generating the same ngrams # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345 - banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len) + banned_tokens = calc_banned_ngram_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len) + for batch_idx in range(batch_size): + next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf") + + if bad_words_ids is not None: + # calculate a list of banned tokens according to bad words + banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids) + for batch_idx in range(batch_size): next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf") @@ -1121,6 +1144,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): top_p, repetition_penalty, no_repeat_ngram_size, + bad_words_ids, bos_token_id, pad_token_id, eos_token_id, @@ -1187,12 +1211,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): # calculate a list of banned tokens to prevent repetitively generating the same ngrams num_batch_hypotheses = batch_size * num_beams # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345 - banned_batch_tokens = calc_banned_tokens( + banned_batch_tokens = calc_banned_ngram_tokens( input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len ) for i, banned_tokens in enumerate(banned_batch_tokens): scores[i, banned_tokens] = -float("inf") + if bad_words_ids is not None: + # calculate a list of banned tokens according to bad words + banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids) + + for i, banned_tokens in enumerate(banned_tokens): + scores[i, banned_tokens] = -float("inf") + assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format( scores.shape, (batch_size * num_beams, vocab_size) ) @@ -1397,7 +1428,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): return past -def calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len): +def calc_banned_ngram_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len): # Copied from fairseq for no_repeat_ngram in beam_search""" if cur_len + 1 < no_repeat_ngram_size: # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet @@ -1420,6 +1451,42 @@ def calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len) return banned_tokens +def calc_banned_bad_words_ids(prev_input_ids, bad_words_ids): + banned_tokens = [] + + def _tokens_match(prev_tokens, tokens): + if len(tokens) == 0: + # if bad word tokens is just one token always ban it + return True + if len(tokens) > len(prev_input_ids): + # if bad word tokens are longer then prev input_ids they can't be equal + return False + + if prev_tokens[-len(tokens) :] == tokens: + # if tokens match + return True + else: + return False + + for prev_input_ids_slice in prev_input_ids: + banned_tokens_slice = [] + + for banned_token_seq in bad_words_ids: + assert len(banned_token_seq) > 0, "Banned words token sequences {} cannot have an empty list".format( + bad_words_ids + ) + + if _tokens_match(prev_input_ids_slice.tolist(), banned_token_seq[:-1]) is False: + # if tokens do not match continue + continue + + banned_tokens_slice.append(banned_token_seq[-1]) + + banned_tokens.append(banned_tokens_slice) + + return banned_tokens + + def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1): """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering Args: diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index b284ee6ec2..fb6390157b 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -641,14 +641,14 @@ class ModelTesterMixin: with self.assertRaises(AssertionError): model.generate(do_sample=True, max_length=5) # batch_size = 1 - self._check_generated_tokens(model.generate(input_ids, do_sample=True)) + self._check_generated_ids(model.generate(input_ids, do_sample=True)) # batch_size = 1, num_beams > 1 - self._check_generated_tokens(model.generate(input_ids, do_sample=True, num_beams=3)) + self._check_generated_ids(model.generate(input_ids, do_sample=True, num_beams=3)) else: # batch_size = 1 - self._check_generated_tokens(model.generate(do_sample=True, max_length=5)) + self._check_generated_ids(model.generate(do_sample=True, max_length=5)) # batch_size = 1, num_beams > 1 - self._check_generated_tokens(model.generate(do_sample=True, max_length=5, num_beams=3)) + self._check_generated_ids(model.generate(do_sample=True, max_length=5, num_beams=3)) with self.assertRaises(AssertionError): # generating multiple sequences when greedy no beam generation @@ -660,24 +660,52 @@ class ModelTesterMixin: model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2) # batch_size > 1, sample - self._check_generated_tokens(model.generate(input_ids, do_sample=True, num_return_sequences=3)) + self._check_generated_ids(model.generate(input_ids, do_sample=True, num_return_sequences=3)) # batch_size > 1, greedy - self._check_generated_tokens(model.generate(input_ids, do_sample=False)) + self._check_generated_ids(model.generate(input_ids, do_sample=False)) # batch_size > 1, num_beams > 1, sample - self._check_generated_tokens( - model.generate(input_ids, do_sample=True, num_beams=3, num_return_sequences=3,) - ) + self._check_generated_ids(model.generate(input_ids, do_sample=True, num_beams=3, num_return_sequences=3,)) # batch_size > 1, num_beams > 1, greedy - self._check_generated_tokens( - model.generate(input_ids, do_sample=False, num_beams=3, num_return_sequences=3) - ) + self._check_generated_ids(model.generate(input_ids, do_sample=False, num_beams=3, num_return_sequences=3)) - def _check_generated_tokens(self, output_ids): + # check bad words tokens language generation + bad_words_ids = [ + ids_tensor((1, 1), self.model_tester.vocab_size).squeeze(-1).tolist(), + ids_tensor((2, 1), self.model_tester.vocab_size).squeeze(-1).tolist(), + ] + + # sampling + output_tokens = model.generate( + input_ids, do_sample=True, bad_words_ids=bad_words_ids, num_return_sequences=3 + ) + generated_ids = output_tokens[:, input_ids.shape[-1] :] + self.assertFalse(self._check_match_tokens(generated_ids.tolist(), bad_words_ids)) + + # beam search + output_tokens = model.generate( + input_ids, do_sample=False, bad_words_ids=bad_words_ids, num_beams=3, num_return_sequences=3 + ) + generated_ids = output_tokens[:, input_ids.shape[-1] :] + self.assertFalse(self._check_match_tokens(generated_ids.tolist(), bad_words_ids)) + + def _check_generated_ids(self, output_ids): for token_id in output_ids[0].tolist(): self.assertGreaterEqual(token_id, 0) self.assertLess(token_id, self.model_tester.vocab_size) + def _check_match_tokens(self, generated_ids, bad_words_ids): + # for all bad word tokens + for bad_word_ids in bad_words_ids: + # for all slices in batch + for generated_ids_slice in generated_ids: + # for all word idx + for i in range(len(bad_word_ids), len(generated_ids_slice)): + # if tokens match + if generated_ids_slice[i - len(bad_word_ids) : i] == bad_word_ids: + return True + return False + global_rng = random.Random() diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index d2d7fd0b4f..f6d02bf52c 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -427,14 +427,14 @@ class TFModelTesterMixin: with self.assertRaises(AssertionError): model.generate(do_sample=True, max_length=5) # batch_size = 1 - self._check_generated_tokens(model.generate(input_ids, do_sample=True)) + self._check_generated_ids(model.generate(input_ids, do_sample=True)) # batch_size = 1, num_beams > 1 - self._check_generated_tokens(model.generate(input_ids, do_sample=True, num_beams=3)) + self._check_generated_ids(model.generate(input_ids, do_sample=True, num_beams=3)) else: # batch_size = 1 - self._check_generated_tokens(model.generate(do_sample=True, max_length=5)) + self._check_generated_ids(model.generate(do_sample=True, max_length=5)) # batch_size = 1, num_beams > 1 - self._check_generated_tokens(model.generate(do_sample=True, max_length=5, num_beams=3)) + self._check_generated_ids(model.generate(do_sample=True, max_length=5, num_beams=3)) with self.assertRaises(AssertionError): # generating multiple sequences when greedy no beam generation @@ -446,24 +446,52 @@ class TFModelTesterMixin: model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2) # batch_size > 1, sample - self._check_generated_tokens(model.generate(input_ids, do_sample=True, num_return_sequences=3)) + self._check_generated_ids(model.generate(input_ids, do_sample=True, num_return_sequences=3)) # batch_size > 1, greedy - self._check_generated_tokens(model.generate(input_ids, do_sample=False)) + self._check_generated_ids(model.generate(input_ids, do_sample=False)) # batch_size > 1, num_beams > 1, sample - self._check_generated_tokens( - model.generate(input_ids, do_sample=True, num_beams=3, num_return_sequences=3,) - ) + self._check_generated_ids(model.generate(input_ids, do_sample=True, num_beams=3, num_return_sequences=3,)) # batch_size > 1, num_beams > 1, greedy - self._check_generated_tokens( - model.generate(input_ids, do_sample=False, num_beams=3, num_return_sequences=3) - ) + self._check_generated_ids(model.generate(input_ids, do_sample=False, num_beams=3, num_return_sequences=3)) - def _check_generated_tokens(self, output_ids): + # check bad words tokens language generation + bad_words_ids = [ + tf.squeeze(ids_tensor((1, 1), self.model_tester.vocab_size), -1).numpy().tolist(), + tf.squeeze(ids_tensor((2, 1), self.model_tester.vocab_size), -1).numpy().tolist(), + ] + + # sampling + output_tokens = model.generate( + input_ids, do_sample=True, bad_words_ids=bad_words_ids, num_return_sequences=3 + ) + generated_ids = output_tokens[:, input_ids.shape[-1] :] + self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids)) + + # beam search + output_tokens = model.generate( + input_ids, do_sample=False, bad_words_ids=bad_words_ids, num_beams=3, num_return_sequences=3 + ) + generated_ids = output_tokens[:, input_ids.shape[-1] :] + self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids)) + + def _check_generated_ids(self, output_ids): for token_id in output_ids[0].numpy().tolist(): self.assertGreaterEqual(token_id, 0) self.assertLess(token_id, self.model_tester.vocab_size) + def _check_match_tokens(self, generated_ids, bad_words_ids): + # for all bad word tokens + for bad_word_ids in bad_words_ids: + # for all slices in batch + for generated_ids_slice in generated_ids: + # for all word idx + for i in range(len(bad_word_ids), len(generated_ids_slice)): + # if tokens match + if generated_ids_slice[i - len(bad_word_ids) : i] == bad_word_ids: + return True + return False + def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=None): """Creates a random int32 tensor of the shape within the vocab size."""