[No merge] TF integration testing (#7621)
* stash * TF Integration testing for ELECTRA, BERT, Longformer * Trigger slow tests * Apply suggestions from code review
This commit is contained in:
@@ -248,3 +248,19 @@ class TFElectraModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
for model_name in ["google/electra-small-discriminator"]:
|
||||
model = TFElectraModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
class TFElectraModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_inference_masked_lm(self):
|
||||
model = TFElectraForPreTraining.from_pretrained("lysandre/tiny-electra-random")
|
||||
input_ids = tf.constant([[0, 1, 2, 3, 4, 5]])
|
||||
output = model(input_ids)[0]
|
||||
|
||||
expected_shape = [1, 6]
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
|
||||
print(output[:, :3])
|
||||
|
||||
expected_slice = tf.constant([[-0.24651965, 0.8835437, 1.823782]])
|
||||
tf.debugging.assert_near(output[:, :3], expected_slice, atol=1e-4)
|
||||
|
||||
Reference in New Issue
Block a user