From 1246b20f6d81bcd949078d26cf5ab3d0f3acccc6 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Mon, 27 Jul 2020 09:18:59 -0400 Subject: [PATCH] Fix the return documentation rendering for all model outputs (#6022) * Fix the return documentation rendering for all model outputs * Formatting --- src/transformers/file_utils.py | 39 ++++++++++++++++++++++++- src/transformers/modeling_transfo_xl.py | 2 -- 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index 57b8c3d310..5a46fb0624 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -189,10 +189,46 @@ def add_end_docstrings(*docstr): RETURN_INTRODUCTION = r""" Returns: - :class:`~{full_output_type}` or :obj:`tuple(torch.FloatTensor)` (if ``return_tuple=True`` is passed or when ``config.return_tuple=True``) comprising various elements depending on the configuration (:class:`~transformers.{config_class}`) and inputs: + :class:`~{full_output_type}` or :obj:`tuple(torch.FloatTensor)`: + A :class:`~{full_output_type}` or a tuple of :obj:`torch.FloatTensor` (if ``return_tuple=True`` is passed or + when ``config.return_tuple=True``) comprising various elements depending on the configuration + (:class:`~transformers.{config_class}`) and inputs. + """ +def _get_indent(t): + """Returns the indentation in the first line of t""" + search = re.search(r"^(\s*)\S", t) + return "" if search is None else search.groups()[0] + + +def _convert_output_args_doc(output_args_doc): + """Convert output_args_doc to display properly.""" + # Split output_arg_doc in blocks argument/description + indent = _get_indent(output_args_doc) + blocks = [] + current_block = "" + for line in output_args_doc.split("\n"): + # If the indent is the same as the beginning, the line is the name of new arg. + if _get_indent(line) == indent: + if len(current_block) > 0: + blocks.append(current_block[:-1]) + current_block = f"{line}\n" + else: + # Otherwise it's part of the description of the current arg. + # We need to remove 2 spaces to the indentation. + current_block += f"{line[2:]}\n" + blocks.append(current_block[:-1]) + + # Format each block for proper rendering + for i in range(len(blocks)): + blocks[i] = re.sub(r"^(\s+)(\S+)(\s+)", r"\1- **\2**\3", blocks[i]) + blocks[i] = re.sub(r":\s*\n\s*(\S)", r" -- \1", blocks[i]) + + return "\n".join(blocks) + + def _prepare_output_docstrings(output_type, config_class): """ Prepares the return part of the docstring using `output_type`. @@ -206,6 +242,7 @@ def _prepare_output_docstrings(output_type, config_class): i += 1 if i < len(lines): docstrings = "\n".join(lines[(i + 1) :]) + docstrings = _convert_output_args_doc(docstrings) # Add the return introduction full_output_type = f"{output_type.__module__}.{output_type.__name__}" diff --git a/src/transformers/modeling_transfo_xl.py b/src/transformers/modeling_transfo_xl.py index ba8285f388..ca98fe5abc 100644 --- a/src/transformers/modeling_transfo_xl.py +++ b/src/transformers/modeling_transfo_xl.py @@ -629,8 +629,6 @@ class TransfoXLLMHeadModelOutput(ModelOutput): Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). Args: - - Language modeling loss (for next-token prediction). losses (:obj:`torch.FloatTensor` of shape `(batch_size, sequence_length-1)`, `optional`, returned when ``labels`` is provided) Language modeling losses (not reduced). prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):