[ConvBert] Fix #21523 (#21849)

* fix reshaping
Fixes #21523

* add test

* styling

* last fixes

* Update src/transformers/models/convbert/modeling_convbert.py

* code quallity
This commit is contained in:
Arthur
2023-03-01 11:11:04 +01:00
committed by GitHub
parent 44e3e3fb49
commit b599b19289
2 changed files with 10 additions and 2 deletions

View File

@@ -459,6 +459,11 @@ class ConvBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
result = model(inputs_embeds=inputs_embeds)
self.assertEqual(result.last_hidden_state.shape, (batch_size, seq_length, config.hidden_size))
def test_reducing_attention_heads(self):
config, *inputs_dict = self.model_tester.prepare_config_and_inputs()
config.head_ratio = 4
self.model_tester.create_and_check_for_masked_lm(config, *inputs_dict)
@require_torch
class ConvBertModelIntegrationTest(unittest.TestCase):