Fix the issue of using only inputs_embeds in convbert model (#21398)
* Fix the input embeds issue with tests * Fix black and isort issue * Clean up tests * Add slow tag to the test introduced * Incorporate PR feedbacks
This commit is contained in:
@@ -818,12 +818,12 @@ class ConvBertModel(ConvBertPreTrainedModel):
|
|||||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
elif input_ids is not None:
|
elif input_ids is not None:
|
||||||
input_shape = input_ids.size()
|
input_shape = input_ids.size()
|
||||||
batch_size, seq_length = input_shape
|
|
||||||
elif inputs_embeds is not None:
|
elif inputs_embeds is not None:
|
||||||
input_shape = inputs_embeds.size()[:-1]
|
input_shape = inputs_embeds.size()[:-1]
|
||||||
else:
|
else:
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
|
|||||||
@@ -437,6 +437,17 @@ class ConvBertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
loaded = torch.jit.load(os.path.join(tmp, "traced_model.pt"), map_location=torch_device)
|
loaded = torch.jit.load(os.path.join(tmp, "traced_model.pt"), map_location=torch_device)
|
||||||
loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device))
|
loaded(inputs_dict["input_ids"].to(torch_device), inputs_dict["attention_mask"].to(torch_device))
|
||||||
|
|
||||||
|
def test_model_for_input_embeds(self):
|
||||||
|
batch_size = 2
|
||||||
|
seq_length = 10
|
||||||
|
inputs_embeds = torch.rand([batch_size, seq_length, 768])
|
||||||
|
config = self.model_tester.get_config()
|
||||||
|
model = ConvBertModel(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(inputs_embeds=inputs_embeds)
|
||||||
|
self.assertEqual(result.last_hidden_state.shape, (batch_size, seq_length, config.hidden_size))
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class ConvBertModelIntegrationTest(unittest.TestCase):
|
class ConvBertModelIntegrationTest(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user