Document model outputs (#5673)
* Document model outputs * Update docs/source/main_classes/output.rst Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -189,7 +189,7 @@ def add_end_docstrings(*docstr):
|
||||
|
||||
RETURN_INTRODUCTION = r"""
|
||||
Returns:
|
||||
:class:`~transformers.{output_type}` or :obj:`tuple(torch.FloatTensor)` (if ``return_tuple=True`` is passed or when ``config.return_tuple=True``) comprising various elements depending on the configuration (:class:`~transformers.{config_class}`) and inputs:
|
||||
:class:`~{full_output_type}` or :obj:`tuple(torch.FloatTensor)` (if ``return_tuple=True`` is passed or when ``config.return_tuple=True``) comprising various elements depending on the configuration (:class:`~transformers.{config_class}`) and inputs:
|
||||
"""
|
||||
|
||||
|
||||
@@ -208,7 +208,8 @@ def _prepare_output_docstrings(output_type, config_class):
|
||||
docstrings = "\n".join(lines[(i + 1) :])
|
||||
|
||||
# Add the return introduction
|
||||
intro = RETURN_INTRODUCTION.format(output_type=output_type.__name__, config_class=config_class)
|
||||
full_output_type = f"{output_type.__module__}.{output_type.__name__}"
|
||||
intro = RETURN_INTRODUCTION.format(full_output_type=full_output_type, config_class=config_class)
|
||||
return intro + docstrings
|
||||
|
||||
|
||||
@@ -857,14 +858,24 @@ def tf_required(func):
|
||||
|
||||
class ModelOutput:
|
||||
"""
|
||||
Base class for all model outputs as dataclass. Has a ``__getitem__`` (to make it behave like a ``namedtuple``) that
|
||||
will ignore ``None`` in the attributes.
|
||||
Base class for all model outputs as dataclass. Has a ``__getitem__`` that allows indexing by integer or slice (like
|
||||
a tuple) or strings (like a dictionnary) that will ignore the ``None`` attributes.
|
||||
"""
|
||||
|
||||
def to_tuple(self):
|
||||
"""
|
||||
Converts :obj:`self` to a tuple.
|
||||
|
||||
Return: A tuple containing all non-:obj:`None` attributes of the :obj:`self`.
|
||||
"""
|
||||
return tuple(getattr(self, f) for f in self.__dataclass_fields__.keys() if getattr(self, f, None) is not None)
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Converts :obj:`self` to a Python dictionary.
|
||||
|
||||
Return: A dictionary containing all non-:obj:`None` attributes of the :obj:`self`.
|
||||
"""
|
||||
return {f: getattr(self, f) for f in self.__dataclass_fields__.keys() if getattr(self, f, None) is not None}
|
||||
|
||||
def __getitem__(self, i):
|
||||
|
||||
Reference in New Issue
Block a user