From ca2047bc352e32f8d6dc26f4e55c2556149230d9 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 9 Mar 2020 19:18:11 +0100 Subject: [PATCH] refactor variable naming and improve tf generate in line with torch generate --- src/transformers/modeling_tf_utils.py | 219 ++++++++++++++++++++------ src/transformers/modeling_utils.py | 91 ++++++----- 2 files changed, 219 insertions(+), 91 deletions(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index c76faa9fe1..3feafcbfa3 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -459,6 +459,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): self, input_ids=None, max_length=None, + min_length=None, do_sample=True, early_stopping=False, num_beams=None, @@ -470,7 +471,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): pad_token_id=None, eos_token_ids=None, length_penalty=None, + no_repeat_ngram_size=None, num_return_sequences=None, + attention_mask=None, ): r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling and beam-search. @@ -564,6 +567,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ) max_length = max_length if max_length is not None else self.config.max_length + min_length = min_length if min_length is not None else self.config.min_length do_sample = do_sample if do_sample is not None else self.config.do_sample early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping num_beams = num_beams if num_beams is not None else self.config.num_beams @@ -575,6 +579,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_ids = eos_token_ids if eos_token_ids is not None else self.config.eos_token_ids length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty + no_repeat_ngram_size = ( + no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size + ) num_return_sequences = ( num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences ) @@ -587,6 +594,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): eos_token_ids = [eos_token_ids] assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer." + assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer." assert isinstance(do_sample, bool), "`do_sample` should be a boolean." assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean." assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer." @@ -631,6 +639,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): num_beams >= num_return_sequences ), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences" + # create attention mask if necessary + # TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140 + if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids): + attention_mask = tf.cast(tf.math.not_equal(input_ids, pad_token_id), dtype=tf.int32) + elif attention_mask is None: + attention_mask = tf.ones_like(input_ids) + if pad_token_id is None and eos_token_ids is not None: logger.warning( "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_ids[0]) @@ -655,42 +670,55 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): input_ids = tf.broadcast_to( tf.expand_dims(input_ids, 1), (batch_size, effective_batch_mult * num_beams, input_ids_len) ) + attention_mask = tf.broadcast_to( + tf.expand_dims(attention_mask, 1), (batch_size, effective_batch_mult * num_beams, input_ids_len) + ) input_ids = tf.reshape( input_ids, (effective_batch_size * num_beams, input_ids_len) ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) + attention_mask = tf.reshape( + attention_mask, (effective_batch_size * num_beams, input_ids_len) + ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) if num_beams > 1: output = self._generate_beam_search( input_ids, - cur_len, - max_length, - do_sample, - early_stopping, - temperature, - top_k, - top_p, - repetition_penalty, - pad_token_id, - eos_token_ids, - effective_batch_size, - num_return_sequences, - length_penalty, - num_beams, - vocab_size, + cur_len=cur_len, + max_length=max_length, + min_length=min_length, + do_sample=do_sample, + early_stopping=early_stopping, + temperature=temperature, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, + pad_token_id=pad_token_id, + eos_token_ids=eos_token_ids, + batch_size=effective_batch_size, + num_return_sequences=num_return_sequences, + length_penalty=length_penalty, + num_beams=num_beams, + vocab_size=vocab_size, + attention_mask=attention_mask, ) else: output = self._generate_no_beam_search( input_ids, - cur_len, - max_length, - do_sample, - temperature, - top_k, - top_p, - repetition_penalty, - pad_token_id, - eos_token_ids, - effective_batch_size, + cur_len=cur_len, + max_length=max_length, + min_length=min_length, + do_sample=do_sample, + temperature=temperature, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, + pad_token_id=pad_token_id, + eos_token_ids=eos_token_ids, + batch_size=effective_batch_size, + vocab_size=vocab_size, + attention_mask=attention_mask, ) return output @@ -700,14 +728,18 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): input_ids, cur_len, max_length, + min_length, do_sample, temperature, top_k, top_p, repetition_penalty, + no_repeat_ngram_size, pad_token_id, eos_token_ids, batch_size, + vocab_size, + attention_mask, ): """ Generate sequences for each example without beam search (num_beams == 1). All returned sequence are generated independantly. @@ -720,7 +752,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): past = None while cur_len < max_length: - model_inputs = self.prepare_inputs_for_generation(input_ids, past=past) + model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask) outputs = self(**model_inputs) next_token_logits = outputs[0][:, -1, :] @@ -735,6 +767,33 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ) next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties) + if no_repeat_ngram_size > 0: + # calculate a list of banned tokens to prevent repetitively generating the same ngrams + # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345 + banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len) + # create banned_tokens boolean mask + banned_tokens_indices_mask = [] + for banned_tokens_slice in banned_tokens: + banned_tokens_indices_mask.append( + [True if token in banned_tokens_slice else False for token in range(vocab_size)] + ) + + next_token_logits = set_tensor_by_indices_to_value( + next_token_logits, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf") + ) + + # set eos token prob to zero if min_length is not reached + if eos_token_ids is not None and cur_len < min_length: + # create eos_token_ids boolean mask + is_token_logit_eos_token = tf.convert_to_tensor( + [True if token in eos_token_ids else False for token in range(vocab_size)], dtype=tf.bool + ) + eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size]) + + next_token_logits = set_tensor_by_indices_to_value( + next_token_logits, eos_token_indices_mask, -float("inf") + ) + if do_sample: # Temperature (higher temperature => more likely to sample low probability tokens) if temperature != 1.0: @@ -806,12 +865,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): input_ids, cur_len, max_length, + min_length, do_sample, early_stopping, temperature, top_k, top_p, repetition_penalty, + no_repeat_ngram_size, pad_token_id, eos_token_ids, batch_size, @@ -819,6 +880,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): length_penalty, num_beams, vocab_size, + attention_mask, ): """ Generate sequences for each example with beam search. """ @@ -829,7 +891,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): for _ in range(batch_size) ] - # scores for each sentence in the beam + # for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times if do_sample is False: beam_scores_begin = tf.zeros((batch_size, 1), dtype=tf.float32) beam_scores_end = tf.zeros((batch_size, num_beams - 1), dtype=tf.float32) * 1e-9 @@ -845,7 +907,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): done = [False for _ in range(batch_size)] while cur_len < max_length: - model_inputs = self.prepare_inputs_for_generation(input_ids, past=past) + model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask) outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size) next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size) @@ -860,12 +922,42 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ) next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties) - if do_sample: - # Temperature (higher temperature => more likely to sample low probability tokens) - if temperature != 1.0: - next_token_logits = next_token_logits / temperature + # Temperature (higher temperature => more likely to sample low probability tokens) + if temperature != 1.0: + next_token_logits = next_token_logits / temperature - scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size) + if no_repeat_ngram_size > 0: + # calculate a list of banned tokens to prevent repetitively generating the same ngrams + # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345 + banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len) + # create banned_tokens boolean mask + banned_tokens_indices_mask = [] + for banned_tokens_slice in banned_tokens: + banned_tokens_indices_mask.append( + [True if token in banned_tokens_slice else False for token in range(vocab_size)] + ) + + next_token_logits = set_tensor_by_indices_to_value( + next_token_logits, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf") + ) + + # set eos token prob to zero if min_length is not reached + if eos_token_ids is not None and cur_len < min_length: + # create eos_token_ids boolean mask + is_token_logit_eos_token = tf.convert_to_tensor( + [True if token in eos_token_ids else False for token in range(vocab_size)], dtype=tf.bool + ) + eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size]) + + next_token_logits = set_tensor_by_indices_to_value( + next_token_logits, eos_token_indices_mask, -float("inf") + ) + + # calculate log softmax score + scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size) + assert shape_list(scores) == [batch_size * num_beams, vocab_size] + + if do_sample: _scores = scores + tf.broadcast_to( beam_scores[:, None], (batch_size * num_beams, vocab_size) ) # (batch_size * num_beams, vocab_size) @@ -888,9 +980,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): next_scores = tf.gather(next_scores, next_scores_indices, batch_dims=1) # (batch_size, num_beams * 2) next_tokens = tf.gather(next_tokens, next_scores_indices, batch_dims=1) # (batch_size, num_beams * 2) else: - # do greedy beam search - scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size) - assert shape_list(scores) == [batch_size * num_beams, vocab_size] # Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product) next_scores = scores + tf.broadcast_to( beam_scores[:, None], (batch_size * num_beams, vocab_size) @@ -912,10 +1001,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): # for each sentence for batch_idx in range(batch_size): - # if we are done with this sentence - done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done( - tf.reduce_max(next_scores[batch_idx]).numpy() - ) if done[batch_idx]: assert ( len(generated_hyps[batch_idx]) >= num_beams @@ -930,29 +1015,46 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): next_sent_beam = [] # next tokens for this sentence - for idx, score in zip(next_tokens[batch_idx], next_scores[batch_idx]): + for beam_token_rank, (beam_token_id, beam_token_score) in enumerate( + zip(next_tokens[batch_idx], next_scores[batch_idx]) + ): # get beam and token IDs - beam_id = idx // vocab_size - token_id = idx % vocab_size + beam_id = beam_token_id // vocab_size + token_id = beam_token_id % vocab_size effective_beam_id = batch_idx * num_beams + beam_id # add to generated hypotheses if end of sentence or last iteration if eos_token_ids is not None and token_id.numpy() in eos_token_ids: - generated_hyps[batch_idx].add(tf.identity(input_ids[effective_beam_id]), score.numpy()) + # if beam_token does not belong to top num_beams tokens, it should not be added + is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams + if is_beam_token_worse_than_top_num_beams: + continue + generated_hyps[batch_idx].add( + tf.identity(input_ids[effective_beam_id]), beam_token_score.numpy() + ) else: # add next predicted token if it is not eos_token - next_sent_beam.append((score, token_id, effective_beam_id)) + next_sent_beam.append((beam_token_score, token_id, effective_beam_id)) # the beam for next step is full if len(next_sent_beam) == num_beams: break + # if we are done with this sentence + done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done( + tf.reduce_max(next_scores[batch_idx]).numpy() + ) + # update next beam content assert len(next_sent_beam) == num_beams, "Beam should always be full" next_batch_beam.extend(next_sent_beam) assert len(next_batch_beam) == num_beams * (batch_idx + 1) + # stop when we are done with each sentence + if all(done): + break + # sanity check / prepare next batch assert len(next_batch_beam) == batch_size * num_beams beam_scores = tf.convert_to_tensor([x[0] for x in next_batch_beam], dtype=tf.float32) @@ -967,10 +1069,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): if past: past = self._reorder_cache(past, beam_idx) - # stop when we are done with each sentence - if all(done): - break - # update current length cur_len = cur_len + 1 @@ -1072,6 +1170,29 @@ def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty): return tf.convert_to_tensor(token_penalties, dtype=tf.float32) +def calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len): + # Copied from fairseq for no_repeat_ngram in beam_search""" + if cur_len + 1 < no_repeat_ngram_size: + # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet + return [[] for _ in range(num_hypos)] + generated_ngrams = [{} for _ in range(num_hypos)] + for idx in range(num_hypos): + gen_tokens = prev_input_ids[idx].numpy().tolist() + generated_ngram = generated_ngrams[idx] + for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]): + prev_ngram_tuple = tuple(ngram[:-1]) + generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]] + + def _get_generated_ngrams(hypo_idx): + # Before decoding the next token, prevent decoding of ngrams that have already appeared + start_idx = cur_len + 1 - no_repeat_ngram_size + ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].numpy().tolist()) + return generated_ngrams[hypo_idx].get(ngram_idx, []) + + banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)] + return banned_tokens + + def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1): """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering Args: diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index a69f89167d..84e6243f87 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -722,6 +722,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): max_length = max_length if max_length is not None else self.config.max_length min_length = min_length if min_length is not None else self.config.min_length + do_sample = do_sample if do_sample is not None else self.config.do_sample early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping num_beams = num_beams if num_beams is not None else self.config.num_beams temperature = temperature if temperature is not None else self.config.temperature @@ -852,7 +853,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): device=next(self.parameters()).device, ) cur_len = 1 - self.model.decoder.generation_mode = True + + # put model in generation mode if it has one + if hasattr(self.model, "generation_mode"): + self.model.decoder.generation_mode = True else: encoder_inputs = None cur_len = input_ids.shape[-1] @@ -860,44 +864,44 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): if num_beams > 1: output = self._generate_beam_search( input_ids, - cur_len, - max_length, - min_length, - do_sample, - early_stopping, - temperature, - top_k, - top_p, - repetition_penalty, - no_repeat_ngram_size, - bos_token_id, - pad_token_id, - eos_token_ids, - effective_batch_size, - num_return_sequences, - length_penalty, - num_beams, - vocab_size, - encoder_inputs, - attention_mask, + cur_len=cur_len, + max_length=max_length, + min_length=min_length, + do_sample=do_sample, + early_stopping=early_stopping, + temperature=temperature, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, + bos_token_id=bos_token_id, + pad_token_id=pad_token_id, + eos_token_ids=eos_token_ids, + batch_size=effective_batch_size, + num_return_sequences=num_return_sequences, + length_penalty=length_penalty, + num_beams=num_beams, + vocab_size=vocab_size, + encoder_inputs=encoder_inputs, + attention_mask=attention_mask, ) else: output = self._generate_no_beam_search( input_ids, - cur_len, - max_length, - min_length, - do_sample, - temperature, - top_k, - top_p, - repetition_penalty, - no_repeat_ngram_size, - pad_token_id, - eos_token_ids, - effective_batch_size, - encoder_inputs, - attention_mask, + cur_len=cur_len, + max_length=max_length, + min_length=min_length, + do_sample=do_sample, + temperature=temperature, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, + pad_token_id=pad_token_id, + eos_token_ids=eos_token_ids, + batch_size=effective_batch_size, + encoder_inputs=encoder_inputs, + attention_mask=attention_mask, ) return output @@ -1157,24 +1161,27 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): next_sent_beam = [] # next tokens for this sentence - for i, (idx, score) in enumerate(zip(next_tokens[batch_idx], next_scores[batch_idx])): + for beam_token_rank, (beam_token_id, beam_token_score) in enumerate( + zip(next_tokens[batch_idx], next_scores[batch_idx]) + ): # get beam and word IDs - beam_id = idx // vocab_size - token_id = idx % vocab_size + beam_id = beam_token_id // vocab_size + token_id = beam_token_id % vocab_size effective_beam_id = batch_idx * num_beams + beam_id # add to generated hypotheses if end of sentence if (eos_token_ids is not None) and (token_id.item() in eos_token_ids): - # when passed to num_beams hypotheses, continue - if i >= num_beams: + # if beam_token does not belong to top num_beams tokens, it should not be added + is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams + if is_beam_token_worse_than_top_num_beams: continue generated_hyps[batch_idx].add( - input_ids[effective_beam_id].clone(), score.item(), + input_ids[effective_beam_id].clone(), beam_token_score.item(), ) else: # add next predicted word if it is not eos_token - next_sent_beam.append((score, token_id, effective_beam_id)) + next_sent_beam.append((beam_token_score, token_id, effective_beam_id)) # the beam for next step is full if len(next_sent_beam) == num_beams: