Fix PerceiverMLP and test (#16405)
Co-authored-by: Jaesun Park <jaesun.park1@navercorp.com>
This commit is contained in:
@@ -420,7 +420,7 @@ class PerceiverMLP(nn.Module):
|
|||||||
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
||||||
else:
|
else:
|
||||||
self.intermediate_act_fn = config.hidden_act
|
self.intermediate_act_fn = config.hidden_act
|
||||||
self.dense2 = nn.Linear(input_size, input_size)
|
self.dense2 = nn.Linear(widening_factor * input_size, input_size)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
hidden_states = self.dense1(hidden_states)
|
hidden_states = self.dense1(hidden_states)
|
||||||
|
|||||||
@@ -82,6 +82,8 @@ class PerceiverModelTester:
|
|||||||
num_self_attends_per_block=2,
|
num_self_attends_per_block=2,
|
||||||
num_self_attention_heads=1,
|
num_self_attention_heads=1,
|
||||||
num_cross_attention_heads=1,
|
num_cross_attention_heads=1,
|
||||||
|
self_attention_widening_factor=4,
|
||||||
|
cross_attention_widening_factor=4,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
use_input_mask=True,
|
use_input_mask=True,
|
||||||
use_labels=True,
|
use_labels=True,
|
||||||
@@ -109,6 +111,8 @@ class PerceiverModelTester:
|
|||||||
self.num_self_attends_per_block = num_self_attends_per_block
|
self.num_self_attends_per_block = num_self_attends_per_block
|
||||||
self.num_self_attention_heads = num_self_attention_heads
|
self.num_self_attention_heads = num_self_attention_heads
|
||||||
self.num_cross_attention_heads = num_cross_attention_heads
|
self.num_cross_attention_heads = num_cross_attention_heads
|
||||||
|
self.self_attention_widening_factor = self_attention_widening_factor
|
||||||
|
self.cross_attention_widening_factor = cross_attention_widening_factor
|
||||||
self.is_training = is_training
|
self.is_training = is_training
|
||||||
self.use_input_mask = use_input_mask
|
self.use_input_mask = use_input_mask
|
||||||
self.use_labels = use_labels
|
self.use_labels = use_labels
|
||||||
@@ -174,10 +178,14 @@ class PerceiverModelTester:
|
|||||||
return PerceiverConfig(
|
return PerceiverConfig(
|
||||||
num_latents=self.num_latents,
|
num_latents=self.num_latents,
|
||||||
d_latents=self.d_latents,
|
d_latents=self.d_latents,
|
||||||
|
qk_channels=self.d_latents,
|
||||||
|
v_channels=self.d_latents,
|
||||||
num_blocks=self.num_blocks,
|
num_blocks=self.num_blocks,
|
||||||
num_self_attends_per_block=self.num_self_attends_per_block,
|
num_self_attends_per_block=self.num_self_attends_per_block,
|
||||||
num_self_attention_heads=self.num_self_attention_heads,
|
num_self_attention_heads=self.num_self_attention_heads,
|
||||||
num_cross_attention_heads=self.num_cross_attention_heads,
|
num_cross_attention_heads=self.num_cross_attention_heads,
|
||||||
|
self_attention_widening_factor=self.self_attention_widening_factor,
|
||||||
|
cross_attention_widening_factor=self.cross_attention_widening_factor,
|
||||||
vocab_size=self.vocab_size,
|
vocab_size=self.vocab_size,
|
||||||
hidden_act=self.hidden_act,
|
hidden_act=self.hidden_act,
|
||||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||||
|
|||||||
Reference in New Issue
Block a user