* fix reshaping Fixes #21523 * add test * styling * last fixes * Update src/transformers/models/convbert/modeling_convbert.py * code quallity
This commit is contained in:
@@ -316,7 +316,7 @@ class ConvBertSelfAttention(nn.Module):
|
|||||||
if config.hidden_size % self.num_attention_heads != 0:
|
if config.hidden_size % self.num_attention_heads != 0:
|
||||||
raise ValueError("hidden_size should be divisible by num_attention_heads")
|
raise ValueError("hidden_size should be divisible by num_attention_heads")
|
||||||
|
|
||||||
self.attention_head_size = config.hidden_size // config.num_attention_heads
|
self.attention_head_size = (config.hidden_size // self.num_attention_heads) // 2
|
||||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||||
|
|
||||||
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
||||||
@@ -413,7 +413,10 @@ class ConvBertSelfAttention(nn.Module):
|
|||||||
conv_out = torch.reshape(conv_out_layer, [batch_size, -1, self.num_attention_heads, self.attention_head_size])
|
conv_out = torch.reshape(conv_out_layer, [batch_size, -1, self.num_attention_heads, self.attention_head_size])
|
||||||
context_layer = torch.cat([context_layer, conv_out], 2)
|
context_layer = torch.cat([context_layer, conv_out], 2)
|
||||||
|
|
||||||
new_context_layer_shape = context_layer.size()[:-2] + (self.head_ratio * self.all_head_size,)
|
# conv and context
|
||||||
|
new_context_layer_shape = context_layer.size()[:-2] + (
|
||||||
|
self.num_attention_heads * self.attention_head_size * 2,
|
||||||
|
)
|
||||||
context_layer = context_layer.view(*new_context_layer_shape)
|
context_layer = context_layer.view(*new_context_layer_shape)
|
||||||
|
|
||||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||||
|
|||||||
@@ -459,6 +459,11 @@ class ConvBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
|||||||
result = model(inputs_embeds=inputs_embeds)
|
result = model(inputs_embeds=inputs_embeds)
|
||||||
self.assertEqual(result.last_hidden_state.shape, (batch_size, seq_length, config.hidden_size))
|
self.assertEqual(result.last_hidden_state.shape, (batch_size, seq_length, config.hidden_size))
|
||||||
|
|
||||||
|
def test_reducing_attention_heads(self):
|
||||||
|
config, *inputs_dict = self.model_tester.prepare_config_and_inputs()
|
||||||
|
config.head_ratio = 4
|
||||||
|
self.model_tester.create_and_check_for_masked_lm(config, *inputs_dict)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class ConvBertModelIntegrationTest(unittest.TestCase):
|
class ConvBertModelIntegrationTest(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user