Output global_attentions in Longformer models (#7562)
* Output global_attentions in Longformer models * make style * small refactoring * fix tests * make fix-copies * add for tf as well * remove comments in test * make fix-copies * make style * add docs * make docstring pretty Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -504,6 +504,7 @@ class TFModelTesterMixin:
|
||||
|
||||
def test_attention_outputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
||||
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", self.model_tester.seq_length)
|
||||
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", self.model_tester.seq_length)
|
||||
@@ -515,9 +516,10 @@ class TFModelTesterMixin:
|
||||
inputs_dict["use_cache"] = False
|
||||
config.output_hidden_states = False
|
||||
model = model_class(config)
|
||||
model_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
outputs = model(model_inputs)
|
||||
attentions = [t.numpy() for t in outputs[-1]]
|
||||
outputs = model(self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = [
|
||||
t.numpy() for t in (outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions)
|
||||
]
|
||||
self.assertEqual(model.config.output_hidden_states, False)
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
@@ -528,7 +530,7 @@ class TFModelTesterMixin:
|
||||
|
||||
if self.is_encoder_decoder:
|
||||
self.assertEqual(out_len % 2, 0)
|
||||
decoder_attentions = outputs[(out_len // 2) - 1]
|
||||
decoder_attentions = outputs.decoder_attentions
|
||||
self.assertEqual(model.config.output_hidden_states, False)
|
||||
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
@@ -541,7 +543,9 @@ class TFModelTesterMixin:
|
||||
config.output_attentions = True
|
||||
model = model_class(config)
|
||||
outputs = model(self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = [t.numpy() for t in outputs[-1]]
|
||||
attentions = [
|
||||
t.numpy() for t in (outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions)
|
||||
]
|
||||
self.assertEqual(model.config.output_hidden_states, False)
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
@@ -557,7 +561,9 @@ class TFModelTesterMixin:
|
||||
self.assertEqual(out_len + (2 if self.is_encoder_decoder else 1), len(outputs))
|
||||
self.assertEqual(model.config.output_hidden_states, True)
|
||||
|
||||
attentions = [t.numpy() for t in outputs[-1]]
|
||||
attentions = [
|
||||
t.numpy() for t in (outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions)
|
||||
]
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
|
||||
Reference in New Issue
Block a user