* fix reshaping Fixes #21523 * add test * styling * last fixes * Update src/transformers/models/convbert/modeling_convbert.py * code quallity
This commit is contained in:
@@ -459,6 +459,11 @@ class ConvBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
result = model(inputs_embeds=inputs_embeds)
|
||||
self.assertEqual(result.last_hidden_state.shape, (batch_size, seq_length, config.hidden_size))
|
||||
|
||||
def test_reducing_attention_heads(self):
|
||||
config, *inputs_dict = self.model_tester.prepare_config_and_inputs()
|
||||
config.head_ratio = 4
|
||||
self.model_tester.create_and_check_for_masked_lm(config, *inputs_dict)
|
||||
|
||||
|
||||
@require_torch
|
||||
class ConvBertModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user