Fix PerceiverMLP and test (#16405)
Co-authored-by: Jaesun Park <jaesun.park1@navercorp.com>
This commit is contained in:
@@ -420,7 +420,7 @@ class PerceiverMLP(nn.Module):
|
||||
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
||||
else:
|
||||
self.intermediate_act_fn = config.hidden_act
|
||||
self.dense2 = nn.Linear(input_size, input_size)
|
||||
self.dense2 = nn.Linear(widening_factor * input_size, input_size)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.dense1(hidden_states)
|
||||
|
||||
@@ -82,6 +82,8 @@ class PerceiverModelTester:
|
||||
num_self_attends_per_block=2,
|
||||
num_self_attention_heads=1,
|
||||
num_cross_attention_heads=1,
|
||||
self_attention_widening_factor=4,
|
||||
cross_attention_widening_factor=4,
|
||||
is_training=True,
|
||||
use_input_mask=True,
|
||||
use_labels=True,
|
||||
@@ -109,6 +111,8 @@ class PerceiverModelTester:
|
||||
self.num_self_attends_per_block = num_self_attends_per_block
|
||||
self.num_self_attention_heads = num_self_attention_heads
|
||||
self.num_cross_attention_heads = num_cross_attention_heads
|
||||
self.self_attention_widening_factor = self_attention_widening_factor
|
||||
self.cross_attention_widening_factor = cross_attention_widening_factor
|
||||
self.is_training = is_training
|
||||
self.use_input_mask = use_input_mask
|
||||
self.use_labels = use_labels
|
||||
@@ -174,10 +178,14 @@ class PerceiverModelTester:
|
||||
return PerceiverConfig(
|
||||
num_latents=self.num_latents,
|
||||
d_latents=self.d_latents,
|
||||
qk_channels=self.d_latents,
|
||||
v_channels=self.d_latents,
|
||||
num_blocks=self.num_blocks,
|
||||
num_self_attends_per_block=self.num_self_attends_per_block,
|
||||
num_self_attention_heads=self.num_self_attention_heads,
|
||||
num_cross_attention_heads=self.num_cross_attention_heads,
|
||||
self_attention_widening_factor=self.self_attention_widening_factor,
|
||||
cross_attention_widening_factor=self.cross_attention_widening_factor,
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_act=self.hidden_act,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
|
||||
Reference in New Issue
Block a user