Model output test (#6155)
* Use return_dict=True in all tests * Formatting
This commit is contained in:
@@ -28,7 +28,7 @@ if is_torch_available():
|
||||
class XLMRobertaModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_xlm_roberta_base(self):
|
||||
model = XLMRobertaModel.from_pretrained("xlm-roberta-base")
|
||||
model = XLMRobertaModel.from_pretrained("xlm-roberta-base", return_dict=True)
|
||||
input_ids = torch.tensor([[0, 581, 10269, 83, 99942, 136, 60742, 23, 70, 80583, 18276, 2]])
|
||||
# The dog is cute and lives in the garden house
|
||||
|
||||
@@ -40,14 +40,14 @@ class XLMRobertaModelIntegrationTest(unittest.TestCase):
|
||||
# xlmr.eval()
|
||||
# expected_output_values_last_dim = xlmr.extract_features(input_ids[0])[:, :, -1]
|
||||
|
||||
output = model(input_ids)[0].detach()
|
||||
output = model(input_ids)["last_hidden_state"].detach()
|
||||
self.assertEqual(output.shape, expected_output_shape)
|
||||
# compare the actual values for a slice of last dim
|
||||
self.assertTrue(torch.allclose(output[:, :, -1], expected_output_values_last_dim, atol=1e-3))
|
||||
|
||||
@slow
|
||||
def test_xlm_roberta_large(self):
|
||||
model = XLMRobertaModel.from_pretrained("xlm-roberta-large")
|
||||
model = XLMRobertaModel.from_pretrained("xlm-roberta-large", return_dict=True)
|
||||
input_ids = torch.tensor([[0, 581, 10269, 83, 99942, 136, 60742, 23, 70, 80583, 18276, 2]])
|
||||
# The dog is cute and lives in the garden house
|
||||
|
||||
@@ -59,7 +59,7 @@ class XLMRobertaModelIntegrationTest(unittest.TestCase):
|
||||
# xlmr.eval()
|
||||
# expected_output_values_last_dim = xlmr.extract_features(input_ids[0])[:, :, -1]
|
||||
|
||||
output = model(input_ids)[0].detach()
|
||||
output = model(input_ids)["last_hidden_state"].detach()
|
||||
self.assertEqual(output.shape, expected_output_shape)
|
||||
# compare the actual values for a slice of last dim
|
||||
self.assertTrue(torch.allclose(output[:, :, -1], expected_output_values_last_dim, atol=1e-3))
|
||||
|
||||
Reference in New Issue
Block a user