Change model outputs types to self-document outputs (#5438)
* [WIP] Proposal for model outputs * All Bert models * Make CI green maybe? * Fix ONNX test * Isolate ModelOutput from pt and tf * Formatting * Add Electra models * Auto-generate docstrings from outputs * Add TF outputs * Add some BERT models * Revert TF side * Remove last traces of TF changes * Fail with a clear error message * Add Albert and work through Bart * Add CTRL and DistilBert * Formatting * Progress on Bart * Renames and finish Bart * Formatting * Fix last test * Add DPR * Finish Electra and add FlauBERT * Add GPT2 * Add Longformer * Add MMBT * Add MobileBert * Add GPT * Formatting * Add Reformer * Add Roberta * Add T5 * Add Transformer XL * Fix test * Add XLM + fix XLMForTokenClassification * Style + XLMRoberta * Add XLNet * Formatting * Add doc of return_tuple arg
This commit is contained in:
@@ -49,6 +49,8 @@ class PretrainedConfig(object):
|
||||
Whether or not the model should returns all attentions.
|
||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models).
|
||||
return_tuple (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not the model should return tuples instead of :obj:`ModelOutput` objects.
|
||||
is_encoder_decoder (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether the model is used as an encoder/decoder or not.
|
||||
is_decoder (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
@@ -131,6 +133,7 @@ class PretrainedConfig(object):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Attributes with defaults
|
||||
self.return_tuple = kwargs.pop("return_tuple", False)
|
||||
self.output_hidden_states = kwargs.pop("output_hidden_states", False)
|
||||
self.output_attentions = kwargs.pop("output_attentions", False)
|
||||
self.use_cache = kwargs.pop("use_cache", True) # Not used by all models
|
||||
@@ -190,6 +193,11 @@ class PretrainedConfig(object):
|
||||
logger.error("Can't set {} with value {} for {}".format(key, value, self))
|
||||
raise err
|
||||
|
||||
@property
|
||||
def use_return_tuple(self):
|
||||
# If torchscript is set, force return_tuple to avoid jit errors
|
||||
return self.return_tuple or self.torchscript
|
||||
|
||||
@property
|
||||
def num_labels(self) -> int:
|
||||
return len(self.id2label)
|
||||
|
||||
Reference in New Issue
Block a user