From 77db257e2a67d4b043cf03bf390947fcd71a9f53 Mon Sep 17 00:00:00 2001 From: raghavanone <115454562+raghavanone@users.noreply.github.com> Date: Wed, 1 Feb 2023 20:17:25 +0530 Subject: [PATCH] Fix the issue of using only inputs_embeds in convbert model (#21398) * Fix the input embeds issue with tests * Fix black and isort issue * Clean up tests * Add slow tag to the test introduced * Incorporate PR feedbacks --- src/transformers/models/convbert/modeling_convbert.py | 2 +- tests/models/convbert/test_modeling_convbert.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/convbert/modeling_convbert.py b/src/transformers/models/convbert/modeling_convbert.py index 6a3e81e25e..655ea55eeb 100755 --- a/src/transformers/models/convbert/modeling_convbert.py +++ b/src/transformers/models/convbert/modeling_convbert.py @@ -818,12 +818,12 @@ class ConvBertModel(ConvBertPreTrainedModel): raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() - batch_size, seq_length = input_shape elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") + batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device if attention_mask is None: diff --git a/tests/models/convbert/test_modeling_convbert.py b/tests/models/convbert/test_modeling_convbert.py index f2b82aaadf..49f363ff7b 100644 --- a/tests/models/convbert/test_modeling_convbert.py +++ b/tests/models/convbert/test_modeling_convbert.py @@ -437,6 +437,17 @@ class ConvBertModelTest(ModelTesterMixin, unittest.TestCase): loaded = torch.jit.load(os.path.join(tmp, "traced_model.pt"), map_location=torch_device) loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device)) + def test_model_for_input_embeds(self): + batch_size = 2 + seq_length = 10 + inputs_embeds = torch.rand([batch_size, seq_length, 768]) + config = self.model_tester.get_config() + model = ConvBertModel(config=config) + model.to(torch_device) + model.eval() + result = model(inputs_embeds=inputs_embeds) + self.assertEqual(result.last_hidden_state.shape, (batch_size, seq_length, config.hidden_size)) + @require_torch class ConvBertModelIntegrationTest(unittest.TestCase):