Fix the issue of using only inputs_embeds in convbert model (#21398)

* Fix the input embeds issue with tests

* Fix black and isort issue

* Clean up tests

* Add slow tag to the test introduced

* Incorporate PR feedbacks
This commit is contained in:
raghavanone
2023-02-01 20:17:25 +05:30
committed by GitHub
parent 65b5035a1d
commit 77db257e2a
2 changed files with 12 additions and 1 deletions

View File

@@ -437,6 +437,17 @@ class ConvBertModelTest(ModelTesterMixin, unittest.TestCase):
loaded = torch.jit.load(os.path.join(tmp, "traced_model.pt"), map_location=torch_device)
loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device))
def test_model_for_input_embeds(self):
batch_size = 2
seq_length = 10
inputs_embeds = torch.rand([batch_size, seq_length, 768])
config = self.model_tester.get_config()
model = ConvBertModel(config=config)
model.to(torch_device)
model.eval()
result = model(inputs_embeds=inputs_embeds)
self.assertEqual(result.last_hidden_state.shape, (batch_size, seq_length, config.hidden_size))
@require_torch
class ConvBertModelIntegrationTest(unittest.TestCase):