From 0a921b64595448d27af7ca66cebd153e8860ce95 Mon Sep 17 00:00:00 2001 From: Max Del Date: Fri, 27 Nov 2020 19:35:34 +0200 Subject: [PATCH] BART & FSMT: fix decoder not returning hidden states from the last layer (#8597) * Fix decoder not returning hidden states from the last layer * Resolve conflict * Change the way to gather hidden states * Add decoder hidden states test * Make pytest and black happy * Remove redundant line * remove new line Co-authored-by: Stas Bekman --- src/transformers/models/bart/modeling_bart.py | 6 ++++++ src/transformers/models/fsmt/modeling_fsmt.py | 6 ++++++ tests/test_modeling_common.py | 17 ++++++++++++++++- 3 files changed, 28 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 684db1b87c..d1cbfe307c 100644 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -610,6 +610,12 @@ class BartDecoder(nn.Module): all_self_attns += (layer_self_attn,) all_cross_attentions += (layer_cross_attn,) + # add hidden states from the last decoder layer + if output_hidden_states: + x = x.transpose(0, 1) + all_hidden_states += (x,) + x = x.transpose(0, 1) + if self.layer_norm: # if config.add_final_layer_norm (mBART) x = self.layer_norm(x) diff --git a/src/transformers/models/fsmt/modeling_fsmt.py b/src/transformers/models/fsmt/modeling_fsmt.py index 457c0a5ab9..682d6af006 100644 --- a/src/transformers/models/fsmt/modeling_fsmt.py +++ b/src/transformers/models/fsmt/modeling_fsmt.py @@ -692,6 +692,12 @@ class FSMTDecoder(nn.Module): all_self_attns += (layer_self_attn,) all_cross_attns += (layer_cross_attn,) + # add hidden states from the last decoder layer + if output_hidden_states: + x = x.transpose(0, 1) + all_hidden_states += (x,) + x = x.transpose(0, 1) + # Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim) x = x.transpose(0, 1) encoder_hidden_states = encoder_hidden_states.transpose(0, 1) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 6a6eed12b9..977a4a1e8b 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -659,12 +659,14 @@ class ModelTesterMixin: with torch.no_grad(): outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - hidden_states = outputs["hidden_states"] if "hidden_states" in outputs else outputs[-1] + + hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states expected_num_layers = getattr( self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1 ) self.assertEqual(len(hidden_states), expected_num_layers) + if hasattr(self.model_tester, "encoder_seq_length"): seq_length = self.model_tester.encoder_seq_length if hasattr(self.model_tester, "chunk_length") and self.model_tester.chunk_length > 1: @@ -677,6 +679,19 @@ class ModelTesterMixin: [seq_length, self.model_tester.hidden_size], ) + if config.is_encoder_decoder: + hidden_states = outputs.decoder_hidden_states + + self.assertIsInstance(hidden_states, (list, tuple)) + self.assertEqual(len(hidden_states), expected_num_layers) + seq_len = getattr(self.model_tester, "seq_length", None) + decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len) + + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [decoder_seq_length, self.model_tester.hidden_size], + ) + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: