From c4c4c9998a2e335706dd7638bbd90513234077f3 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 4 Mar 2020 11:09:45 +0100 Subject: [PATCH] make GPT2 and CTRL shape consistent between torch and TF --- src/transformers/modeling_tf_ctrl.py | 4 +-- src/transformers/modeling_tf_gpt2.py | 4 +-- src/transformers/modeling_tf_utils.py | 35 ++++++++++++++++++++------- 3 files changed, 30 insertions(+), 13 deletions(-) diff --git a/src/transformers/modeling_tf_ctrl.py b/src/transformers/modeling_tf_ctrl.py index 335421979c..8a049bbce9 100644 --- a/src/transformers/modeling_tf_ctrl.py +++ b/src/transformers/modeling_tf_ctrl.py @@ -104,10 +104,10 @@ class TFMultiHeadAttention(tf.keras.layers.Layer): k = self.split_into_heads(k, batch_size) v = self.split_into_heads(v, batch_size) if layer_past is not None: - past_key, past_value = tf.unstack(layer_past, axis=1) + past_key, past_value = tf.unstack(layer_past, axis=0) k = tf.concat((past_key, k), axis=-2) v = tf.concat((past_value, v), axis=-2) - present = tf.stack((k, v), axis=1) + present = tf.stack((k, v), axis=0) output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask) scaled_attention = tf.transpose(output[0], perm=[0, 2, 1, 3]) diff --git a/src/transformers/modeling_tf_gpt2.py b/src/transformers/modeling_tf_gpt2.py index 7e9b102b6d..3b79d58949 100644 --- a/src/transformers/modeling_tf_gpt2.py +++ b/src/transformers/modeling_tf_gpt2.py @@ -139,10 +139,10 @@ class TFAttention(tf.keras.layers.Layer): key = self.split_heads(key) value = self.split_heads(value) if layer_past is not None: - past_key, past_value = tf.unstack(layer_past, axis=1) + past_key, past_value = tf.unstack(layer_past, axis=0) key = tf.concat([past_key, key], axis=-2) value = tf.concat([past_value, value], axis=-2) - present = tf.stack([key, value], axis=1) + present = tf.stack([key, value], axis=0) attn_outputs = self._attn([query, key, value, attention_mask, head_mask], training=training) a = attn_outputs[0] diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index e3083b6d20..1dfeecdd8e 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -658,7 +658,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) if repetition_penalty != 1.0: - next_token_logits_penalties = _create_next_token_logits_penalties(input_ids, next_token_logits, repetition_penalty) + next_token_logits_penalties = _create_next_token_logits_penalties( + input_ids, next_token_logits, repetition_penalty + ) next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties) if do_sample: @@ -779,7 +781,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858) if repetition_penalty != 1.0: - next_token_logits_penalties = _create_next_token_logits_penalties(input_ids, next_token_logits, repetition_penalty) + next_token_logits_penalties = _create_next_token_logits_penalties( + input_ids, next_token_logits, repetition_penalty + ) next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties) if do_sample: @@ -791,11 +795,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): next_token_logits, top_k=top_k, top_p=top_p, min_tokens_to_keep=2 ) # (batch_size * num_beams, vocab_size) # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search) - next_tokens = tf.random.categorical(next_token_logits, dtype=tf.int32, num_samples=2) # (batch_size * num_beams, vocab_size) + next_tokens = tf.random.categorical( + next_token_logits, dtype=tf.int32, num_samples=2 + ) # (batch_size * num_beams, vocab_size) # Compute next scores scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size) _scores = tf.gather(scores, next_tokens, batch_dims=1) # (batch_size * num_beams, 2) - next_scores = _scores + tf.broadcast_to(beam_scores[:, None], (batch_size * num_beams, 2)) # (batch_size * num_beams, 2) + next_scores = _scores + tf.broadcast_to( + beam_scores[:, None], (batch_size * num_beams, 2) + ) # (batch_size * num_beams, 2) # Match shape of greedy beam search next_tokens = tf.reshape(next_tokens, (batch_size, 2 * num_beams)) # (batch_size, 2 * num_beams) next_scores = tf.reshape(next_scores, (batch_size, 2 * num_beams)) # (batch_size, 2 * num_beams) @@ -804,10 +812,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size) assert shape_list(scores) == [batch_size * num_beams, vocab_size] # Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product) - next_scores = scores + tf.broadcast_to(beam_scores[:, None], (batch_size * num_beams, vocab_size)) # (batch_size * num_beams, vocab_size) + next_scores = scores + tf.broadcast_to( + beam_scores[:, None], (batch_size * num_beams, vocab_size) + ) # (batch_size * num_beams, vocab_size) # re-organize to group the beam together (we are keeping top hypothesis accross beams) - next_scores = tf.reshape(next_scores, (batch_size, num_beams * vocab_size)) # (batch_size, num_beams * vocab_size) + next_scores = tf.reshape( + next_scores, (batch_size, num_beams * vocab_size) + ) # (batch_size, num_beams * vocab_size) next_scores, next_tokens = tf.math.top_k(next_scores, 2 * num_beams, sorted=True) assert shape_list(next_scores) == shape_list(next_tokens) == [batch_size, 2 * num_beams] @@ -909,7 +921,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): best_hyp = sorted_hyps.pop()[1] sent_lengths_list.append(len(best_hyp)) best.append(best_hyp) - assert output_batch_size == len(best), "Output batch size {} must match output beam hypotheses {}".format(output_batch_size, len(best)) + assert output_batch_size == len(best), "Output batch size {} must match output beam hypotheses {}".format( + output_batch_size, len(best) + ) sent_lengths = tf.convert_to_tensor(sent_lengths_list, dtype=tf.int32) @@ -925,7 +939,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): decoded_hypo = tf.concat([hypo, padding], axis=0) if sent_lengths[i] < max_length: - decoded_hypo = tf.where(tf.range(max_length) == sent_lengths[i], eos_token_ids[0] * tf.ones((sent_max_len,), dtype=tf.int32), decoded_hypo) + decoded_hypo = tf.where( + tf.range(max_length) == sent_lengths[i], + eos_token_ids[0] * tf.ones((sent_max_len,), dtype=tf.int32), + decoded_hypo, + ) decoded_list.append(decoded_hypo) decoded = tf.stack(decoded_list) else: @@ -942,7 +960,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): # get the correct batch idx from layer past batch dim # batch dim of `past` and `mems` is at 2nd position reordered_layer_past = [tf.identity(tf.expand_dims(layer_past[:, i], 1)) for i in beam_idx] - # TODO: check whether it is an error that TF past.shape != Torch past.shape reordered_layer_past = tf.concat(reordered_layer_past, axis=1) # check that shape matches assert shape_list(reordered_layer_past) == shape_list(layer_past)