[All models] Extend config.output_attentions with output_attentions function arguments (#4538)
* DOC: Replace instances of ``config.output_attentions`` with function argument ``output_attentions`` * DOC: Apply Black Formatting * Fix errors where output_attentions was undefined * Remove output_attentions in classes per review * Fix regressions on tests having `output_attention` * Fix further regressions in tests relating to `output_attentions` Ensure proper propagation of `output_attentions` as a function parameter to all model subclasses * Fix more regressions in `test_output_attentions` * Fix issues with BertEncoder * Rename related variables to `output_attentions` * fix pytorch tests * fix bert and gpt2 tf * Fix most TF tests for `test_output_attentions` * Fix linter errors and more TF tests * fix conflicts * DOC: Apply Black Formatting * Fix errors where output_attentions was undefined * Remove output_attentions in classes per review * Fix regressions on tests having `output_attention` * fix conflicts * fix conflicts * fix conflicts * fix conflicts * fix pytorch tests * fix conflicts * fix conflicts * Fix linter errors and more TF tests * fix tf tests * make style * fix isort * improve output_attentions * improve tensorflow Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
committed by
GitHub
parent
f90bc44d9a
commit
6e603cb789
@@ -44,8 +44,6 @@ class PretrainedConfig(object):
|
||||
Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint.
|
||||
num_labels (:obj:`int`, `optional`, defaults to `2`):
|
||||
Number of classes to use when the model is a classification model (sequences/tokens)
|
||||
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Should the model returns attentions weights.
|
||||
output_hidden_states (:obj:`string`, `optional`, defaults to :obj:`False`):
|
||||
Should the model returns all hidden-states.
|
||||
torchscript (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
@@ -55,8 +53,8 @@ class PretrainedConfig(object):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Attributes with defaults
|
||||
self.output_attentions = kwargs.pop("output_attentions", False)
|
||||
self.output_hidden_states = kwargs.pop("output_hidden_states", False)
|
||||
self.output_attentions = kwargs.pop("output_attentions", False)
|
||||
self.use_cache = kwargs.pop("use_cache", True) # Not used by all models
|
||||
self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
|
||||
self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
|
||||
|
||||
Reference in New Issue
Block a user