From 2f17464266ef5fe8314f78de1320e16cdf29d909 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Mon, 11 Nov 2019 19:56:45 -0500 Subject: [PATCH] [common attributes] Slightly sharper test coverage --- transformers/tests/modeling_common_test.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/transformers/tests/modeling_common_test.py b/transformers/tests/modeling_common_test.py index 38b2ceafa4..777e62459b 100644 --- a/transformers/tests/modeling_common_test.py +++ b/transformers/tests/modeling_common_test.py @@ -468,9 +468,15 @@ class CommonTestCases: for model_class in self.all_model_classes: model = model_class(config) - model.get_input_embeddings() + self.assertIsInstance( + model.get_input_embeddings(), + torch.nn.Embedding + ) model.set_input_embeddings(torch.nn.Embedding(10, 10)) - model.get_output_embeddings() + x = model.get_output_embeddings() + self.assertTrue( + x is None or isinstance(x, torch.nn.Linear) + ) def test_tie_model_weights(self): if not self.test_torchscript: