Tf model outputs (#6247)
* TF outputs and test on BERT * Albert to DistilBert * All remaining TF models except T5 * Documentation * One file forgotten * TF outputs and test on BERT * Albert to DistilBert * All remaining TF models except T5 * Documentation * One file forgotten * Add new models and fix issues * Quality improvements * Add T5 * A bit of cleanup * Fix for slow tests * Style
This commit is contained in:
@@ -78,6 +78,7 @@ class TFT5ModelTester:
|
||||
bos_token_id=self.pad_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
decoder_start_token_id=self.pad_token_id,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
return (config, input_ids, input_mask, token_labels)
|
||||
@@ -89,22 +90,14 @@ class TFT5ModelTester:
|
||||
"decoder_input_ids": input_ids,
|
||||
"decoder_attention_mask": input_mask,
|
||||
}
|
||||
decoder_output, decoder_past, encoder_output = model(inputs)
|
||||
result = model(inputs)
|
||||
|
||||
decoder_output, decoder_past, encoder_output = model(
|
||||
input_ids, decoder_attention_mask=input_mask, decoder_input_ids=input_ids
|
||||
)
|
||||
result = {
|
||||
"encoder_output": encoder_output.numpy(),
|
||||
"decoder_past": decoder_past,
|
||||
"decoder_output": decoder_output.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["encoder_output"].shape), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(result["decoder_output"].shape), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
)
|
||||
result = model(input_ids, decoder_attention_mask=input_mask, decoder_input_ids=input_ids)
|
||||
decoder_output = result["last_hidden_state"]
|
||||
decoder_past = result["decoder_past_key_values"]
|
||||
encoder_output = result["encoder_last_hidden_state"]
|
||||
self.parent.assertListEqual(list(encoder_output.shape), [self.batch_size, self.seq_length, self.hidden_size])
|
||||
self.parent.assertListEqual(list(decoder_output.shape), [self.batch_size, self.seq_length, self.hidden_size])
|
||||
self.parent.assertEqual(len(decoder_past), 2)
|
||||
# decoder_past[0] should correspond to encoder output
|
||||
self.parent.assertTrue(tf.reduce_all(tf.math.equal(decoder_past[0][0], encoder_output)))
|
||||
@@ -121,14 +114,9 @@ class TFT5ModelTester:
|
||||
"decoder_attention_mask": input_mask,
|
||||
}
|
||||
|
||||
prediction_scores, _, _ = model(inputs_dict)
|
||||
result = model(inputs_dict)
|
||||
|
||||
result = {
|
||||
"prediction_scores": prediction_scores.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_scores"].shape), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
)
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.vocab_size])
|
||||
|
||||
def create_and_check_t5_decoder_model_past(self, config, input_ids, decoder_input_ids, attention_mask):
|
||||
model = TFT5Model(config=config).get_decoder()
|
||||
|
||||
Reference in New Issue
Block a user