make GPT2 and CTRL shape consistent between torch and TF
This commit is contained in:
@@ -104,10 +104,10 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
|
||||
k = self.split_into_heads(k, batch_size)
|
||||
v = self.split_into_heads(v, batch_size)
|
||||
if layer_past is not None:
|
||||
past_key, past_value = tf.unstack(layer_past, axis=1)
|
||||
past_key, past_value = tf.unstack(layer_past, axis=0)
|
||||
k = tf.concat((past_key, k), axis=-2)
|
||||
v = tf.concat((past_value, v), axis=-2)
|
||||
present = tf.stack((k, v), axis=1)
|
||||
present = tf.stack((k, v), axis=0)
|
||||
|
||||
output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask)
|
||||
scaled_attention = tf.transpose(output[0], perm=[0, 2, 1, 3])
|
||||
|
||||
@@ -139,10 +139,10 @@ class TFAttention(tf.keras.layers.Layer):
|
||||
key = self.split_heads(key)
|
||||
value = self.split_heads(value)
|
||||
if layer_past is not None:
|
||||
past_key, past_value = tf.unstack(layer_past, axis=1)
|
||||
past_key, past_value = tf.unstack(layer_past, axis=0)
|
||||
key = tf.concat([past_key, key], axis=-2)
|
||||
value = tf.concat([past_value, value], axis=-2)
|
||||
present = tf.stack([key, value], axis=1)
|
||||
present = tf.stack([key, value], axis=0)
|
||||
|
||||
attn_outputs = self._attn([query, key, value, attention_mask, head_mask], training=training)
|
||||
a = attn_outputs[0]
|
||||
|
||||
@@ -658,7 +658,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
|
||||
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
|
||||
if repetition_penalty != 1.0:
|
||||
next_token_logits_penalties = _create_next_token_logits_penalties(input_ids, next_token_logits, repetition_penalty)
|
||||
next_token_logits_penalties = _create_next_token_logits_penalties(
|
||||
input_ids, next_token_logits, repetition_penalty
|
||||
)
|
||||
next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties)
|
||||
|
||||
if do_sample:
|
||||
@@ -779,7 +781,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
|
||||
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
|
||||
if repetition_penalty != 1.0:
|
||||
next_token_logits_penalties = _create_next_token_logits_penalties(input_ids, next_token_logits, repetition_penalty)
|
||||
next_token_logits_penalties = _create_next_token_logits_penalties(
|
||||
input_ids, next_token_logits, repetition_penalty
|
||||
)
|
||||
next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties)
|
||||
|
||||
if do_sample:
|
||||
@@ -791,11 +795,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
next_token_logits, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
|
||||
) # (batch_size * num_beams, vocab_size)
|
||||
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
|
||||
next_tokens = tf.random.categorical(next_token_logits, dtype=tf.int32, num_samples=2) # (batch_size * num_beams, vocab_size)
|
||||
next_tokens = tf.random.categorical(
|
||||
next_token_logits, dtype=tf.int32, num_samples=2
|
||||
) # (batch_size * num_beams, vocab_size)
|
||||
# Compute next scores
|
||||
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
|
||||
_scores = tf.gather(scores, next_tokens, batch_dims=1) # (batch_size * num_beams, 2)
|
||||
next_scores = _scores + tf.broadcast_to(beam_scores[:, None], (batch_size * num_beams, 2)) # (batch_size * num_beams, 2)
|
||||
next_scores = _scores + tf.broadcast_to(
|
||||
beam_scores[:, None], (batch_size * num_beams, 2)
|
||||
) # (batch_size * num_beams, 2)
|
||||
# Match shape of greedy beam search
|
||||
next_tokens = tf.reshape(next_tokens, (batch_size, 2 * num_beams)) # (batch_size, 2 * num_beams)
|
||||
next_scores = tf.reshape(next_scores, (batch_size, 2 * num_beams)) # (batch_size, 2 * num_beams)
|
||||
@@ -804,10 +812,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
|
||||
assert shape_list(scores) == [batch_size * num_beams, vocab_size]
|
||||
# Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
|
||||
next_scores = scores + tf.broadcast_to(beam_scores[:, None], (batch_size * num_beams, vocab_size)) # (batch_size * num_beams, vocab_size)
|
||||
next_scores = scores + tf.broadcast_to(
|
||||
beam_scores[:, None], (batch_size * num_beams, vocab_size)
|
||||
) # (batch_size * num_beams, vocab_size)
|
||||
|
||||
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
|
||||
next_scores = tf.reshape(next_scores, (batch_size, num_beams * vocab_size)) # (batch_size, num_beams * vocab_size)
|
||||
next_scores = tf.reshape(
|
||||
next_scores, (batch_size, num_beams * vocab_size)
|
||||
) # (batch_size, num_beams * vocab_size)
|
||||
next_scores, next_tokens = tf.math.top_k(next_scores, 2 * num_beams, sorted=True)
|
||||
|
||||
assert shape_list(next_scores) == shape_list(next_tokens) == [batch_size, 2 * num_beams]
|
||||
@@ -909,7 +921,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
best_hyp = sorted_hyps.pop()[1]
|
||||
sent_lengths_list.append(len(best_hyp))
|
||||
best.append(best_hyp)
|
||||
assert output_batch_size == len(best), "Output batch size {} must match output beam hypotheses {}".format(output_batch_size, len(best))
|
||||
assert output_batch_size == len(best), "Output batch size {} must match output beam hypotheses {}".format(
|
||||
output_batch_size, len(best)
|
||||
)
|
||||
|
||||
sent_lengths = tf.convert_to_tensor(sent_lengths_list, dtype=tf.int32)
|
||||
|
||||
@@ -925,7 +939,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
decoded_hypo = tf.concat([hypo, padding], axis=0)
|
||||
|
||||
if sent_lengths[i] < max_length:
|
||||
decoded_hypo = tf.where(tf.range(max_length) == sent_lengths[i], eos_token_ids[0] * tf.ones((sent_max_len,), dtype=tf.int32), decoded_hypo)
|
||||
decoded_hypo = tf.where(
|
||||
tf.range(max_length) == sent_lengths[i],
|
||||
eos_token_ids[0] * tf.ones((sent_max_len,), dtype=tf.int32),
|
||||
decoded_hypo,
|
||||
)
|
||||
decoded_list.append(decoded_hypo)
|
||||
decoded = tf.stack(decoded_list)
|
||||
else:
|
||||
@@ -942,7 +960,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
# get the correct batch idx from layer past batch dim
|
||||
# batch dim of `past` and `mems` is at 2nd position
|
||||
reordered_layer_past = [tf.identity(tf.expand_dims(layer_past[:, i], 1)) for i in beam_idx]
|
||||
# TODO: check whether it is an error that TF past.shape != Torch past.shape
|
||||
reordered_layer_past = tf.concat(reordered_layer_past, axis=1)
|
||||
# check that shape matches
|
||||
assert shape_list(reordered_layer_past) == shape_list(layer_past)
|
||||
|
||||
Reference in New Issue
Block a user