Change model outputs types to self-document outputs (#5438)
* [WIP] Proposal for model outputs * All Bert models * Make CI green maybe? * Fix ONNX test * Isolate ModelOutput from pt and tf * Formatting * Add Electra models * Auto-generate docstrings from outputs * Add TF outputs * Add some BERT models * Revert TF side * Remove last traces of TF changes * Fail with a clear error message * Add Albert and work through Bart * Add CTRL and DistilBert * Formatting * Progress on Bart * Renames and finish Bart * Formatting * Fix last test * Add DPR * Finish Electra and add FlauBERT * Add GPT2 * Add Longformer * Add MMBT * Add MobileBert * Add GPT * Formatting * Add Reformer * Add Roberta * Add T5 * Add Transformer XL * Fix test * Add XLM + fix XLMForTokenClassification * Style + XLMRoberta * Add XLNet * Formatting * Add doc of return_tuple arg
This commit is contained in:
@@ -8,6 +8,7 @@ import fnmatch
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import tarfile
|
||||
@@ -186,6 +187,31 @@ def add_end_docstrings(*docstr):
|
||||
return docstring_decorator
|
||||
|
||||
|
||||
RETURN_INTRODUCTION = r"""
|
||||
Returns:
|
||||
:class:`~transformers.{output_type}` or :obj:`tuple(torch.FloatTensor)` (if ``return_tuple=True`` is passed or when ``config.return_tuple=True``) comprising various elements depending on the configuration (:class:`~transformers.{config_class}`) and inputs:
|
||||
"""
|
||||
|
||||
|
||||
def _prepare_output_docstrings(output_type, config_class):
|
||||
"""
|
||||
Prepares the return part of the docstring using `output_type`.
|
||||
"""
|
||||
docstrings = output_type.__doc__
|
||||
|
||||
# Remove the head of the docstring to keep the list of args only
|
||||
lines = docstrings.split("\n")
|
||||
i = 0
|
||||
while i < len(lines) and re.search(r"^\s*(Args|Parameters):\s*$", lines[i]) is None:
|
||||
i += 1
|
||||
if i < len(lines):
|
||||
docstrings = "\n".join(lines[(i + 1) :])
|
||||
|
||||
# Add the return introduction
|
||||
intro = RETURN_INTRODUCTION.format(output_type=output_type.__name__, config_class=config_class)
|
||||
return intro + docstrings
|
||||
|
||||
|
||||
PT_TOKEN_CLASSIFICATION_SAMPLE = r"""
|
||||
Example::
|
||||
|
||||
@@ -414,7 +440,7 @@ TF_CAUSAL_LM_SAMPLE = r"""
|
||||
"""
|
||||
|
||||
|
||||
def add_code_sample_docstrings(*docstr, tokenizer_class=None, checkpoint=None):
|
||||
def add_code_sample_docstrings(*docstr, tokenizer_class=None, checkpoint=None, output_type=None, config_class=None):
|
||||
def docstring_decorator(fn):
|
||||
model_class = fn.__qualname__.split(".")[0]
|
||||
is_tf_class = model_class[:2] == "TF"
|
||||
@@ -436,8 +462,29 @@ def add_code_sample_docstrings(*docstr, tokenizer_class=None, checkpoint=None):
|
||||
else:
|
||||
raise ValueError(f"Docstring can't be built for model {model_class}")
|
||||
|
||||
output_doc = _prepare_output_docstrings(output_type, config_class) if output_type is not None else ""
|
||||
built_doc = code_sample.format(model_class=model_class, tokenizer_class=tokenizer_class, checkpoint=checkpoint)
|
||||
fn.__doc__ = (fn.__doc__ or "") + "".join(docstr) + built_doc
|
||||
fn.__doc__ = (fn.__doc__ or "") + "".join(docstr) + output_doc + built_doc
|
||||
return fn
|
||||
|
||||
return docstring_decorator
|
||||
|
||||
|
||||
def replace_return_docstrings(output_type=None, config_class=None):
|
||||
def docstring_decorator(fn):
|
||||
docstrings = fn.__doc__
|
||||
lines = docstrings.split("\n")
|
||||
i = 0
|
||||
while i < len(lines) and re.search(r"^\s*Returns?:\s*$", lines[i]) is None:
|
||||
i += 1
|
||||
if i < len(lines):
|
||||
lines[i] = _prepare_output_docstrings(output_type, config_class)
|
||||
docstrings = "\n".join(lines)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The function {fn} should have an empty 'Return:' or 'Returns:' in its docstring as placeholder, current docstring is:\n{docstrings}"
|
||||
)
|
||||
fn.__doc__ = docstrings
|
||||
return fn
|
||||
|
||||
return docstring_decorator
|
||||
@@ -806,3 +853,22 @@ def tf_required(func):
|
||||
raise ImportError(f"Method `{func.__name__}` requires TF.")
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class ModelOutput:
|
||||
"""
|
||||
Base class for all model outputs as dataclass. Has a ``__getitem__`` (to make it behave like a ``namedtuple``) that
|
||||
will ignore ``None`` in the attributes.
|
||||
"""
|
||||
|
||||
def to_tuple(self):
|
||||
return tuple(getattr(self, f) for f in self.__dataclass_fields__.keys() if getattr(self, f, None) is not None)
|
||||
|
||||
def to_dict(self):
|
||||
return {f: getattr(self, f) for f in self.__dataclass_fields__.keys() if getattr(self, f, None) is not None}
|
||||
|
||||
def __getitem__(self, i):
|
||||
return self.to_dict()[i] if isinstance(i, str) else self.to_tuple()[i]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.to_tuple())
|
||||
|
||||
Reference in New Issue
Block a user