[T5, Testst] Add extensive hard-coded integration tests and make sure PT and TF give equal results (#3550)

* add some t5 integration tests

* finish summarization and translation integration tests for T5 - results loook good

* add tf test

* fix == vs is bug

* fix tf beam search error and make tf t5 tests pass
This commit is contained in:
Patrick von Platen
2020-04-01 18:01:33 +02:00
committed by GitHub
parent 8538ce9044
commit b815edf69f
5 changed files with 326 additions and 36 deletions

View File

@@ -1112,12 +1112,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
assert shape_list(next_scores) == shape_list(next_tokens) == [batch_size, 2 * num_beams]
# next batch beam content
# list of (batch_size * num_beams) tuple(next hypothesis score, next token, current position in the batch)
next_batch_beam = []
# for each sentence
for batch_idx in range(batch_size):
# if we are done with this sentence
if done[batch_idx]:
assert (
len(generated_hyps[batch_idx]) >= num_beams
@@ -1135,14 +1135,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
zip(next_tokens[batch_idx], next_scores[batch_idx])
):
# get beam and token IDs
beam_id = beam_token_id // vocab_size
token_id = beam_token_id % vocab_size
effective_beam_id = batch_idx * num_beams + beam_id
# add to generated hypotheses if end of sentence or last iteration
if eos_token_id is not None and token_id.numpy() is eos_token_id:
if (eos_token_id is not None) and (token_id.numpy() == eos_token_id):
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
if is_beam_token_worse_than_top_num_beams:
@@ -1158,9 +1157,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
if len(next_sent_beam) == num_beams:
break
# if we are done with this sentence
# Check if were done so that we can save a pad step if all(done)
done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
tf.reduce_max(next_scores[batch_idx]).numpy()
tf.reduce_max(next_scores[batch_idx]).numpy(), cur_len=cur_len
)
# update next beam content
@@ -1178,6 +1177,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
beam_tokens = tf.convert_to_tensor([x[1] for x in next_batch_beam], dtype=tf.int32)
beam_idx = tf.convert_to_tensor([x[2] for x in next_batch_beam], dtype=tf.int32)
print("Scores: {}-{}".format(cur_len, beam_scores.numpy()))
# re-order batch
input_ids = tf.stack([tf.identity(input_ids[x, :]) for x in beam_idx])
input_ids = tf.concat([input_ids, tf.expand_dims(beam_tokens, 1)], axis=-1)
@@ -1185,6 +1186,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
if past is not None:
past = self._reorder_cache(past, beam_idx)
# extend attention_mask for new generated input if only decoder
if self.config.is_encoder_decoder is False:
attention_mask = tf.concat(
[attention_mask, tf.ones((shape_list(attention_mask)[0], 1), dtype=tf.int32)], axis=-1
@@ -1244,16 +1246,26 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
# fill with hypothesis and eos_token_id if necessary
for i, hypo in enumerate(best):
padding = tf.ones((sent_max_len - shape_list(hypo)[0],), dtype=tf.int32) * pad_token_id
decoded_hypo = tf.concat([hypo, padding], axis=0)
assert sent_lengths[i] == shape_list(hypo)[0]
# if sent_length is max_len do not pad
if sent_lengths[i] == sent_max_len:
decoded_slice = hypo
else:
# else pad to sent_max_len
num_pad_tokens = sent_max_len - sent_lengths[i]
padding = pad_token_id * tf.ones((num_pad_tokens,), dtype=tf.int32)
decoded_slice = tf.concat([hypo, padding], axis=-1)
# finish sentence with EOS token
if sent_lengths[i] < max_length:
decoded_slice = tf.where(
tf.range(sent_max_len, dtype=tf.int32) == sent_lengths[i],
eos_token_id * tf.ones((sent_max_len,), dtype=tf.int32),
decoded_slice,
)
# add to list
decoded_list.append(decoded_slice)
if sent_lengths[i] < max_length:
decoded_hypo = tf.where(
tf.range(max_length) == sent_lengths[i],
eos_token_id * tf.ones((sent_max_len,), dtype=tf.int32),
decoded_hypo,
)
decoded_list.append(decoded_hypo)
decoded = tf.stack(decoded_list)
else:
# none of the hypotheses have an eos_token