[All models] Extend config.output_attentions with output_attentions function arguments (#4538)

* DOC: Replace instances of ``config.output_attentions`` with function argument ``output_attentions``

* DOC: Apply Black Formatting

* Fix errors where output_attentions was undefined

* Remove output_attentions in classes per review

* Fix regressions on tests having `output_attention`

* Fix further regressions in tests relating to `output_attentions`

Ensure proper propagation of `output_attentions` as a function parameter
to all model subclasses

* Fix more regressions in `test_output_attentions`

* Fix issues with BertEncoder

* Rename related variables to `output_attentions`

* fix pytorch tests

* fix bert and gpt2 tf

* Fix most TF tests for `test_output_attentions`

* Fix linter errors and more TF tests

* fix conflicts

* DOC: Apply Black Formatting

* Fix errors where output_attentions was undefined

* Remove output_attentions in classes per review

* Fix regressions on tests having `output_attention`

* fix conflicts

* fix conflicts

* fix conflicts

* fix conflicts

* fix pytorch tests

* fix conflicts

* fix conflicts

* Fix linter errors and more TF tests

* fix tf tests

* make style

* fix isort

* improve output_attentions

* improve tensorflow

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Bharat Raghunathan
2020-06-10 03:09:06 +05:30
committed by GitHub
parent f90bc44d9a
commit 6e603cb789
38 changed files with 1108 additions and 549 deletions

View File

@@ -130,7 +130,7 @@ class ModelTesterMixin:
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
for model_class in self.all_model_classes:
config.output_attentions = True
inputs_dict["output_attentions"] = True
config.output_hidden_states = False
model = model_class(config)
model.to(torch_device)
@@ -138,7 +138,18 @@ class ModelTesterMixin:
with torch.no_grad():
outputs = model(**inputs_dict)
attentions = outputs[-1]
self.assertEqual(model.config.output_attentions, True)
self.assertEqual(model.config.output_hidden_states, False)
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# check that output_attentions also work using config
del inputs_dict["output_attentions"]
config.output_attentions = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**inputs_dict)
attentions = outputs[-1]
self.assertEqual(model.config.output_hidden_states, False)
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
@@ -172,7 +183,7 @@ class ModelTesterMixin:
)
# Check attention is always last and order is fine
config.output_attentions = True
inputs_dict["output_attentions"] = True
config.output_hidden_states = True
model = model_class(config)
model.to(torch_device)
@@ -180,7 +191,6 @@ class ModelTesterMixin:
with torch.no_grad():
outputs = model(**inputs_dict)
self.assertEqual(out_len + (2 if self.is_encoder_decoder else 1), len(outputs))
self.assertEqual(model.config.output_attentions, True)
self.assertEqual(model.config.output_hidden_states, True)
self_attentions = outputs[-1]
@@ -203,7 +213,6 @@ class ModelTesterMixin:
def test_torchscript_output_attentions(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_attentions = True
self._create_and_check_torchscript(config, inputs_dict)
@@ -270,7 +279,7 @@ class ModelTesterMixin:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
global_rng.seed()
config.output_attentions = True
inputs_dict["output_attentions"] = True
config.output_hidden_states = True
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
for model_class in self.all_model_classes:
@@ -326,7 +335,7 @@ class ModelTesterMixin:
if "head_mask" in inputs_dict:
del inputs_dict["head_mask"]
config.output_attentions = True
inputs_dict["output_attentions"] = True
config.output_hidden_states = False
model = model_class(config=config)
model.to(torch_device)
@@ -355,7 +364,7 @@ class ModelTesterMixin:
if "head_mask" in inputs_dict:
del inputs_dict["head_mask"]
config.output_attentions = True
inputs_dict["output_attentions"] = True
config.output_hidden_states = False
model = model_class(config=config)
model.to(torch_device)
@@ -388,7 +397,7 @@ class ModelTesterMixin:
if "head_mask" in inputs_dict:
del inputs_dict["head_mask"]
config.output_attentions = True
inputs_dict["output_attentions"] = True
config.output_hidden_states = False
heads_to_prune = {
@@ -419,7 +428,7 @@ class ModelTesterMixin:
if "head_mask" in inputs_dict:
del inputs_dict["head_mask"]
config.output_attentions = True
inputs_dict["output_attentions"] = True
config.output_hidden_states = False
heads_to_prune = {0: [0], 1: [1, 2]}
@@ -471,14 +480,12 @@ class ModelTesterMixin:
for model_class in self.all_model_classes:
config.output_hidden_states = True
config.output_attentions = False
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**inputs_dict)
hidden_states = outputs[-1]
self.assertEqual(model.config.output_attentions, False)
self.assertEqual(model.config.output_hidden_states, True)
self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
@@ -838,7 +845,6 @@ class ModelUtilsTest(unittest.TestCase):
config = BertConfig.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
model = BertModel.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
self.assertEqual(model.config.output_attentions, True)
self.assertEqual(model.config.output_hidden_states, True)
self.assertEqual(model.config, config)

View File

@@ -296,7 +296,7 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
test_headmasking = False # head masking is not supported
test_torchscript = False
all_model_classes = (LongformerForMaskedLM, LongformerModel) if is_torch_available() else ()
all_model_classes = (LongformerModel, LongformerForMaskedLM,) if is_torch_available() else ()
def setUp(self):
self.model_tester = LongformerModelTester(self)

View File

@@ -314,12 +314,11 @@ class TFModelTesterMixin:
)
for model_class in self.all_model_classes:
config.output_attentions = True
inputs_dict["output_attentions"] = True
config.output_hidden_states = False
model = model_class(config)
outputs = model(inputs_dict)
attentions = [t.numpy() for t in outputs[-1]]
self.assertEqual(model.config.output_attentions, True)
self.assertEqual(model.config.output_hidden_states, False)
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
@@ -331,7 +330,6 @@ class TFModelTesterMixin:
if self.is_encoder_decoder:
self.assertEqual(out_len % 2, 0)
decoder_attentions = outputs[(out_len // 2) - 1]
self.assertEqual(model.config.output_attentions, True)
self.assertEqual(model.config.output_hidden_states, False)
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
@@ -339,13 +337,25 @@ class TFModelTesterMixin:
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
)
# Check attention is always last and order is fine
# Check that output attentions can also be changed via the config
del inputs_dict["output_attentions"]
config.output_attentions = True
model = model_class(config)
outputs = model(inputs_dict)
attentions = [t.numpy() for t in outputs[-1]]
self.assertEqual(model.config.output_hidden_states, False)
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
config.output_hidden_states = True
model = model_class(config)
outputs = model(inputs_dict)
self.assertEqual(out_len + (2 if self.is_encoder_decoder else 1), len(outputs))
self.assertEqual(model.config.output_attentions, True)
self.assertEqual(model.config.output_hidden_states, True)
attentions = [t.numpy() for t in outputs[-1]]
@@ -360,11 +370,9 @@ class TFModelTesterMixin:
for model_class in self.all_model_classes:
config.output_hidden_states = True
config.output_attentions = False
model = model_class(config)
outputs = model(inputs_dict)
hidden_states = [t.numpy() for t in outputs[-1]]
self.assertEqual(model.config.output_attentions, False)
self.assertEqual(model.config.output_hidden_states, True)
self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
self.assertListEqual(

View File

@@ -238,7 +238,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
model.to(torch_device)
model.eval()
_, _, attentions = model(input_ids_1, target_mapping=target_mapping)
_, _, attentions = model(input_ids_1, target_mapping=target_mapping, output_attentions=True)
self.parent.assertEqual(len(attentions), config.n_layer)
self.parent.assertIsInstance(attentions[0], tuple)
@@ -483,7 +483,6 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
def test_xlnet_base_model_with_att_output(self):
self.model_tester.set_seed()
config_and_inputs = self.model_tester.prepare_config_and_inputs()
config_and_inputs[0].output_attentions = True
self.model_tester.create_and_check_xlnet_base_model_with_att_output(*config_and_inputs)
def test_xlnet_lm_head(self):