[T5, Testst] Add extensive hard-coded integration tests and make sure PT and TF give equal results (#3550)
* add some t5 integration tests * finish summarization and translation integration tests for T5 - results loook good * add tf test * fix == vs is bug * fix tf beam search error and make tf t5 tests pass
This commit is contained in:
committed by
GitHub
parent
8538ce9044
commit
b815edf69f
@@ -119,7 +119,7 @@ class TFT5Attention(tf.keras.layers.Layer):
|
||||
|
||||
if self.has_relative_attention_bias:
|
||||
self.relative_attention_bias = tf.keras.layers.Embedding(
|
||||
self.relative_attention_num_buckets, self.n_heads, name="relative_attention_bias"
|
||||
self.relative_attention_num_buckets, self.n_heads, name="relative_attention_bias",
|
||||
)
|
||||
self.pruned_heads = set()
|
||||
|
||||
@@ -178,13 +178,15 @@ class TFT5Attention(tf.keras.layers.Layer):
|
||||
memory_position = tf.range(klen)[None, :]
|
||||
relative_position = memory_position - context_position # shape (qlen, klen)
|
||||
rp_bucket = self._relative_position_bucket(
|
||||
relative_position, bidirectional=not self.is_decoder, num_buckets=self.relative_attention_num_buckets
|
||||
relative_position, bidirectional=not self.is_decoder, num_buckets=self.relative_attention_num_buckets,
|
||||
)
|
||||
values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads)
|
||||
values = tf.expand_dims(tf.transpose(values, [2, 0, 1]), axis=0) # shape (1, num_heads, qlen, klen)
|
||||
return values
|
||||
|
||||
def call(self, input, mask=None, kv=None, position_bias=None, cache=None, head_mask=None, training=False):
|
||||
def call(
|
||||
self, input, mask=None, kv=None, position_bias=None, cache=None, head_mask=None, training=False,
|
||||
):
|
||||
"""
|
||||
Self-attention (if kv is None) or attention over source sentence (provided by kv).
|
||||
"""
|
||||
@@ -261,15 +263,17 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer):
|
||||
def __init__(self, config, has_relative_attention_bias=False, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.SelfAttention = TFT5Attention(
|
||||
config, has_relative_attention_bias=has_relative_attention_bias, name="SelfAttention"
|
||||
config, has_relative_attention_bias=has_relative_attention_bias, name="SelfAttention",
|
||||
)
|
||||
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm")
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
||||
|
||||
def call(self, hidden_states, attention_mask=None, position_bias=None, head_mask=None, training=False):
|
||||
def call(
|
||||
self, hidden_states, attention_mask=None, position_bias=None, head_mask=None, training=False,
|
||||
):
|
||||
norm_x = self.layer_norm(hidden_states)
|
||||
attention_output = self.SelfAttention(
|
||||
norm_x, mask=attention_mask, position_bias=position_bias, head_mask=head_mask, training=training
|
||||
norm_x, mask=attention_mask, position_bias=position_bias, head_mask=head_mask, training=training,
|
||||
)
|
||||
y = attention_output[0]
|
||||
layer_output = hidden_states + self.dropout(y, training=training)
|
||||
@@ -281,15 +285,17 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer):
|
||||
def __init__(self, config, has_relative_attention_bias=False, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.EncDecAttention = TFT5Attention(
|
||||
config, has_relative_attention_bias=has_relative_attention_bias, name="EncDecAttention"
|
||||
config, has_relative_attention_bias=has_relative_attention_bias, name="EncDecAttention",
|
||||
)
|
||||
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm")
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
|
||||
|
||||
def call(self, hidden_states, kv, attention_mask=None, position_bias=None, head_mask=None, training=False):
|
||||
def call(
|
||||
self, hidden_states, kv, attention_mask=None, position_bias=None, head_mask=None, training=False,
|
||||
):
|
||||
norm_x = self.layer_norm(hidden_states)
|
||||
attention_output = self.EncDecAttention(
|
||||
norm_x, mask=attention_mask, kv=kv, position_bias=position_bias, head_mask=head_mask, training=training
|
||||
norm_x, mask=attention_mask, kv=kv, position_bias=position_bias, head_mask=head_mask, training=training,
|
||||
)
|
||||
y = attention_output[0]
|
||||
layer_output = hidden_states + self.dropout(y, training=training)
|
||||
@@ -303,12 +309,12 @@ class TFT5Block(tf.keras.layers.Layer):
|
||||
self.is_decoder = config.is_decoder
|
||||
self.layer = []
|
||||
self.layer.append(
|
||||
TFT5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias, name="layer_._0")
|
||||
TFT5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias, name="layer_._0",)
|
||||
)
|
||||
if self.is_decoder:
|
||||
self.layer.append(
|
||||
TFT5LayerCrossAttention(
|
||||
config, has_relative_attention_bias=has_relative_attention_bias, name="layer_._1"
|
||||
config, has_relative_attention_bias=has_relative_attention_bias, name="layer_._1",
|
||||
)
|
||||
)
|
||||
self.layer.append(TFT5LayerFF(config, name="layer_._2"))
|
||||
@@ -402,7 +408,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
||||
self.num_hidden_layers = config.num_layers
|
||||
|
||||
self.block = [
|
||||
TFT5Block(config, has_relative_attention_bias=bool(i == 0), name="block_._{}".format(i))
|
||||
TFT5Block(config, has_relative_attention_bias=bool(i == 0), name="block_._{}".format(i),)
|
||||
for i in range(config.num_layers)
|
||||
]
|
||||
self.final_layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="final_layer_norm")
|
||||
@@ -469,7 +475,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
||||
if self.config.is_decoder:
|
||||
seq_ids = tf.range(seq_length)
|
||||
causal_mask = tf.less_equal(
|
||||
tf.tile(seq_ids[None, None, :], (batch_size, seq_length, 1)), seq_ids[None, :, None]
|
||||
tf.tile(seq_ids[None, None, :], (batch_size, seq_length, 1)), seq_ids[None, :, None],
|
||||
)
|
||||
causal_mask = tf.cast(causal_mask, dtype=tf.float32)
|
||||
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
||||
@@ -748,7 +754,7 @@ class TFT5Model(TFT5PreTrainedModel):
|
||||
# Encode if needed (training, first prediction pass)
|
||||
if encoder_outputs is None:
|
||||
encoder_outputs = self.encoder(
|
||||
input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask
|
||||
input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask,
|
||||
)
|
||||
|
||||
hidden_states = encoder_outputs[0]
|
||||
@@ -852,7 +858,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel):
|
||||
if encoder_outputs is None:
|
||||
# Convert encoder inputs in embeddings if needed
|
||||
encoder_outputs = self.encoder(
|
||||
input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask
|
||||
input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask,
|
||||
)
|
||||
|
||||
hidden_states = encoder_outputs[0]
|
||||
|
||||
@@ -1112,12 +1112,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
assert shape_list(next_scores) == shape_list(next_tokens) == [batch_size, 2 * num_beams]
|
||||
|
||||
# next batch beam content
|
||||
# list of (batch_size * num_beams) tuple(next hypothesis score, next token, current position in the batch)
|
||||
next_batch_beam = []
|
||||
|
||||
# for each sentence
|
||||
for batch_idx in range(batch_size):
|
||||
|
||||
# if we are done with this sentence
|
||||
if done[batch_idx]:
|
||||
assert (
|
||||
len(generated_hyps[batch_idx]) >= num_beams
|
||||
@@ -1135,14 +1135,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
|
||||
zip(next_tokens[batch_idx], next_scores[batch_idx])
|
||||
):
|
||||
|
||||
# get beam and token IDs
|
||||
beam_id = beam_token_id // vocab_size
|
||||
token_id = beam_token_id % vocab_size
|
||||
|
||||
effective_beam_id = batch_idx * num_beams + beam_id
|
||||
# add to generated hypotheses if end of sentence or last iteration
|
||||
if eos_token_id is not None and token_id.numpy() is eos_token_id:
|
||||
if (eos_token_id is not None) and (token_id.numpy() == eos_token_id):
|
||||
# if beam_token does not belong to top num_beams tokens, it should not be added
|
||||
is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
|
||||
if is_beam_token_worse_than_top_num_beams:
|
||||
@@ -1158,9 +1157,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
if len(next_sent_beam) == num_beams:
|
||||
break
|
||||
|
||||
# if we are done with this sentence
|
||||
# Check if were done so that we can save a pad step if all(done)
|
||||
done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
|
||||
tf.reduce_max(next_scores[batch_idx]).numpy()
|
||||
tf.reduce_max(next_scores[batch_idx]).numpy(), cur_len=cur_len
|
||||
)
|
||||
|
||||
# update next beam content
|
||||
@@ -1178,6 +1177,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
beam_tokens = tf.convert_to_tensor([x[1] for x in next_batch_beam], dtype=tf.int32)
|
||||
beam_idx = tf.convert_to_tensor([x[2] for x in next_batch_beam], dtype=tf.int32)
|
||||
|
||||
print("Scores: {}-{}".format(cur_len, beam_scores.numpy()))
|
||||
|
||||
# re-order batch
|
||||
input_ids = tf.stack([tf.identity(input_ids[x, :]) for x in beam_idx])
|
||||
input_ids = tf.concat([input_ids, tf.expand_dims(beam_tokens, 1)], axis=-1)
|
||||
@@ -1185,6 +1186,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
if past is not None:
|
||||
past = self._reorder_cache(past, beam_idx)
|
||||
|
||||
# extend attention_mask for new generated input if only decoder
|
||||
if self.config.is_encoder_decoder is False:
|
||||
attention_mask = tf.concat(
|
||||
[attention_mask, tf.ones((shape_list(attention_mask)[0], 1), dtype=tf.int32)], axis=-1
|
||||
@@ -1244,16 +1246,26 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
|
||||
# fill with hypothesis and eos_token_id if necessary
|
||||
for i, hypo in enumerate(best):
|
||||
padding = tf.ones((sent_max_len - shape_list(hypo)[0],), dtype=tf.int32) * pad_token_id
|
||||
decoded_hypo = tf.concat([hypo, padding], axis=0)
|
||||
assert sent_lengths[i] == shape_list(hypo)[0]
|
||||
# if sent_length is max_len do not pad
|
||||
if sent_lengths[i] == sent_max_len:
|
||||
decoded_slice = hypo
|
||||
else:
|
||||
# else pad to sent_max_len
|
||||
num_pad_tokens = sent_max_len - sent_lengths[i]
|
||||
padding = pad_token_id * tf.ones((num_pad_tokens,), dtype=tf.int32)
|
||||
decoded_slice = tf.concat([hypo, padding], axis=-1)
|
||||
|
||||
# finish sentence with EOS token
|
||||
if sent_lengths[i] < max_length:
|
||||
decoded_slice = tf.where(
|
||||
tf.range(sent_max_len, dtype=tf.int32) == sent_lengths[i],
|
||||
eos_token_id * tf.ones((sent_max_len,), dtype=tf.int32),
|
||||
decoded_slice,
|
||||
)
|
||||
# add to list
|
||||
decoded_list.append(decoded_slice)
|
||||
|
||||
if sent_lengths[i] < max_length:
|
||||
decoded_hypo = tf.where(
|
||||
tf.range(max_length) == sent_lengths[i],
|
||||
eos_token_id * tf.ones((sent_max_len,), dtype=tf.int32),
|
||||
decoded_hypo,
|
||||
)
|
||||
decoded_list.append(decoded_hypo)
|
||||
decoded = tf.stack(decoded_list)
|
||||
else:
|
||||
# none of the hypotheses have an eos_token
|
||||
|
||||
@@ -960,6 +960,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
.to(input_ids.device)
|
||||
)
|
||||
encoder_outputs = (encoder_outputs[0].index_select(batch_idx, expanded_idx), *encoder_outputs[1:])
|
||||
|
||||
else:
|
||||
encoder_outputs = None
|
||||
cur_len = input_ids.shape[-1]
|
||||
@@ -1284,14 +1285,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
|
||||
zip(next_tokens[batch_idx], next_scores[batch_idx])
|
||||
):
|
||||
# get beam and word IDs
|
||||
# get beam and token IDs
|
||||
beam_id = beam_token_id // vocab_size
|
||||
token_id = beam_token_id % vocab_size
|
||||
|
||||
effective_beam_id = batch_idx * num_beams + beam_id
|
||||
|
||||
# add to generated hypotheses if end of sentence
|
||||
if (eos_token_id is not None) and (token_id.item() is eos_token_id):
|
||||
# add to generated hypotheses if end of sentence or last iteration
|
||||
if (eos_token_id is not None) and (token_id.item() == eos_token_id):
|
||||
# if beam_token does not belong to top num_beams tokens, it should not be added
|
||||
is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
|
||||
if is_beam_token_worse_than_top_num_beams:
|
||||
@@ -1300,7 +1300,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
input_ids[effective_beam_id].clone(), beam_token_score.item(),
|
||||
)
|
||||
else:
|
||||
# add next predicted word if it is not eos_token
|
||||
# add next predicted token if it is not eos_token
|
||||
next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
|
||||
|
||||
# the beam for next step is full
|
||||
@@ -1330,7 +1330,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
# re-order batch
|
||||
input_ids = input_ids[beam_idx, :]
|
||||
input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)
|
||||
|
||||
# re-order internal states
|
||||
if past is not None:
|
||||
past = self._reorder_cache(past, beam_idx)
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user