From b599b192896b0016aab192394d4d0ce8f8e86672 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Wed, 1 Mar 2023 11:11:04 +0100 Subject: [PATCH] [ConvBert] Fix #21523 (#21849) * fix reshaping Fixes #21523 * add test * styling * last fixes * Update src/transformers/models/convbert/modeling_convbert.py * code quallity --- src/transformers/models/convbert/modeling_convbert.py | 7 +++++-- tests/models/convbert/test_modeling_convbert.py | 5 +++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/convbert/modeling_convbert.py b/src/transformers/models/convbert/modeling_convbert.py index 655ea55eeb..59f81d3145 100755 --- a/src/transformers/models/convbert/modeling_convbert.py +++ b/src/transformers/models/convbert/modeling_convbert.py @@ -316,7 +316,7 @@ class ConvBertSelfAttention(nn.Module): if config.hidden_size % self.num_attention_heads != 0: raise ValueError("hidden_size should be divisible by num_attention_heads") - self.attention_head_size = config.hidden_size // config.num_attention_heads + self.attention_head_size = (config.hidden_size // self.num_attention_heads) // 2 self.all_head_size = self.num_attention_heads * self.attention_head_size self.query = nn.Linear(config.hidden_size, self.all_head_size) @@ -413,7 +413,10 @@ class ConvBertSelfAttention(nn.Module): conv_out = torch.reshape(conv_out_layer, [batch_size, -1, self.num_attention_heads, self.attention_head_size]) context_layer = torch.cat([context_layer, conv_out], 2) - new_context_layer_shape = context_layer.size()[:-2] + (self.head_ratio * self.all_head_size,) + # conv and context + new_context_layer_shape = context_layer.size()[:-2] + ( + self.num_attention_heads * self.attention_head_size * 2, + ) context_layer = context_layer.view(*new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) diff --git a/tests/models/convbert/test_modeling_convbert.py b/tests/models/convbert/test_modeling_convbert.py index 98cd937763..dc1550acc2 100644 --- a/tests/models/convbert/test_modeling_convbert.py +++ b/tests/models/convbert/test_modeling_convbert.py @@ -459,6 +459,11 @@ class ConvBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase result = model(inputs_embeds=inputs_embeds) self.assertEqual(result.last_hidden_state.shape, (batch_size, seq_length, config.hidden_size)) + def test_reducing_attention_heads(self): + config, *inputs_dict = self.model_tester.prepare_config_and_inputs() + config.head_ratio = 4 + self.model_tester.create_and_check_for_masked_lm(config, *inputs_dict) + @require_torch class ConvBertModelIntegrationTest(unittest.TestCase):