Tf model outputs (#6247)
* TF outputs and test on BERT * Albert to DistilBert * All remaining TF models except T5 * Documentation * One file forgotten * TF outputs and test on BERT * Albert to DistilBert * All remaining TF models except T5 * Documentation * One file forgotten * Add new models and fix issues * Quality improvements * Add T5 * A bit of cleanup * Fix for slow tests * Style
This commit is contained in:
@@ -146,7 +146,8 @@ class TFModelTesterMixin:
|
||||
tf.saved_model.save(model, tmpdirname)
|
||||
model = tf.keras.models.load_model(tmpdirname)
|
||||
outputs = model(inputs_dict)
|
||||
hidden_states = [t.numpy() for t in outputs[-1]]
|
||||
output = outputs[list(outputs.keys())[-1]] if isinstance(outputs, dict) else outputs[-1]
|
||||
hidden_states = [t.numpy() for t in output]
|
||||
self.assertEqual(len(outputs), num_out)
|
||||
self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
|
||||
self.assertListEqual(
|
||||
@@ -177,7 +178,8 @@ class TFModelTesterMixin:
|
||||
tf.saved_model.save(model, tmpdirname)
|
||||
model = tf.keras.models.load_model(tmpdirname)
|
||||
outputs = model(inputs_dict)
|
||||
attentions = [t.numpy() for t in outputs[-1]]
|
||||
output = outputs[list(outputs.keys())[-1]] if isinstance(outputs, dict) else outputs[-1]
|
||||
attentions = [t.numpy() for t in output]
|
||||
self.assertEqual(len(outputs), num_out)
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
@@ -238,6 +240,8 @@ class TFModelTesterMixin:
|
||||
# Make sure we don't have nans
|
||||
if isinstance(after_outputs, tf.Tensor):
|
||||
out_1 = after_outputs.numpy()
|
||||
elif isinstance(after_outputs, dict):
|
||||
out_1 = after_outputs[list(after_outputs.keys())[0]]
|
||||
else:
|
||||
out_1 = after_outputs[0].numpy()
|
||||
out_2 = outputs[0].numpy()
|
||||
|
||||
Reference in New Issue
Block a user