Tf model outputs (#6247)

* TF outputs and test on BERT

* Albert to DistilBert

* All remaining TF models except T5

* Documentation

* One file forgotten

* TF outputs and test on BERT

* Albert to DistilBert

* All remaining TF models except T5

* Documentation

* One file forgotten

* Add new models and fix issues

* Quality improvements

* Add T5

* A bit of cleanup

* Fix for slow tests

* Style
This commit is contained in:
Sylvain Gugger
2020-08-05 11:34:39 -04:00
committed by GitHub
parent bd0eab351a
commit c67d1a0259
51 changed files with 3253 additions and 2430 deletions

View File

@@ -78,6 +78,7 @@ class TFT5ModelTester:
bos_token_id=self.pad_token_id,
pad_token_id=self.pad_token_id,
decoder_start_token_id=self.pad_token_id,
return_dict=True,
)
return (config, input_ids, input_mask, token_labels)
@@ -89,22 +90,14 @@ class TFT5ModelTester:
"decoder_input_ids": input_ids,
"decoder_attention_mask": input_mask,
}
decoder_output, decoder_past, encoder_output = model(inputs)
result = model(inputs)
decoder_output, decoder_past, encoder_output = model(
input_ids, decoder_attention_mask=input_mask, decoder_input_ids=input_ids
)
result = {
"encoder_output": encoder_output.numpy(),
"decoder_past": decoder_past,
"decoder_output": decoder_output.numpy(),
}
self.parent.assertListEqual(
list(result["encoder_output"].shape), [self.batch_size, self.seq_length, self.hidden_size]
)
self.parent.assertListEqual(
list(result["decoder_output"].shape), [self.batch_size, self.seq_length, self.hidden_size]
)
result = model(input_ids, decoder_attention_mask=input_mask, decoder_input_ids=input_ids)
decoder_output = result["last_hidden_state"]
decoder_past = result["decoder_past_key_values"]
encoder_output = result["encoder_last_hidden_state"]
self.parent.assertListEqual(list(encoder_output.shape), [self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertListEqual(list(decoder_output.shape), [self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertEqual(len(decoder_past), 2)
# decoder_past[0] should correspond to encoder output
self.parent.assertTrue(tf.reduce_all(tf.math.equal(decoder_past[0][0], encoder_output)))
@@ -121,14 +114,9 @@ class TFT5ModelTester:
"decoder_attention_mask": input_mask,
}
prediction_scores, _, _ = model(inputs_dict)
result = model(inputs_dict)
result = {
"prediction_scores": prediction_scores.numpy(),
}
self.parent.assertListEqual(
list(result["prediction_scores"].shape), [self.batch_size, self.seq_length, self.vocab_size]
)
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.vocab_size])
def create_and_check_t5_decoder_model_past(self, config, input_ids, decoder_input_ids, attention_mask):
model = TFT5Model(config=config).get_decoder()