Change model outputs types to self-document outputs (#5438)

* [WIP] Proposal for model outputs

* All Bert models

* Make CI green maybe?

* Fix ONNX test

* Isolate ModelOutput from pt and tf

* Formatting

* Add Electra models

* Auto-generate docstrings from outputs

* Add TF outputs

* Add some BERT models

* Revert TF side

* Remove last traces of TF changes

* Fail with a clear error message

* Add Albert and work through Bart

* Add CTRL and DistilBert

* Formatting

* Progress on Bart

* Renames and finish Bart

* Formatting

* Fix last test

* Add DPR

* Finish Electra and add FlauBERT

* Add GPT2

* Add Longformer

* Add MMBT

* Add MobileBert

* Add GPT

* Formatting

* Add Reformer

* Add Roberta

* Add T5

* Add Transformer XL

* Fix test

* Add XLM + fix XLMForTokenClassification

* Style + XLMRoberta

* Add XLNet

* Formatting

* Add doc of return_tuple arg
This commit is contained in:
Sylvain Gugger
2020-07-10 11:36:53 -04:00
committed by GitHub
parent fa265230a2
commit edfd82f5ff
33 changed files with 3458 additions and 2292 deletions

View File

@@ -49,6 +49,8 @@ class PretrainedConfig(object):
Whether or not the model should returns all attentions.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the model should return the last key/values attentions (not used by all models).
return_tuple (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the model should return tuples instead of :obj:`ModelOutput` objects.
is_encoder_decoder (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether the model is used as an encoder/decoder or not.
is_decoder (:obj:`bool`, `optional`, defaults to :obj:`False`):
@@ -131,6 +133,7 @@ class PretrainedConfig(object):
def __init__(self, **kwargs):
# Attributes with defaults
self.return_tuple = kwargs.pop("return_tuple", False)
self.output_hidden_states = kwargs.pop("output_hidden_states", False)
self.output_attentions = kwargs.pop("output_attentions", False)
self.use_cache = kwargs.pop("use_cache", True) # Not used by all models
@@ -190,6 +193,11 @@ class PretrainedConfig(object):
logger.error("Can't set {} with value {} for {}".format(key, value, self))
raise err
@property
def use_return_tuple(self):
# If torchscript is set, force return_tuple to avoid jit errors
return self.return_tuple or self.torchscript
@property
def num_labels(self) -> int:
return len(self.id2label)