From e0ac72b7bd2cbed43164b69308c2cf546e1a49a5 Mon Sep 17 00:00:00 2001 From: Jaesun Park Date: Mon, 28 Mar 2022 21:06:48 +0900 Subject: [PATCH] Fix PerceiverMLP and test (#16405) Co-authored-by: Jaesun Park --- src/transformers/models/perceiver/modeling_perceiver.py | 2 +- tests/perceiver/test_modeling_perceiver.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/perceiver/modeling_perceiver.py b/src/transformers/models/perceiver/modeling_perceiver.py index e79a7d5999..f323b9091a 100755 --- a/src/transformers/models/perceiver/modeling_perceiver.py +++ b/src/transformers/models/perceiver/modeling_perceiver.py @@ -420,7 +420,7 @@ class PerceiverMLP(nn.Module): self.intermediate_act_fn = ACT2FN[config.hidden_act] else: self.intermediate_act_fn = config.hidden_act - self.dense2 = nn.Linear(input_size, input_size) + self.dense2 = nn.Linear(widening_factor * input_size, input_size) def forward(self, hidden_states): hidden_states = self.dense1(hidden_states) diff --git a/tests/perceiver/test_modeling_perceiver.py b/tests/perceiver/test_modeling_perceiver.py index 5fd75ab649..a394b00852 100644 --- a/tests/perceiver/test_modeling_perceiver.py +++ b/tests/perceiver/test_modeling_perceiver.py @@ -82,6 +82,8 @@ class PerceiverModelTester: num_self_attends_per_block=2, num_self_attention_heads=1, num_cross_attention_heads=1, + self_attention_widening_factor=4, + cross_attention_widening_factor=4, is_training=True, use_input_mask=True, use_labels=True, @@ -109,6 +111,8 @@ class PerceiverModelTester: self.num_self_attends_per_block = num_self_attends_per_block self.num_self_attention_heads = num_self_attention_heads self.num_cross_attention_heads = num_cross_attention_heads + self.self_attention_widening_factor = self_attention_widening_factor + self.cross_attention_widening_factor = cross_attention_widening_factor self.is_training = is_training self.use_input_mask = use_input_mask self.use_labels = use_labels @@ -174,10 +178,14 @@ class PerceiverModelTester: return PerceiverConfig( num_latents=self.num_latents, d_latents=self.d_latents, + qk_channels=self.d_latents, + v_channels=self.d_latents, num_blocks=self.num_blocks, num_self_attends_per_block=self.num_self_attends_per_block, num_self_attention_heads=self.num_self_attention_heads, num_cross_attention_heads=self.num_cross_attention_heads, + self_attention_widening_factor=self.self_attention_widening_factor, + cross_attention_widening_factor=self.cross_attention_widening_factor, vocab_size=self.vocab_size, hidden_act=self.hidden_act, attention_probs_dropout_prob=self.attention_probs_dropout_prob,