From 0e1718afb61c1acb786d120af4c32341b454eed5 Mon Sep 17 00:00:00 2001 From: sadakmed Date: Mon, 5 Jul 2021 11:21:25 +0200 Subject: [PATCH] create LxmertModelIntegrationTest Pytorch (#9989) * create LxmertModelIntegrationTest * implementation using numpy seeding to fix inputs params. * fix code quality * isort check --- tests/test_modeling_lxmert.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/test_modeling_lxmert.py b/tests/test_modeling_lxmert.py index 451db8089a..1febee25d5 100644 --- a/tests/test_modeling_lxmert.py +++ b/tests/test_modeling_lxmert.py @@ -17,6 +17,8 @@ import copy import unittest +import numpy as np + from transformers import is_torch_available from transformers.models.auto import get_values from transformers.testing_utils import require_torch, slow, torch_device @@ -727,3 +729,24 @@ class LxmertModelTest(ModelTesterMixin, unittest.TestCase): self.assertIsNotNone(attentions_vision.grad) self.assertIsNotNone(hidden_states_vision.grad) self.assertIsNotNone(attentions_vision.grad) + + +@require_torch +class LxmertModelIntegrationTest(unittest.TestCase): + @slow + def test_inference_no_head_absolute_embedding(self): + model = LxmertModel.from_pretrained(LXMERT_PRETRAINED_MODEL_ARCHIVE_LIST[0]) + input_ids = torch.tensor([[101, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 102]]) + num_visual_features = 10 + _, visual_feats = np.random.seed(0), np.random.rand(1, num_visual_features, LxmertModel.config.visual_feat_dim) + _, visual_pos = np.random.seed(0), np.random.rand(1, num_visual_features, 4) + visual_feats = torch.as_tensor(visual_feats, dtype=torch.float32) + visual_pos = torch.as_tensor(visual_pos, dtype=torch.float32) + output = model(input_ids, visual_feats=visual_feats, visual_pos=visual_pos)[0] + expected_shape = torch.Size([1, 11, 768]) + self.assertEqual(expected_shape, output.shape) + expected_slice = torch.tensor( + [[[0.2417, -0.9807, 0.1480], [1.2541, -0.8320, 0.5112], [1.4070, -1.1052, 0.6990]]] + ) + + self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))