Document model outputs (#5673)

* Document model outputs

* Update docs/source/main_classes/output.rst

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
Sylvain Gugger
2020-07-10 17:31:02 -04:00
committed by GitHub
parent df983b7483
commit 7fad617dc1
18 changed files with 267 additions and 17 deletions

View File

@@ -189,7 +189,7 @@ def add_end_docstrings(*docstr):
RETURN_INTRODUCTION = r"""
Returns:
:class:`~transformers.{output_type}` or :obj:`tuple(torch.FloatTensor)` (if ``return_tuple=True`` is passed or when ``config.return_tuple=True``) comprising various elements depending on the configuration (:class:`~transformers.{config_class}`) and inputs:
:class:`~{full_output_type}` or :obj:`tuple(torch.FloatTensor)` (if ``return_tuple=True`` is passed or when ``config.return_tuple=True``) comprising various elements depending on the configuration (:class:`~transformers.{config_class}`) and inputs:
"""
@@ -208,7 +208,8 @@ def _prepare_output_docstrings(output_type, config_class):
docstrings = "\n".join(lines[(i + 1) :])
# Add the return introduction
intro = RETURN_INTRODUCTION.format(output_type=output_type.__name__, config_class=config_class)
full_output_type = f"{output_type.__module__}.{output_type.__name__}"
intro = RETURN_INTRODUCTION.format(full_output_type=full_output_type, config_class=config_class)
return intro + docstrings
@@ -857,14 +858,24 @@ def tf_required(func):
class ModelOutput:
"""
Base class for all model outputs as dataclass. Has a ``__getitem__`` (to make it behave like a ``namedtuple``) that
will ignore ``None`` in the attributes.
Base class for all model outputs as dataclass. Has a ``__getitem__`` that allows indexing by integer or slice (like
a tuple) or strings (like a dictionnary) that will ignore the ``None`` attributes.
"""
def to_tuple(self):
"""
Converts :obj:`self` to a tuple.
Return: A tuple containing all non-:obj:`None` attributes of the :obj:`self`.
"""
return tuple(getattr(self, f) for f in self.__dataclass_fields__.keys() if getattr(self, f, None) is not None)
def to_dict(self):
"""
Converts :obj:`self` to a Python dictionary.
Return: A dictionary containing all non-:obj:`None` attributes of the :obj:`self`.
"""
return {f: getattr(self, f) for f in self.__dataclass_fields__.keys() if getattr(self, f, None) is not None}
def __getitem__(self, i):