Fix PerceiverMLP and test (#16405)

Co-authored-by: Jaesun Park <jaesun.park1@navercorp.com>
This commit is contained in:
Jaesun Park
2022-03-28 21:06:48 +09:00
committed by GitHub
parent 473709fc76
commit e0ac72b7bd
2 changed files with 9 additions and 1 deletions

View File

@@ -420,7 +420,7 @@ class PerceiverMLP(nn.Module):
self.intermediate_act_fn = ACT2FN[config.hidden_act] self.intermediate_act_fn = ACT2FN[config.hidden_act]
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
self.dense2 = nn.Linear(input_size, input_size) self.dense2 = nn.Linear(widening_factor * input_size, input_size)
def forward(self, hidden_states): def forward(self, hidden_states):
hidden_states = self.dense1(hidden_states) hidden_states = self.dense1(hidden_states)

View File

@@ -82,6 +82,8 @@ class PerceiverModelTester:
num_self_attends_per_block=2, num_self_attends_per_block=2,
num_self_attention_heads=1, num_self_attention_heads=1,
num_cross_attention_heads=1, num_cross_attention_heads=1,
self_attention_widening_factor=4,
cross_attention_widening_factor=4,
is_training=True, is_training=True,
use_input_mask=True, use_input_mask=True,
use_labels=True, use_labels=True,
@@ -109,6 +111,8 @@ class PerceiverModelTester:
self.num_self_attends_per_block = num_self_attends_per_block self.num_self_attends_per_block = num_self_attends_per_block
self.num_self_attention_heads = num_self_attention_heads self.num_self_attention_heads = num_self_attention_heads
self.num_cross_attention_heads = num_cross_attention_heads self.num_cross_attention_heads = num_cross_attention_heads
self.self_attention_widening_factor = self_attention_widening_factor
self.cross_attention_widening_factor = cross_attention_widening_factor
self.is_training = is_training self.is_training = is_training
self.use_input_mask = use_input_mask self.use_input_mask = use_input_mask
self.use_labels = use_labels self.use_labels = use_labels
@@ -174,10 +178,14 @@ class PerceiverModelTester:
return PerceiverConfig( return PerceiverConfig(
num_latents=self.num_latents, num_latents=self.num_latents,
d_latents=self.d_latents, d_latents=self.d_latents,
qk_channels=self.d_latents,
v_channels=self.d_latents,
num_blocks=self.num_blocks, num_blocks=self.num_blocks,
num_self_attends_per_block=self.num_self_attends_per_block, num_self_attends_per_block=self.num_self_attends_per_block,
num_self_attention_heads=self.num_self_attention_heads, num_self_attention_heads=self.num_self_attention_heads,
num_cross_attention_heads=self.num_cross_attention_heads, num_cross_attention_heads=self.num_cross_attention_heads,
self_attention_widening_factor=self.self_attention_widening_factor,
cross_attention_widening_factor=self.cross_attention_widening_factor,
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
hidden_act=self.hidden_act, hidden_act=self.hidden_act,
attention_probs_dropout_prob=self.attention_probs_dropout_prob, attention_probs_dropout_prob=self.attention_probs_dropout_prob,