BART & FSMT: fix decoder not returning hidden states from the last layer (#8597)
* Fix decoder not returning hidden states from the last layer * Resolve conflict * Change the way to gather hidden states * Add decoder hidden states test * Make pytest and black happy * Remove redundant line * remove new line Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
@@ -610,6 +610,12 @@ class BartDecoder(nn.Module):
|
|||||||
all_self_attns += (layer_self_attn,)
|
all_self_attns += (layer_self_attn,)
|
||||||
all_cross_attentions += (layer_cross_attn,)
|
all_cross_attentions += (layer_cross_attn,)
|
||||||
|
|
||||||
|
# add hidden states from the last decoder layer
|
||||||
|
if output_hidden_states:
|
||||||
|
x = x.transpose(0, 1)
|
||||||
|
all_hidden_states += (x,)
|
||||||
|
x = x.transpose(0, 1)
|
||||||
|
|
||||||
if self.layer_norm: # if config.add_final_layer_norm (mBART)
|
if self.layer_norm: # if config.add_final_layer_norm (mBART)
|
||||||
x = self.layer_norm(x)
|
x = self.layer_norm(x)
|
||||||
|
|
||||||
|
|||||||
@@ -692,6 +692,12 @@ class FSMTDecoder(nn.Module):
|
|||||||
all_self_attns += (layer_self_attn,)
|
all_self_attns += (layer_self_attn,)
|
||||||
all_cross_attns += (layer_cross_attn,)
|
all_cross_attns += (layer_cross_attn,)
|
||||||
|
|
||||||
|
# add hidden states from the last decoder layer
|
||||||
|
if output_hidden_states:
|
||||||
|
x = x.transpose(0, 1)
|
||||||
|
all_hidden_states += (x,)
|
||||||
|
x = x.transpose(0, 1)
|
||||||
|
|
||||||
# Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
|
# Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
|
||||||
x = x.transpose(0, 1)
|
x = x.transpose(0, 1)
|
||||||
encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
|
encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
|
||||||
|
|||||||
@@ -659,12 +659,14 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
hidden_states = outputs["hidden_states"] if "hidden_states" in outputs else outputs[-1]
|
|
||||||
|
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
|
||||||
|
|
||||||
expected_num_layers = getattr(
|
expected_num_layers = getattr(
|
||||||
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
|
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
|
||||||
)
|
)
|
||||||
self.assertEqual(len(hidden_states), expected_num_layers)
|
self.assertEqual(len(hidden_states), expected_num_layers)
|
||||||
|
|
||||||
if hasattr(self.model_tester, "encoder_seq_length"):
|
if hasattr(self.model_tester, "encoder_seq_length"):
|
||||||
seq_length = self.model_tester.encoder_seq_length
|
seq_length = self.model_tester.encoder_seq_length
|
||||||
if hasattr(self.model_tester, "chunk_length") and self.model_tester.chunk_length > 1:
|
if hasattr(self.model_tester, "chunk_length") and self.model_tester.chunk_length > 1:
|
||||||
@@ -677,6 +679,19 @@ class ModelTesterMixin:
|
|||||||
[seq_length, self.model_tester.hidden_size],
|
[seq_length, self.model_tester.hidden_size],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if config.is_encoder_decoder:
|
||||||
|
hidden_states = outputs.decoder_hidden_states
|
||||||
|
|
||||||
|
self.assertIsInstance(hidden_states, (list, tuple))
|
||||||
|
self.assertEqual(len(hidden_states), expected_num_layers)
|
||||||
|
seq_len = getattr(self.model_tester, "seq_length", None)
|
||||||
|
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
|
||||||
|
|
||||||
|
self.assertListEqual(
|
||||||
|
list(hidden_states[0].shape[-2:]),
|
||||||
|
[decoder_seq_length, self.model_tester.hidden_size],
|
||||||
|
)
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
|
|||||||
Reference in New Issue
Block a user