Switch from return_tuple to return_dict (#6138)
* Switch from return_tuple to return_dict
* Fix test
* [WIP] Test TF Flaubert + Add {XLM, Flaubert}{TokenClassification, MultipleC… (#5614)
* Test TF Flaubert + Add {XLM, Flaubert}{TokenClassification, MultipleChoice} models and tests
* AutoModels
Tiny tweaks
* Style
* Final changes before merge
* Re-order for simpler review
* Final fixes
* Addressing @sgugger's comments
* Test MultipleChoice
* Rework TF trainer (#6038)
* Fully rework training/prediction loops
* fix method name
* Fix variable name
* Fix property name
* Fix scope
* Fix method name
* Fix tuple index
* Fix tuple index
* Fix indentation
* Fix variable name
* fix eval before log
* Add drop remainder for test dataset
* Fix step number + fix logging datetime
* fix eval loss value
* use global step instead of step + fix logging at step 0
* Fix logging datetime
* Fix global_step usage
* Fix breaking loop + logging datetime
* Fix step in prediction loop
* Fix step breaking
* Fix train/test loops
* Force TF at least 2.2 for the trainer
* Use assert_cardinality to facilitate the dataset size computation
* Log steps per epoch
* Make tfds compliant with TPU
* Make tfds compliant with TPU
* Use TF dataset enumerate instead of the Python one
* revert previous commit
* Fix data_dir
* Apply style
* rebase on master
* Address Sylvain's comments
* Address Sylvain's and Lysandre comments
* Trigger CI
* Remove unused import
* Switch from return_tuple to return_dict
* Fix test
* Add recent model
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Julien Plu <plu.julien@gmail.com>
This commit is contained in:
@@ -13,14 +13,17 @@ import shutil
|
||||
import sys
|
||||
import tarfile
|
||||
import tempfile
|
||||
from collections import OrderedDict
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import fields
|
||||
from functools import partial, wraps
|
||||
from hashlib import sha256
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from urllib.parse import urlparse
|
||||
from zipfile import ZipFile, is_zipfile
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from filelock import FileLock
|
||||
from tqdm.auto import tqdm
|
||||
@@ -190,8 +193,8 @@ def add_end_docstrings(*docstr):
|
||||
RETURN_INTRODUCTION = r"""
|
||||
Returns:
|
||||
:class:`~{full_output_type}` or :obj:`tuple(torch.FloatTensor)`:
|
||||
A :class:`~{full_output_type}` or a tuple of :obj:`torch.FloatTensor` (if ``return_tuple=True`` is passed or
|
||||
when ``config.return_tuple=True``) comprising various elements depending on the configuration
|
||||
A :class:`~{full_output_type}` (if ``return_dict=True`` is passed or when ``config.return_dict=True``) or a
|
||||
tuple of :obj:`torch.FloatTensor` comprising various elements depending on the configuration
|
||||
(:class:`~transformers.{config_class}`) and inputs.
|
||||
|
||||
"""
|
||||
@@ -257,7 +260,7 @@ PT_TOKEN_CLASSIFICATION_SAMPLE = r"""
|
||||
>>> import torch
|
||||
|
||||
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
||||
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
||||
>>> model = {model_class}.from_pretrained('{checkpoint}', return_dict=True)
|
||||
|
||||
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
||||
>>> labels = torch.tensor([1] * inputs["input_ids"].size(1)).unsqueeze(0) # Batch size 1
|
||||
@@ -274,7 +277,7 @@ PT_QUESTION_ANSWERING_SAMPLE = r"""
|
||||
>>> import torch
|
||||
|
||||
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
||||
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
||||
>>> model = {model_class}.from_pretrained('{checkpoint}', return_dict=True)
|
||||
|
||||
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
||||
>>> start_positions = torch.tensor([1])
|
||||
@@ -293,7 +296,7 @@ PT_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
|
||||
>>> import torch
|
||||
|
||||
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
||||
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
||||
>>> model = {model_class}.from_pretrained('{checkpoint}', return_dict=True)
|
||||
|
||||
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
||||
>>> labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
|
||||
@@ -309,7 +312,7 @@ PT_MASKED_LM_SAMPLE = r"""
|
||||
>>> import torch
|
||||
|
||||
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
||||
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
||||
>>> model = {model_class}.from_pretrained('{checkpoint}', return_dict=True)
|
||||
|
||||
>>> input_ids = tokenizer("Hello, my dog is cute", return_tensors="pt")["input_ids"]
|
||||
|
||||
@@ -325,7 +328,7 @@ PT_BASE_MODEL_SAMPLE = r"""
|
||||
>>> import torch
|
||||
|
||||
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
||||
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
||||
>>> model = {model_class}.from_pretrained('{checkpoint}', return_dict=True)
|
||||
|
||||
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
||||
>>> outputs = model(**inputs)
|
||||
@@ -340,7 +343,7 @@ PT_MULTIPLE_CHOICE_SAMPLE = r"""
|
||||
>>> import torch
|
||||
|
||||
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
||||
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
||||
>>> model = {model_class}.from_pretrained('{checkpoint}', return_dict=True)
|
||||
|
||||
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
||||
>>> choice0 = "It is eaten with a fork and a knife."
|
||||
@@ -362,7 +365,7 @@ PT_CAUSAL_LM_SAMPLE = r"""
|
||||
>>> from transformers import {tokenizer_class}, {model_class}
|
||||
|
||||
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
||||
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
||||
>>> model = {model_class}.from_pretrained('{checkpoint}', return_dict=True)
|
||||
|
||||
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
||||
>>> outputs = model(**inputs, labels=inputs["input_ids"])
|
||||
@@ -900,30 +903,91 @@ def tf_required(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
class ModelOutput:
|
||||
def is_tensor(x):
|
||||
""" Tests if ``x`` is a :obj:`torch.Tensor`, :obj:`tf.Tensor` or :obj:`np.ndarray`. """
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if isinstance(x, torch.Tensor):
|
||||
return True
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
if isinstance(x, tf.Tensor):
|
||||
return True
|
||||
return isinstance(x, np.ndarray)
|
||||
|
||||
|
||||
class ModelOutput(OrderedDict):
|
||||
"""
|
||||
Base class for all model outputs as dataclass. Has a ``__getitem__`` that allows indexing by integer or slice (like
|
||||
a tuple) or strings (like a dictionnary) that will ignore the ``None`` attributes.
|
||||
a tuple) or strings (like a dictionnary) that will ignore the ``None`` attributes. Otherwise behaves like a
|
||||
regular python dictionary.
|
||||
|
||||
.. warning::
|
||||
You can't unpack a :obj:`ModelOutput` directly. Use the :meth:`~transformers.file_utils.ModelOutput.to_tuple`
|
||||
method to convert it to a tuple before.
|
||||
"""
|
||||
|
||||
def to_tuple(self):
|
||||
def __post_init__(self):
|
||||
class_fields = fields(self)
|
||||
|
||||
# Safety and consistency checks
|
||||
assert len(class_fields), f"{self.__class__.__name__} has no fields."
|
||||
assert all(
|
||||
field.default is None for field in class_fields[1:]
|
||||
), f"{self.__class__.__name__} should not have more than one required field."
|
||||
|
||||
first_field = getattr(self, class_fields[0].name)
|
||||
other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
|
||||
|
||||
if other_fields_are_none and not is_tensor(first_field):
|
||||
try:
|
||||
iterator = iter(first_field)
|
||||
first_field_iterator = True
|
||||
except TypeError:
|
||||
first_field_iterator = False
|
||||
|
||||
# if we provided an iterator as first field and the iterator is a (key, value) iterator
|
||||
# set the associated fields
|
||||
if first_field_iterator:
|
||||
for element in iterator:
|
||||
if (
|
||||
not isinstance(element, (list, tuple))
|
||||
or not len(element) == 2
|
||||
or not isinstance(element[0], str)
|
||||
):
|
||||
break
|
||||
setattr(self, element[0], element[1])
|
||||
if element[1] is not None:
|
||||
self[element[0]] = element[1]
|
||||
else:
|
||||
for field in class_fields:
|
||||
v = getattr(self, field.name)
|
||||
if v is not None:
|
||||
self[field.name] = v
|
||||
|
||||
def __delitem__(self, *args, **kwargs):
|
||||
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
|
||||
|
||||
def setdefault(self, *args, **kwargs):
|
||||
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
|
||||
|
||||
def pop(self, *args, **kwargs):
|
||||
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
|
||||
|
||||
def update(self, *args, **kwargs):
|
||||
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
|
||||
|
||||
def __getitem__(self, k):
|
||||
if isinstance(k, str):
|
||||
inner_dict = {k: v for (k, v) in self.items()}
|
||||
return inner_dict[k]
|
||||
else:
|
||||
return self.to_tuple()[k]
|
||||
|
||||
def to_tuple(self) -> Tuple[Any]:
|
||||
"""
|
||||
Converts :obj:`self` to a tuple.
|
||||
|
||||
Return: A tuple containing all non-:obj:`None` attributes of the :obj:`self`.
|
||||
Convert self to a tuple containing all the attributes/keys that are not ``None``.
|
||||
"""
|
||||
return tuple(getattr(self, f) for f in self.__dataclass_fields__.keys() if getattr(self, f, None) is not None)
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Converts :obj:`self` to a Python dictionary.
|
||||
|
||||
Return: A dictionary containing all non-:obj:`None` attributes of the :obj:`self`.
|
||||
"""
|
||||
return {f: getattr(self, f) for f in self.__dataclass_fields__.keys() if getattr(self, f, None) is not None}
|
||||
|
||||
def __getitem__(self, i):
|
||||
return self.to_dict()[i] if isinstance(i, str) else self.to_tuple()[i]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.to_tuple())
|
||||
return tuple(self[k] for k in self.keys())
|
||||
|
||||
Reference in New Issue
Block a user