Change model outputs types to self-document outputs (#5438)

* [WIP] Proposal for model outputs

* All Bert models

* Make CI green maybe?

* Fix ONNX test

* Isolate ModelOutput from pt and tf

* Formatting

* Add Electra models

* Auto-generate docstrings from outputs

* Add TF outputs

* Add some BERT models

* Revert TF side

* Remove last traces of TF changes

* Fail with a clear error message

* Add Albert and work through Bart

* Add CTRL and DistilBert

* Formatting

* Progress on Bart

* Renames and finish Bart

* Formatting

* Fix last test

* Add DPR

* Finish Electra and add FlauBERT

* Add GPT2

* Add Longformer

* Add MMBT

* Add MobileBert

* Add GPT

* Formatting

* Add Reformer

* Add Roberta

* Add T5

* Add Transformer XL

* Fix test

* Add XLM + fix XLMForTokenClassification

* Style + XLMRoberta

* Add XLNet

* Formatting

* Add doc of return_tuple arg
This commit is contained in:
Sylvain Gugger
2020-07-10 11:36:53 -04:00
committed by GitHub
parent fa265230a2
commit edfd82f5ff
33 changed files with 3458 additions and 2292 deletions

View File

@@ -8,6 +8,7 @@ import fnmatch
import json
import logging
import os
import re
import shutil
import sys
import tarfile
@@ -186,6 +187,31 @@ def add_end_docstrings(*docstr):
return docstring_decorator
RETURN_INTRODUCTION = r"""
Returns:
:class:`~transformers.{output_type}` or :obj:`tuple(torch.FloatTensor)` (if ``return_tuple=True`` is passed or when ``config.return_tuple=True``) comprising various elements depending on the configuration (:class:`~transformers.{config_class}`) and inputs:
"""
def _prepare_output_docstrings(output_type, config_class):
"""
Prepares the return part of the docstring using `output_type`.
"""
docstrings = output_type.__doc__
# Remove the head of the docstring to keep the list of args only
lines = docstrings.split("\n")
i = 0
while i < len(lines) and re.search(r"^\s*(Args|Parameters):\s*$", lines[i]) is None:
i += 1
if i < len(lines):
docstrings = "\n".join(lines[(i + 1) :])
# Add the return introduction
intro = RETURN_INTRODUCTION.format(output_type=output_type.__name__, config_class=config_class)
return intro + docstrings
PT_TOKEN_CLASSIFICATION_SAMPLE = r"""
Example::
@@ -414,7 +440,7 @@ TF_CAUSAL_LM_SAMPLE = r"""
"""
def add_code_sample_docstrings(*docstr, tokenizer_class=None, checkpoint=None):
def add_code_sample_docstrings(*docstr, tokenizer_class=None, checkpoint=None, output_type=None, config_class=None):
def docstring_decorator(fn):
model_class = fn.__qualname__.split(".")[0]
is_tf_class = model_class[:2] == "TF"
@@ -436,8 +462,29 @@ def add_code_sample_docstrings(*docstr, tokenizer_class=None, checkpoint=None):
else:
raise ValueError(f"Docstring can't be built for model {model_class}")
output_doc = _prepare_output_docstrings(output_type, config_class) if output_type is not None else ""
built_doc = code_sample.format(model_class=model_class, tokenizer_class=tokenizer_class, checkpoint=checkpoint)
fn.__doc__ = (fn.__doc__ or "") + "".join(docstr) + built_doc
fn.__doc__ = (fn.__doc__ or "") + "".join(docstr) + output_doc + built_doc
return fn
return docstring_decorator
def replace_return_docstrings(output_type=None, config_class=None):
def docstring_decorator(fn):
docstrings = fn.__doc__
lines = docstrings.split("\n")
i = 0
while i < len(lines) and re.search(r"^\s*Returns?:\s*$", lines[i]) is None:
i += 1
if i < len(lines):
lines[i] = _prepare_output_docstrings(output_type, config_class)
docstrings = "\n".join(lines)
else:
raise ValueError(
f"The function {fn} should have an empty 'Return:' or 'Returns:' in its docstring as placeholder, current docstring is:\n{docstrings}"
)
fn.__doc__ = docstrings
return fn
return docstring_decorator
@@ -806,3 +853,22 @@ def tf_required(func):
raise ImportError(f"Method `{func.__name__}` requires TF.")
return wrapper
class ModelOutput:
"""
Base class for all model outputs as dataclass. Has a ``__getitem__`` (to make it behave like a ``namedtuple``) that
will ignore ``None`` in the attributes.
"""
def to_tuple(self):
return tuple(getattr(self, f) for f in self.__dataclass_fields__.keys() if getattr(self, f, None) is not None)
def to_dict(self):
return {f: getattr(self, f) for f in self.__dataclass_fields__.keys() if getattr(self, f, None) is not None}
def __getitem__(self, i):
return self.to_dict()[i] if isinstance(i, str) else self.to_tuple()[i]
def __len__(self):
return len(self.to_tuple())