[All models] Extend config.output_attentions with output_attentions function arguments (#4538)
* DOC: Replace instances of ``config.output_attentions`` with function argument ``output_attentions`` * DOC: Apply Black Formatting * Fix errors where output_attentions was undefined * Remove output_attentions in classes per review * Fix regressions on tests having `output_attention` * Fix further regressions in tests relating to `output_attentions` Ensure proper propagation of `output_attentions` as a function parameter to all model subclasses * Fix more regressions in `test_output_attentions` * Fix issues with BertEncoder * Rename related variables to `output_attentions` * fix pytorch tests * fix bert and gpt2 tf * Fix most TF tests for `test_output_attentions` * Fix linter errors and more TF tests * fix conflicts * DOC: Apply Black Formatting * Fix errors where output_attentions was undefined * Remove output_attentions in classes per review * Fix regressions on tests having `output_attention` * fix conflicts * fix conflicts * fix conflicts * fix conflicts * fix pytorch tests * fix conflicts * fix conflicts * Fix linter errors and more TF tests * fix tf tests * make style * fix isort * improve output_attentions * improve tensorflow Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
committed by
GitHub
parent
f90bc44d9a
commit
6e603cb789
@@ -130,7 +130,7 @@ class ModelTesterMixin:
|
||||
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config.output_attentions = True
|
||||
inputs_dict["output_attentions"] = True
|
||||
config.output_hidden_states = False
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
@@ -138,7 +138,18 @@ class ModelTesterMixin:
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs_dict)
|
||||
attentions = outputs[-1]
|
||||
self.assertEqual(model.config.output_attentions, True)
|
||||
self.assertEqual(model.config.output_hidden_states, False)
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
# check that output_attentions also work using config
|
||||
del inputs_dict["output_attentions"]
|
||||
config.output_attentions = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs_dict)
|
||||
attentions = outputs[-1]
|
||||
self.assertEqual(model.config.output_hidden_states, False)
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
@@ -172,7 +183,7 @@ class ModelTesterMixin:
|
||||
)
|
||||
|
||||
# Check attention is always last and order is fine
|
||||
config.output_attentions = True
|
||||
inputs_dict["output_attentions"] = True
|
||||
config.output_hidden_states = True
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
@@ -180,7 +191,6 @@ class ModelTesterMixin:
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs_dict)
|
||||
self.assertEqual(out_len + (2 if self.is_encoder_decoder else 1), len(outputs))
|
||||
self.assertEqual(model.config.output_attentions, True)
|
||||
self.assertEqual(model.config.output_hidden_states, True)
|
||||
|
||||
self_attentions = outputs[-1]
|
||||
@@ -203,7 +213,6 @@ class ModelTesterMixin:
|
||||
|
||||
def test_torchscript_output_attentions(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
config.output_attentions = True
|
||||
self._create_and_check_torchscript(config, inputs_dict)
|
||||
|
||||
@@ -270,7 +279,7 @@ class ModelTesterMixin:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
global_rng.seed()
|
||||
|
||||
config.output_attentions = True
|
||||
inputs_dict["output_attentions"] = True
|
||||
config.output_hidden_states = True
|
||||
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
||||
for model_class in self.all_model_classes:
|
||||
@@ -326,7 +335,7 @@ class ModelTesterMixin:
|
||||
if "head_mask" in inputs_dict:
|
||||
del inputs_dict["head_mask"]
|
||||
|
||||
config.output_attentions = True
|
||||
inputs_dict["output_attentions"] = True
|
||||
config.output_hidden_states = False
|
||||
model = model_class(config=config)
|
||||
model.to(torch_device)
|
||||
@@ -355,7 +364,7 @@ class ModelTesterMixin:
|
||||
if "head_mask" in inputs_dict:
|
||||
del inputs_dict["head_mask"]
|
||||
|
||||
config.output_attentions = True
|
||||
inputs_dict["output_attentions"] = True
|
||||
config.output_hidden_states = False
|
||||
model = model_class(config=config)
|
||||
model.to(torch_device)
|
||||
@@ -388,7 +397,7 @@ class ModelTesterMixin:
|
||||
if "head_mask" in inputs_dict:
|
||||
del inputs_dict["head_mask"]
|
||||
|
||||
config.output_attentions = True
|
||||
inputs_dict["output_attentions"] = True
|
||||
config.output_hidden_states = False
|
||||
|
||||
heads_to_prune = {
|
||||
@@ -419,7 +428,7 @@ class ModelTesterMixin:
|
||||
if "head_mask" in inputs_dict:
|
||||
del inputs_dict["head_mask"]
|
||||
|
||||
config.output_attentions = True
|
||||
inputs_dict["output_attentions"] = True
|
||||
config.output_hidden_states = False
|
||||
|
||||
heads_to_prune = {0: [0], 1: [1, 2]}
|
||||
@@ -471,14 +480,12 @@ class ModelTesterMixin:
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = False
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs_dict)
|
||||
hidden_states = outputs[-1]
|
||||
self.assertEqual(model.config.output_attentions, False)
|
||||
self.assertEqual(model.config.output_hidden_states, True)
|
||||
self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
|
||||
|
||||
@@ -838,7 +845,6 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
|
||||
config = BertConfig.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
|
||||
model = BertModel.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
|
||||
self.assertEqual(model.config.output_attentions, True)
|
||||
self.assertEqual(model.config.output_hidden_states, True)
|
||||
self.assertEqual(model.config, config)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user