* memory tracker metrics * go back to eval for somewhat consistency * handle no-gpu case * deal with stackable eval calls * restore callback order * style * simplify the API * add test * docs * consistently use eval_ prefix * improve docs * Update src/transformers/trainer_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * rename method * style Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
1527 lines
56 KiB
Python
1527 lines
56 KiB
Python
# Copyright 2020 The HuggingFace Team, the AllenNLP library authors. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
Utilities for working with the local dataset cache. Parts of this file is adapted from the AllenNLP library at
|
|
https://github.com/allenai/allennlp.
|
|
"""
|
|
|
|
import copy
|
|
import fnmatch
|
|
import importlib.util
|
|
import io
|
|
import json
|
|
import os
|
|
import re
|
|
import shutil
|
|
import sys
|
|
import tarfile
|
|
import tempfile
|
|
from collections import OrderedDict
|
|
from contextlib import contextmanager
|
|
from dataclasses import fields
|
|
from functools import partial, wraps
|
|
from hashlib import sha256
|
|
from pathlib import Path
|
|
from types import ModuleType
|
|
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union
|
|
from urllib.parse import urlparse
|
|
from zipfile import ZipFile, is_zipfile
|
|
|
|
import numpy as np
|
|
from packaging import version
|
|
from tqdm.auto import tqdm
|
|
|
|
import requests
|
|
from filelock import FileLock
|
|
|
|
from . import __version__
|
|
from .hf_api import HfFolder
|
|
from .utils import logging
|
|
|
|
|
|
# The package importlib_metadata is in a different place, depending on the python version.
|
|
if sys.version_info < (3, 8):
|
|
import importlib_metadata
|
|
else:
|
|
import importlib.metadata as importlib_metadata
|
|
|
|
|
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
|
|
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
|
|
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
|
|
|
|
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
|
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
|
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
|
|
|
|
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
|
|
_torch_available = importlib.util.find_spec("torch") is not None
|
|
if _torch_available:
|
|
try:
|
|
_torch_version = importlib_metadata.version("torch")
|
|
logger.info(f"PyTorch version {_torch_version} available.")
|
|
except importlib_metadata.PackageNotFoundError:
|
|
_torch_available = False
|
|
else:
|
|
logger.info("Disabling PyTorch because USE_TF is set")
|
|
_torch_available = False
|
|
|
|
|
|
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
|
|
_tf_available = importlib.util.find_spec("tensorflow") is not None
|
|
if _tf_available:
|
|
# For the metadata, we have to look for both tensorflow and tensorflow-cpu
|
|
try:
|
|
_tf_version = importlib_metadata.version("tensorflow")
|
|
except importlib_metadata.PackageNotFoundError:
|
|
try:
|
|
_tf_version = importlib_metadata.version("tensorflow-cpu")
|
|
except importlib_metadata.PackageNotFoundError:
|
|
try:
|
|
_tf_version = importlib_metadata.version("tensorflow-gpu")
|
|
except importlib_metadata.PackageNotFoundError:
|
|
try:
|
|
_tf_version = importlib_metadata.version("tf-nightly")
|
|
except importlib_metadata.PackageNotFoundError:
|
|
try:
|
|
_tf_version = importlib_metadata.version("tf-nightly-cpu")
|
|
except importlib_metadata.PackageNotFoundError:
|
|
try:
|
|
_tf_version = importlib_metadata.version("tf-nightly-gpu")
|
|
except importlib_metadata.PackageNotFoundError:
|
|
_tf_version = None
|
|
_tf_available = False
|
|
if _tf_available:
|
|
if version.parse(_tf_version) < version.parse("2"):
|
|
logger.info(f"TensorFlow found but with version {_tf_version}. Transformers requires version 2 minimum.")
|
|
_tf_available = False
|
|
else:
|
|
logger.info(f"TensorFlow version {_tf_version} available.")
|
|
else:
|
|
logger.info("Disabling Tensorflow because USE_TORCH is set")
|
|
_tf_available = False
|
|
|
|
|
|
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
|
|
_flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None
|
|
if _flax_available:
|
|
try:
|
|
_jax_version = importlib_metadata.version("jax")
|
|
_flax_version = importlib_metadata.version("flax")
|
|
logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.")
|
|
except importlib_metadata.PackageNotFoundError:
|
|
_flax_available = False
|
|
else:
|
|
_flax_available = False
|
|
|
|
|
|
_datasets_available = importlib.util.find_spec("datasets") is not None
|
|
try:
|
|
# Check we're not importing a "datasets" directory somewhere but the actual library by trying to grab the version
|
|
# AND checking it has an author field in the metadata that is HuggingFace.
|
|
_ = importlib_metadata.version("datasets")
|
|
_datasets_metadata = importlib_metadata.metadata("datasets")
|
|
if _datasets_metadata.get("author", "") != "HuggingFace Inc.":
|
|
_datasets_available = False
|
|
except importlib_metadata.PackageNotFoundError:
|
|
_datasets_available = False
|
|
|
|
|
|
_faiss_available = importlib.util.find_spec("faiss") is not None
|
|
try:
|
|
_faiss_version = importlib_metadata.version("faiss")
|
|
logger.debug(f"Successfully imported faiss version {_faiss_version}")
|
|
except importlib_metadata.PackageNotFoundError:
|
|
try:
|
|
_faiss_version = importlib_metadata.version("faiss-cpu")
|
|
logger.debug(f"Successfully imported faiss version {_faiss_version}")
|
|
except importlib_metadata.PackageNotFoundError:
|
|
_faiss_available = False
|
|
|
|
|
|
_onnx_available = (
|
|
importlib.util.find_spec("keras2onnx") is not None and importlib.util.find_spec("onnxruntime") is not None
|
|
)
|
|
try:
|
|
_onxx_version = importlib_metadata.version("onnx")
|
|
logger.debug(f"Successfully imported onnx version {_onxx_version}")
|
|
except importlib_metadata.PackageNotFoundError:
|
|
_onnx_available = False
|
|
|
|
|
|
_scatter_available = importlib.util.find_spec("torch_scatter") is not None
|
|
try:
|
|
_scatter_version = importlib_metadata.version("torch_scatter")
|
|
logger.debug(f"Successfully imported torch-scatter version {_scatter_version}")
|
|
except importlib_metadata.PackageNotFoundError:
|
|
_scatter_available = False
|
|
|
|
|
|
_soundfile_available = importlib.util.find_spec("soundfile") is not None
|
|
try:
|
|
_soundfile_version = importlib_metadata.version("soundfile")
|
|
logger.debug(f"Successfully imported soundfile version {_soundfile_version}")
|
|
except importlib_metadata.PackageNotFoundError:
|
|
_soundfile_available = False
|
|
|
|
|
|
torch_cache_home = os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
|
|
old_default_cache_path = os.path.join(torch_cache_home, "transformers")
|
|
# New default cache, shared with the Datasets library
|
|
hf_cache_home = os.path.expanduser(
|
|
os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
|
|
)
|
|
default_cache_path = os.path.join(hf_cache_home, "transformers")
|
|
|
|
# Onetime move from the old location to the new one if no ENV variable has been set.
|
|
if (
|
|
os.path.isdir(old_default_cache_path)
|
|
and not os.path.isdir(default_cache_path)
|
|
and "PYTORCH_PRETRAINED_BERT_CACHE" not in os.environ
|
|
and "PYTORCH_TRANSFORMERS_CACHE" not in os.environ
|
|
and "TRANSFORMERS_CACHE" not in os.environ
|
|
):
|
|
logger.warn(
|
|
"In Transformers v4.0.0, the default path to cache downloaded models changed from "
|
|
"'~/.cache/torch/transformers' to '~/.cache/huggingface/transformers'. Since you don't seem to have overridden "
|
|
"and '~/.cache/torch/transformers' is a directory that exists, we're moving it to "
|
|
"'~/.cache/huggingface/transformers' to avoid redownloading models you have already in the cache. You should "
|
|
"only see this message once."
|
|
)
|
|
shutil.move(old_default_cache_path, default_cache_path)
|
|
|
|
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
|
|
PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
|
|
TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
|
|
|
|
WEIGHTS_NAME = "pytorch_model.bin"
|
|
TF2_WEIGHTS_NAME = "tf_model.h5"
|
|
TF_WEIGHTS_NAME = "model.ckpt"
|
|
FLAX_WEIGHTS_NAME = "flax_model.msgpack"
|
|
CONFIG_NAME = "config.json"
|
|
MODEL_CARD_NAME = "modelcard.json"
|
|
|
|
SENTENCEPIECE_UNDERLINE = "▁"
|
|
SPIECE_UNDERLINE = SENTENCEPIECE_UNDERLINE # Kept for backward compatibility
|
|
|
|
MULTIPLE_CHOICE_DUMMY_INPUTS = [
|
|
[[0, 1, 0, 1], [1, 0, 0, 1]]
|
|
] * 2 # Needs to have 0s and 1s only since XLM uses it for langs too.
|
|
DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
|
DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]
|
|
|
|
S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
|
|
CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co"
|
|
HUGGINGFACE_CO_PREFIX = "https://huggingface.co/{model_id}/resolve/{revision}/{filename}"
|
|
|
|
PRESET_MIRROR_DICT = {
|
|
"tuna": "https://mirrors.tuna.tsinghua.edu.cn/hugging-face-models",
|
|
"bfsu": "https://mirrors.bfsu.edu.cn/hugging-face-models",
|
|
}
|
|
|
|
|
|
def is_torch_available():
|
|
return _torch_available
|
|
|
|
|
|
def is_torch_cuda_available():
|
|
if is_torch_available():
|
|
import torch
|
|
|
|
return torch.cuda.is_available()
|
|
else:
|
|
return False
|
|
|
|
|
|
def is_tf_available():
|
|
return _tf_available
|
|
|
|
|
|
def is_onnx_available():
|
|
return _onnx_available
|
|
|
|
|
|
def is_flax_available():
|
|
return _flax_available
|
|
|
|
|
|
def is_torch_tpu_available():
|
|
if not _torch_available:
|
|
return False
|
|
# This test is probably enough, but just in case, we unpack a bit.
|
|
if importlib.util.find_spec("torch_xla") is None:
|
|
return False
|
|
if importlib.util.find_spec("torch_xla.core") is None:
|
|
return False
|
|
return importlib.util.find_spec("torch_xla.core.xla_model") is not None
|
|
|
|
|
|
def is_datasets_available():
|
|
return _datasets_available
|
|
|
|
|
|
def is_psutil_available():
|
|
return importlib.util.find_spec("psutil") is not None
|
|
|
|
|
|
def is_py3nvml_available():
|
|
return importlib.util.find_spec("py3nvml") is not None
|
|
|
|
|
|
def is_apex_available():
|
|
return importlib.util.find_spec("apex") is not None
|
|
|
|
|
|
def is_faiss_available():
|
|
return _faiss_available
|
|
|
|
|
|
def is_sklearn_available():
|
|
if importlib.util.find_spec("sklearn") is None:
|
|
return False
|
|
if importlib.util.find_spec("scipy") is None:
|
|
return False
|
|
return importlib.util.find_spec("sklearn.metrics") and importlib.util.find_spec("scipy.stats")
|
|
|
|
|
|
def is_sentencepiece_available():
|
|
return importlib.util.find_spec("sentencepiece") is not None
|
|
|
|
|
|
def is_protobuf_available():
|
|
if importlib.util.find_spec("google") is None:
|
|
return False
|
|
return importlib.util.find_spec("google.protobuf") is not None
|
|
|
|
|
|
def is_tokenizers_available():
|
|
return importlib.util.find_spec("tokenizers") is not None
|
|
|
|
|
|
def is_in_notebook():
|
|
try:
|
|
# Test adapted from tqdm.autonotebook: https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py
|
|
get_ipython = sys.modules["IPython"].get_ipython
|
|
if "IPKernelApp" not in get_ipython().config:
|
|
raise ImportError("console")
|
|
if "VSCODE_PID" in os.environ:
|
|
raise ImportError("vscode")
|
|
|
|
return importlib.util.find_spec("IPython") is not None
|
|
except (AttributeError, ImportError, KeyError):
|
|
return False
|
|
|
|
|
|
def is_scatter_available():
|
|
return _scatter_available
|
|
|
|
|
|
def is_pandas_available():
|
|
return importlib.util.find_spec("pandas") is not None
|
|
|
|
|
|
def is_sagemaker_distributed_available():
|
|
# Get the sagemaker specific env variable.
|
|
sagemaker_params = os.getenv("SM_FRAMEWORK_PARAMS", "{}")
|
|
try:
|
|
# Parse it and check the field "sagemaker_distributed_dataparallel_enabled".
|
|
sagemaker_params = json.loads(sagemaker_params)
|
|
if not sagemaker_params.get("sagemaker_distributed_dataparallel_enabled", False):
|
|
return False
|
|
except json.JSONDecodeError:
|
|
return False
|
|
# Lastly, check if the `smdistributed` module is present.
|
|
return importlib.util.find_spec("smdistributed") is not None
|
|
|
|
|
|
def is_soundfile_availble():
|
|
return _soundfile_available
|
|
|
|
|
|
def torch_only_method(fn):
|
|
def wrapper(*args, **kwargs):
|
|
if not _torch_available:
|
|
raise ImportError(
|
|
"You need to install pytorch to use this method or class, "
|
|
"or activate it with environment variables USE_TORCH=1 and USE_TF=0."
|
|
)
|
|
else:
|
|
return fn(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
# docstyle-ignore
|
|
DATASETS_IMPORT_ERROR = """
|
|
{0} requires the 🤗 Datasets library but it was not found in your environment. You can install it with:
|
|
```
|
|
pip install datasets
|
|
```
|
|
In a notebook or a colab, you can install it by executing a cell with
|
|
```
|
|
!pip install datasets
|
|
```
|
|
then restarting your kernel.
|
|
|
|
Note that if you have a local folder named `datasets` or a local python file named `datasets.py` in your current
|
|
working directory, python may try to import this instead of the 🤗 Datasets library. You should rename this folder or
|
|
that python file if that's the case.
|
|
"""
|
|
|
|
|
|
# docstyle-ignore
|
|
TOKENIZERS_IMPORT_ERROR = """
|
|
{0} requires the 🤗 Tokenizers library but it was not found in your environment. You can install it with:
|
|
```
|
|
pip install tokenizers
|
|
```
|
|
In a notebook or a colab, you can install it by executing a cell with
|
|
```
|
|
!pip install tokenizers
|
|
```
|
|
"""
|
|
|
|
|
|
# docstyle-ignore
|
|
SENTENCEPIECE_IMPORT_ERROR = """
|
|
{0} requires the SentencePiece library but it was not found in your environment. Checkout the instructions on the
|
|
installation page of its repo: https://github.com/google/sentencepiece#installation and follow the ones
|
|
that match your environment.
|
|
"""
|
|
|
|
|
|
# docstyle-ignore
|
|
PROTOBUF_IMPORT_ERROR = """
|
|
{0} requires the protobuf library but it was not found in your environment. Checkout the instructions on the
|
|
installation page of its repo: https://github.com/protocolbuffers/protobuf/tree/master/python#installation and follow the ones
|
|
that match your environment.
|
|
"""
|
|
|
|
|
|
# docstyle-ignore
|
|
FAISS_IMPORT_ERROR = """
|
|
{0} requires the faiss library but it was not found in your environment. Checkout the instructions on the
|
|
installation page of its repo: https://github.com/facebookresearch/faiss/blob/master/INSTALL.md and follow the ones
|
|
that match your environment.
|
|
"""
|
|
|
|
|
|
# docstyle-ignore
|
|
PYTORCH_IMPORT_ERROR = """
|
|
{0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the
|
|
installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
|
|
"""
|
|
|
|
|
|
# docstyle-ignore
|
|
SKLEARN_IMPORT_ERROR = """
|
|
{0} requires the scikit-learn library but it was not found in your environment. You can install it with:
|
|
```
|
|
pip install -U scikit-learn
|
|
```
|
|
In a notebook or a colab, you can install it by executing a cell with
|
|
```
|
|
!pip install -U scikit-learn
|
|
```
|
|
"""
|
|
|
|
|
|
# docstyle-ignore
|
|
TENSORFLOW_IMPORT_ERROR = """
|
|
{0} requires the TensorFlow library but it was not found in your environment. Checkout the instructions on the
|
|
installation page: https://www.tensorflow.org/install and follow the ones that match your environment.
|
|
"""
|
|
|
|
|
|
# docstyle-ignore
|
|
FLAX_IMPORT_ERROR = """
|
|
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
|
|
installation page: https://github.com/google/flax and follow the ones that match your environment.
|
|
"""
|
|
|
|
|
|
# docstyle-ignore
|
|
SCATTER_IMPORT_ERROR = """
|
|
{0} requires the torch-scatter library but it was not found in your environment. You can install it with pip as
|
|
explained here: https://github.com/rusty1s/pytorch_scatter.
|
|
"""
|
|
|
|
|
|
# docstyle-ignore
|
|
PANDAS_IMPORT_ERROR = """
|
|
{0} requires the pandas library but it was not found in your environment. You can install it with pip as
|
|
explained here: https://pandas.pydata.org/pandas-docs/stable/getting_started/install.html.
|
|
"""
|
|
|
|
|
|
def requires_datasets(obj):
|
|
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
|
|
if not is_datasets_available():
|
|
raise ImportError(DATASETS_IMPORT_ERROR.format(name))
|
|
|
|
|
|
def requires_faiss(obj):
|
|
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
|
|
if not is_faiss_available():
|
|
raise ImportError(FAISS_IMPORT_ERROR.format(name))
|
|
|
|
|
|
def requires_pytorch(obj):
|
|
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
|
|
if not is_torch_available():
|
|
raise ImportError(PYTORCH_IMPORT_ERROR.format(name))
|
|
|
|
|
|
def requires_sklearn(obj):
|
|
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
|
|
if not is_sklearn_available():
|
|
raise ImportError(SKLEARN_IMPORT_ERROR.format(name))
|
|
|
|
|
|
def requires_tf(obj):
|
|
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
|
|
if not is_tf_available():
|
|
raise ImportError(TENSORFLOW_IMPORT_ERROR.format(name))
|
|
|
|
|
|
def requires_flax(obj):
|
|
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
|
|
if not is_flax_available():
|
|
raise ImportError(FLAX_IMPORT_ERROR.format(name))
|
|
|
|
|
|
def requires_tokenizers(obj):
|
|
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
|
|
if not is_tokenizers_available():
|
|
raise ImportError(TOKENIZERS_IMPORT_ERROR.format(name))
|
|
|
|
|
|
def requires_sentencepiece(obj):
|
|
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
|
|
if not is_sentencepiece_available():
|
|
raise ImportError(SENTENCEPIECE_IMPORT_ERROR.format(name))
|
|
|
|
|
|
def requires_protobuf(obj):
|
|
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
|
|
if not is_protobuf_available():
|
|
raise ImportError(PROTOBUF_IMPORT_ERROR.format(name))
|
|
|
|
|
|
def requires_pandas(obj):
|
|
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
|
|
if not is_pandas_available():
|
|
raise ImportError(PANDAS_IMPORT_ERROR.format(name))
|
|
|
|
|
|
def requires_scatter(obj):
|
|
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
|
|
if not is_scatter_available():
|
|
raise ImportError(SCATTER_IMPORT_ERROR.format(name))
|
|
|
|
|
|
def add_start_docstrings(*docstr):
|
|
def docstring_decorator(fn):
|
|
fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
|
|
return fn
|
|
|
|
return docstring_decorator
|
|
|
|
|
|
def add_start_docstrings_to_model_forward(*docstr):
|
|
def docstring_decorator(fn):
|
|
class_name = ":class:`~transformers.{}`".format(fn.__qualname__.split(".")[0])
|
|
intro = " The {} forward method, overrides the :func:`__call__` special method.".format(class_name)
|
|
note = r"""
|
|
|
|
.. note::
|
|
Although the recipe for forward pass needs to be defined within this function, one should call the
|
|
:class:`Module` instance afterwards instead of this since the former takes care of running the pre and post
|
|
processing steps while the latter silently ignores them.
|
|
"""
|
|
fn.__doc__ = intro + note + "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
|
|
return fn
|
|
|
|
return docstring_decorator
|
|
|
|
|
|
def add_end_docstrings(*docstr):
|
|
def docstring_decorator(fn):
|
|
fn.__doc__ = fn.__doc__ + "".join(docstr)
|
|
return fn
|
|
|
|
return docstring_decorator
|
|
|
|
|
|
PT_RETURN_INTRODUCTION = r"""
|
|
Returns:
|
|
:class:`~{full_output_type}` or :obj:`tuple(torch.FloatTensor)`: A :class:`~{full_output_type}` (if
|
|
``return_dict=True`` is passed or when ``config.return_dict=True``) or a tuple of :obj:`torch.FloatTensor`
|
|
comprising various elements depending on the configuration (:class:`~transformers.{config_class}`) and inputs.
|
|
|
|
"""
|
|
|
|
|
|
TF_RETURN_INTRODUCTION = r"""
|
|
Returns:
|
|
:class:`~{full_output_type}` or :obj:`tuple(tf.Tensor)`: A :class:`~{full_output_type}` (if
|
|
``return_dict=True`` is passed or when ``config.return_dict=True``) or a tuple of :obj:`tf.Tensor` comprising
|
|
various elements depending on the configuration (:class:`~transformers.{config_class}`) and inputs.
|
|
|
|
"""
|
|
|
|
|
|
def _get_indent(t):
|
|
"""Returns the indentation in the first line of t"""
|
|
search = re.search(r"^(\s*)\S", t)
|
|
return "" if search is None else search.groups()[0]
|
|
|
|
|
|
def _convert_output_args_doc(output_args_doc):
|
|
"""Convert output_args_doc to display properly."""
|
|
# Split output_arg_doc in blocks argument/description
|
|
indent = _get_indent(output_args_doc)
|
|
blocks = []
|
|
current_block = ""
|
|
for line in output_args_doc.split("\n"):
|
|
# If the indent is the same as the beginning, the line is the name of new arg.
|
|
if _get_indent(line) == indent:
|
|
if len(current_block) > 0:
|
|
blocks.append(current_block[:-1])
|
|
current_block = f"{line}\n"
|
|
else:
|
|
# Otherwise it's part of the description of the current arg.
|
|
# We need to remove 2 spaces to the indentation.
|
|
current_block += f"{line[2:]}\n"
|
|
blocks.append(current_block[:-1])
|
|
|
|
# Format each block for proper rendering
|
|
for i in range(len(blocks)):
|
|
blocks[i] = re.sub(r"^(\s+)(\S+)(\s+)", r"\1- **\2**\3", blocks[i])
|
|
blocks[i] = re.sub(r":\s*\n\s*(\S)", r" -- \1", blocks[i])
|
|
|
|
return "\n".join(blocks)
|
|
|
|
|
|
def _prepare_output_docstrings(output_type, config_class):
|
|
"""
|
|
Prepares the return part of the docstring using `output_type`.
|
|
"""
|
|
docstrings = output_type.__doc__
|
|
|
|
# Remove the head of the docstring to keep the list of args only
|
|
lines = docstrings.split("\n")
|
|
i = 0
|
|
while i < len(lines) and re.search(r"^\s*(Args|Parameters):\s*$", lines[i]) is None:
|
|
i += 1
|
|
if i < len(lines):
|
|
docstrings = "\n".join(lines[(i + 1) :])
|
|
docstrings = _convert_output_args_doc(docstrings)
|
|
|
|
# Add the return introduction
|
|
full_output_type = f"{output_type.__module__}.{output_type.__name__}"
|
|
intro = TF_RETURN_INTRODUCTION if output_type.__name__.startswith("TF") else PT_RETURN_INTRODUCTION
|
|
intro = intro.format(full_output_type=full_output_type, config_class=config_class)
|
|
return intro + docstrings
|
|
|
|
|
|
PT_TOKEN_CLASSIFICATION_SAMPLE = r"""
|
|
Example::
|
|
|
|
>>> from transformers import {tokenizer_class}, {model_class}
|
|
>>> import torch
|
|
|
|
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
|
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
|
|
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
|
>>> labels = torch.tensor([1] * inputs["input_ids"].size(1)).unsqueeze(0) # Batch size 1
|
|
|
|
>>> outputs = model(**inputs, labels=labels)
|
|
>>> loss = outputs.loss
|
|
>>> logits = outputs.logits
|
|
"""
|
|
|
|
PT_QUESTION_ANSWERING_SAMPLE = r"""
|
|
Example::
|
|
|
|
>>> from transformers import {tokenizer_class}, {model_class}
|
|
>>> import torch
|
|
|
|
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
|
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
|
|
|
>>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
|
|
>>> inputs = tokenizer(question, text, return_tensors='pt')
|
|
>>> start_positions = torch.tensor([1])
|
|
>>> end_positions = torch.tensor([3])
|
|
|
|
>>> outputs = model(**inputs, start_positions=start_positions, end_positions=end_positions)
|
|
>>> loss = outputs.loss
|
|
>>> start_scores = outputs.start_logits
|
|
>>> end_scores = outputs.end_logits
|
|
"""
|
|
|
|
PT_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
|
|
Example::
|
|
|
|
>>> from transformers import {tokenizer_class}, {model_class}
|
|
>>> import torch
|
|
|
|
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
|
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
|
|
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
|
>>> labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
|
|
>>> outputs = model(**inputs, labels=labels)
|
|
>>> loss = outputs.loss
|
|
>>> logits = outputs.logits
|
|
"""
|
|
|
|
PT_MASKED_LM_SAMPLE = r"""
|
|
Example::
|
|
|
|
>>> from transformers import {tokenizer_class}, {model_class}
|
|
>>> import torch
|
|
|
|
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
|
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
|
|
|
>>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="pt")
|
|
>>> labels = tokenizer("The capital of France is Paris.", return_tensors="pt")["input_ids"]
|
|
|
|
>>> outputs = model(**inputs, labels=labels)
|
|
>>> loss = outputs.loss
|
|
>>> logits = outputs.logits
|
|
"""
|
|
|
|
PT_BASE_MODEL_SAMPLE = r"""
|
|
Example::
|
|
|
|
>>> from transformers import {tokenizer_class}, {model_class}
|
|
>>> import torch
|
|
|
|
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
|
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
|
|
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
|
>>> outputs = model(**inputs)
|
|
|
|
>>> last_hidden_states = outputs.last_hidden_state
|
|
"""
|
|
|
|
PT_MULTIPLE_CHOICE_SAMPLE = r"""
|
|
Example::
|
|
|
|
>>> from transformers import {tokenizer_class}, {model_class}
|
|
>>> import torch
|
|
|
|
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
|
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
|
|
|
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
|
>>> choice0 = "It is eaten with a fork and a knife."
|
|
>>> choice1 = "It is eaten while held in the hand."
|
|
>>> labels = torch.tensor(0).unsqueeze(0) # choice0 is correct (according to Wikipedia ;)), batch size 1
|
|
|
|
>>> encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors='pt', padding=True)
|
|
>>> outputs = model(**{{k: v.unsqueeze(0) for k,v in encoding.items()}}, labels=labels) # batch size is 1
|
|
|
|
>>> # the linear classifier still needs to be trained
|
|
>>> loss = outputs.loss
|
|
>>> logits = outputs.logits
|
|
"""
|
|
|
|
PT_CAUSAL_LM_SAMPLE = r"""
|
|
Example::
|
|
|
|
>>> import torch
|
|
>>> from transformers import {tokenizer_class}, {model_class}
|
|
|
|
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
|
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
|
|
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
|
>>> outputs = model(**inputs, labels=inputs["input_ids"])
|
|
>>> loss = outputs.loss
|
|
>>> logits = outputs.logits
|
|
"""
|
|
|
|
TF_TOKEN_CLASSIFICATION_SAMPLE = r"""
|
|
Example::
|
|
|
|
>>> from transformers import {tokenizer_class}, {model_class}
|
|
>>> import tensorflow as tf
|
|
|
|
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
|
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
|
|
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
|
|
>>> input_ids = inputs["input_ids"]
|
|
>>> inputs["labels"] = tf.reshape(tf.constant([1] * tf.size(input_ids).numpy()), (-1, tf.size(input_ids))) # Batch size 1
|
|
|
|
>>> outputs = model(inputs)
|
|
>>> loss = outputs.loss
|
|
>>> logits = outputs.logits
|
|
"""
|
|
|
|
TF_QUESTION_ANSWERING_SAMPLE = r"""
|
|
Example::
|
|
|
|
>>> from transformers import {tokenizer_class}, {model_class}
|
|
>>> import tensorflow as tf
|
|
|
|
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
|
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
|
|
|
>>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
|
|
>>> input_dict = tokenizer(question, text, return_tensors='tf')
|
|
>>> outputs = model(input_dict)
|
|
>>> start_logits = outputs.start_logits
|
|
>>> end_logits = outputs.end_logits
|
|
|
|
>>> all_tokens = tokenizer.convert_ids_to_tokens(input_dict["input_ids"].numpy()[0])
|
|
>>> answer = ' '.join(all_tokens[tf.math.argmax(start_logits, 1)[0] : tf.math.argmax(end_logits, 1)[0]+1])
|
|
"""
|
|
|
|
TF_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
|
|
Example::
|
|
|
|
>>> from transformers import {tokenizer_class}, {model_class}
|
|
>>> import tensorflow as tf
|
|
|
|
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
|
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
|
|
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
|
|
>>> inputs["labels"] = tf.reshape(tf.constant(1), (-1, 1)) # Batch size 1
|
|
|
|
>>> outputs = model(inputs)
|
|
>>> loss = outputs.loss
|
|
>>> logits = outputs.logits
|
|
"""
|
|
|
|
TF_MASKED_LM_SAMPLE = r"""
|
|
Example::
|
|
|
|
>>> from transformers import {tokenizer_class}, {model_class}
|
|
>>> import tensorflow as tf
|
|
|
|
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
|
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
|
|
|
>>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="tf")
|
|
>>> inputs["labels"] = tokenizer("The capital of France is Paris.", return_tensors="tf")["input_ids"]
|
|
|
|
>>> outputs = model(inputs)
|
|
>>> loss = outputs.loss
|
|
>>> logits = outputs.logits
|
|
"""
|
|
|
|
TF_BASE_MODEL_SAMPLE = r"""
|
|
Example::
|
|
|
|
>>> from transformers import {tokenizer_class}, {model_class}
|
|
>>> import tensorflow as tf
|
|
|
|
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
|
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
|
|
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
|
|
>>> outputs = model(inputs)
|
|
|
|
>>> last_hidden_states = outputs.last_hidden_state
|
|
"""
|
|
|
|
TF_MULTIPLE_CHOICE_SAMPLE = r"""
|
|
Example::
|
|
|
|
>>> from transformers import {tokenizer_class}, {model_class}
|
|
>>> import tensorflow as tf
|
|
|
|
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
|
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
|
|
|
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
|
>>> choice0 = "It is eaten with a fork and a knife."
|
|
>>> choice1 = "It is eaten while held in the hand."
|
|
|
|
>>> encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors='tf', padding=True)
|
|
>>> inputs = {{k: tf.expand_dims(v, 0) for k, v in encoding.items()}}
|
|
>>> outputs = model(inputs) # batch size is 1
|
|
|
|
>>> # the linear classifier still needs to be trained
|
|
>>> logits = outputs.logits
|
|
"""
|
|
|
|
TF_CAUSAL_LM_SAMPLE = r"""
|
|
Example::
|
|
|
|
>>> from transformers import {tokenizer_class}, {model_class}
|
|
>>> import tensorflow as tf
|
|
|
|
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
|
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
|
|
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
|
|
>>> outputs = model(inputs)
|
|
>>> logits = outputs.logits
|
|
"""
|
|
|
|
|
|
def add_code_sample_docstrings(
|
|
*docstr, tokenizer_class=None, checkpoint=None, output_type=None, config_class=None, mask=None
|
|
):
|
|
def docstring_decorator(fn):
|
|
model_class = fn.__qualname__.split(".")[0]
|
|
is_tf_class = model_class[:2] == "TF"
|
|
doc_kwargs = dict(model_class=model_class, tokenizer_class=tokenizer_class, checkpoint=checkpoint)
|
|
|
|
if "SequenceClassification" in model_class:
|
|
code_sample = TF_SEQUENCE_CLASSIFICATION_SAMPLE if is_tf_class else PT_SEQUENCE_CLASSIFICATION_SAMPLE
|
|
elif "QuestionAnswering" in model_class:
|
|
code_sample = TF_QUESTION_ANSWERING_SAMPLE if is_tf_class else PT_QUESTION_ANSWERING_SAMPLE
|
|
elif "TokenClassification" in model_class:
|
|
code_sample = TF_TOKEN_CLASSIFICATION_SAMPLE if is_tf_class else PT_TOKEN_CLASSIFICATION_SAMPLE
|
|
elif "MultipleChoice" in model_class:
|
|
code_sample = TF_MULTIPLE_CHOICE_SAMPLE if is_tf_class else PT_MULTIPLE_CHOICE_SAMPLE
|
|
elif "MaskedLM" in model_class or model_class in ["FlaubertWithLMHeadModel", "XLMWithLMHeadModel"]:
|
|
doc_kwargs["mask"] = "[MASK]" if mask is None else mask
|
|
code_sample = TF_MASKED_LM_SAMPLE if is_tf_class else PT_MASKED_LM_SAMPLE
|
|
elif "LMHead" in model_class or "CausalLM" in model_class:
|
|
code_sample = TF_CAUSAL_LM_SAMPLE if is_tf_class else PT_CAUSAL_LM_SAMPLE
|
|
elif "Model" in model_class or "Encoder" in model_class:
|
|
code_sample = TF_BASE_MODEL_SAMPLE if is_tf_class else PT_BASE_MODEL_SAMPLE
|
|
else:
|
|
raise ValueError(f"Docstring can't be built for model {model_class}")
|
|
|
|
output_doc = _prepare_output_docstrings(output_type, config_class) if output_type is not None else ""
|
|
built_doc = code_sample.format(**doc_kwargs)
|
|
fn.__doc__ = (fn.__doc__ or "") + "".join(docstr) + output_doc + built_doc
|
|
return fn
|
|
|
|
return docstring_decorator
|
|
|
|
|
|
def replace_return_docstrings(output_type=None, config_class=None):
|
|
def docstring_decorator(fn):
|
|
docstrings = fn.__doc__
|
|
lines = docstrings.split("\n")
|
|
i = 0
|
|
while i < len(lines) and re.search(r"^\s*Returns?:\s*$", lines[i]) is None:
|
|
i += 1
|
|
if i < len(lines):
|
|
lines[i] = _prepare_output_docstrings(output_type, config_class)
|
|
docstrings = "\n".join(lines)
|
|
else:
|
|
raise ValueError(
|
|
f"The function {fn} should have an empty 'Return:' or 'Returns:' in its docstring as placeholder, current docstring is:\n{docstrings}"
|
|
)
|
|
fn.__doc__ = docstrings
|
|
return fn
|
|
|
|
return docstring_decorator
|
|
|
|
|
|
def is_remote_url(url_or_filename):
|
|
parsed = urlparse(url_or_filename)
|
|
return parsed.scheme in ("http", "https")
|
|
|
|
|
|
def hf_bucket_url(
|
|
model_id: str, filename: str, subfolder: Optional[str] = None, revision: Optional[str] = None, mirror=None
|
|
) -> str:
|
|
"""
|
|
Resolve a model identifier, a file name, and an optional revision id, to a huggingface.co-hosted url, redirecting
|
|
to Cloudfront (a Content Delivery Network, or CDN) for large files.
|
|
|
|
Cloudfront is replicated over the globe so downloads are way faster for the end user (and it also lowers our
|
|
bandwidth costs).
|
|
|
|
Cloudfront aggressively caches files by default (default TTL is 24 hours), however this is not an issue here
|
|
because we migrated to a git-based versioning system on huggingface.co, so we now store the files on S3/Cloudfront
|
|
in a content-addressable way (i.e., the file name is its hash). Using content-addressable filenames means cache
|
|
can't ever be stale.
|
|
|
|
In terms of client-side caching from this library, we base our caching on the objects' ETag. An object' ETag is:
|
|
its sha1 if stored in git, or its sha256 if stored in git-lfs. Files cached locally from transformers before v3.5.0
|
|
are not shared with those new files, because the cached file's name contains a hash of the url (which changed).
|
|
"""
|
|
if subfolder is not None:
|
|
filename = f"{subfolder}/{filename}"
|
|
|
|
if mirror:
|
|
endpoint = PRESET_MIRROR_DICT.get(mirror, mirror)
|
|
legacy_format = "/" not in model_id
|
|
if legacy_format:
|
|
return f"{endpoint}/{model_id}-{filename}"
|
|
else:
|
|
return f"{endpoint}/{model_id}/{filename}"
|
|
|
|
if revision is None:
|
|
revision = "main"
|
|
return HUGGINGFACE_CO_PREFIX.format(model_id=model_id, revision=revision, filename=filename)
|
|
|
|
|
|
def url_to_filename(url: str, etag: Optional[str] = None) -> str:
|
|
"""
|
|
Convert `url` into a hashed filename in a repeatable way. If `etag` is specified, append its hash to the url's,
|
|
delimited by a period. If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name so that TF 2.0 can
|
|
identify it as a HDF5 file (see
|
|
https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
|
|
"""
|
|
url_bytes = url.encode("utf-8")
|
|
filename = sha256(url_bytes).hexdigest()
|
|
|
|
if etag:
|
|
etag_bytes = etag.encode("utf-8")
|
|
filename += "." + sha256(etag_bytes).hexdigest()
|
|
|
|
if url.endswith(".h5"):
|
|
filename += ".h5"
|
|
|
|
return filename
|
|
|
|
|
|
def filename_to_url(filename, cache_dir=None):
|
|
"""
|
|
Return the url and etag (which may be ``None``) stored for `filename`. Raise ``EnvironmentError`` if `filename` or
|
|
its stored metadata do not exist.
|
|
"""
|
|
if cache_dir is None:
|
|
cache_dir = TRANSFORMERS_CACHE
|
|
if isinstance(cache_dir, Path):
|
|
cache_dir = str(cache_dir)
|
|
|
|
cache_path = os.path.join(cache_dir, filename)
|
|
if not os.path.exists(cache_path):
|
|
raise EnvironmentError("file {} not found".format(cache_path))
|
|
|
|
meta_path = cache_path + ".json"
|
|
if not os.path.exists(meta_path):
|
|
raise EnvironmentError("file {} not found".format(meta_path))
|
|
|
|
with open(meta_path, encoding="utf-8") as meta_file:
|
|
metadata = json.load(meta_file)
|
|
url = metadata["url"]
|
|
etag = metadata["etag"]
|
|
|
|
return url, etag
|
|
|
|
|
|
def get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]:
|
|
"""
|
|
Returns a list of tuples representing model binaries that are cached locally. Each tuple has shape
|
|
:obj:`(model_url, etag, size_MB)`. Filenames in :obj:`cache_dir` are use to get the metadata for each model, only
|
|
urls ending with `.bin` are added.
|
|
|
|
Args:
|
|
cache_dir (:obj:`Union[str, Path]`, `optional`):
|
|
The cache directory to search for models within. Will default to the transformers cache if unset.
|
|
|
|
Returns:
|
|
List[Tuple]: List of tuples each with shape :obj:`(model_url, etag, size_MB)`
|
|
"""
|
|
if cache_dir is None:
|
|
cache_dir = TRANSFORMERS_CACHE
|
|
elif isinstance(cache_dir, Path):
|
|
cache_dir = str(cache_dir)
|
|
|
|
cached_models = []
|
|
for file in os.listdir(cache_dir):
|
|
if file.endswith(".json"):
|
|
meta_path = os.path.join(cache_dir, file)
|
|
with open(meta_path, encoding="utf-8") as meta_file:
|
|
metadata = json.load(meta_file)
|
|
url = metadata["url"]
|
|
etag = metadata["etag"]
|
|
if url.endswith(".bin"):
|
|
size_MB = os.path.getsize(meta_path.strip(".json")) / 1e6
|
|
cached_models.append((url, etag, size_MB))
|
|
|
|
return cached_models
|
|
|
|
|
|
def cached_path(
|
|
url_or_filename,
|
|
cache_dir=None,
|
|
force_download=False,
|
|
proxies=None,
|
|
resume_download=False,
|
|
user_agent: Union[Dict, str, None] = None,
|
|
extract_compressed_file=False,
|
|
force_extract=False,
|
|
use_auth_token: Union[bool, str, None] = None,
|
|
local_files_only=False,
|
|
) -> Optional[str]:
|
|
"""
|
|
Given something that might be a URL (or might be a local path), determine which. If it's a URL, download the file
|
|
and cache it, and return the path to the cached file. If it's already a local path, make sure the file exists and
|
|
then return the path
|
|
|
|
Args:
|
|
cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
|
|
force_download: if True, re-download the file even if it's already cached in the cache dir.
|
|
resume_download: if True, resume the download if incompletely received file is found.
|
|
user_agent: Optional string or dict that will be appended to the user-agent on remote requests.
|
|
use_auth_token: Optional string or boolean to use as Bearer token for remote files. If True,
|
|
will get token from ~/.huggingface.
|
|
extract_compressed_file: if True and the path point to a zip or tar file, extract the compressed
|
|
file in a folder along the archive.
|
|
force_extract: if True when extract_compressed_file is True and the archive was already extracted,
|
|
re-extract the archive and override the folder where it was extracted.
|
|
|
|
Return:
|
|
Local path (string) of file or if networking is off, last version of file cached on disk.
|
|
|
|
Raises:
|
|
In case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
|
|
"""
|
|
if cache_dir is None:
|
|
cache_dir = TRANSFORMERS_CACHE
|
|
if isinstance(url_or_filename, Path):
|
|
url_or_filename = str(url_or_filename)
|
|
if isinstance(cache_dir, Path):
|
|
cache_dir = str(cache_dir)
|
|
|
|
if is_remote_url(url_or_filename):
|
|
# URL, so get it from the cache (downloading if necessary)
|
|
output_path = get_from_cache(
|
|
url_or_filename,
|
|
cache_dir=cache_dir,
|
|
force_download=force_download,
|
|
proxies=proxies,
|
|
resume_download=resume_download,
|
|
user_agent=user_agent,
|
|
use_auth_token=use_auth_token,
|
|
local_files_only=local_files_only,
|
|
)
|
|
elif os.path.exists(url_or_filename):
|
|
# File, and it exists.
|
|
output_path = url_or_filename
|
|
elif urlparse(url_or_filename).scheme == "":
|
|
# File, but it doesn't exist.
|
|
raise EnvironmentError("file {} not found".format(url_or_filename))
|
|
else:
|
|
# Something unknown
|
|
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
|
|
|
|
if extract_compressed_file:
|
|
if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path):
|
|
return output_path
|
|
|
|
# Path where we extract compressed archives
|
|
# We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/"
|
|
output_dir, output_file = os.path.split(output_path)
|
|
output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
|
|
output_path_extracted = os.path.join(output_dir, output_extract_dir_name)
|
|
|
|
if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract:
|
|
return output_path_extracted
|
|
|
|
# Prevent parallel extractions
|
|
lock_path = output_path + ".lock"
|
|
with FileLock(lock_path):
|
|
shutil.rmtree(output_path_extracted, ignore_errors=True)
|
|
os.makedirs(output_path_extracted)
|
|
if is_zipfile(output_path):
|
|
with ZipFile(output_path, "r") as zip_file:
|
|
zip_file.extractall(output_path_extracted)
|
|
zip_file.close()
|
|
elif tarfile.is_tarfile(output_path):
|
|
tar_file = tarfile.open(output_path)
|
|
tar_file.extractall(output_path_extracted)
|
|
tar_file.close()
|
|
else:
|
|
raise EnvironmentError("Archive format of {} could not be identified".format(output_path))
|
|
|
|
return output_path_extracted
|
|
|
|
return output_path
|
|
|
|
|
|
def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
|
|
"""
|
|
Formats a user-agent string with basic info about a request.
|
|
"""
|
|
ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
|
|
if is_torch_available():
|
|
ua += f"; torch/{_torch_version}"
|
|
if is_tf_available():
|
|
ua += f"; tensorflow/{_tf_version}"
|
|
if isinstance(user_agent, dict):
|
|
ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
|
|
elif isinstance(user_agent, str):
|
|
ua += "; " + user_agent
|
|
return ua
|
|
|
|
|
|
def http_get(url: str, temp_file: BinaryIO, proxies=None, resume_size=0, headers: Optional[Dict[str, str]] = None):
|
|
"""
|
|
Donwload remote file. Do not gobble up errors.
|
|
"""
|
|
headers = copy.deepcopy(headers)
|
|
if resume_size > 0:
|
|
headers["Range"] = "bytes=%d-" % (resume_size,)
|
|
r = requests.get(url, stream=True, proxies=proxies, headers=headers)
|
|
r.raise_for_status()
|
|
content_length = r.headers.get("Content-Length")
|
|
total = resume_size + int(content_length) if content_length is not None else None
|
|
progress = tqdm(
|
|
unit="B",
|
|
unit_scale=True,
|
|
total=total,
|
|
initial=resume_size,
|
|
desc="Downloading",
|
|
disable=bool(logging.get_verbosity() == logging.NOTSET),
|
|
)
|
|
for chunk in r.iter_content(chunk_size=1024):
|
|
if chunk: # filter out keep-alive new chunks
|
|
progress.update(len(chunk))
|
|
temp_file.write(chunk)
|
|
progress.close()
|
|
|
|
|
|
def get_from_cache(
|
|
url: str,
|
|
cache_dir=None,
|
|
force_download=False,
|
|
proxies=None,
|
|
etag_timeout=10,
|
|
resume_download=False,
|
|
user_agent: Union[Dict, str, None] = None,
|
|
use_auth_token: Union[bool, str, None] = None,
|
|
local_files_only=False,
|
|
) -> Optional[str]:
|
|
"""
|
|
Given a URL, look for the corresponding file in the local cache. If it's not there, download it. Then return the
|
|
path to the cached file.
|
|
|
|
Return:
|
|
Local path (string) of file or if networking is off, last version of file cached on disk.
|
|
|
|
Raises:
|
|
In case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
|
|
"""
|
|
if cache_dir is None:
|
|
cache_dir = TRANSFORMERS_CACHE
|
|
if isinstance(cache_dir, Path):
|
|
cache_dir = str(cache_dir)
|
|
|
|
os.makedirs(cache_dir, exist_ok=True)
|
|
|
|
headers = {"user-agent": http_user_agent(user_agent)}
|
|
if isinstance(use_auth_token, str):
|
|
headers["authorization"] = "Bearer {}".format(use_auth_token)
|
|
elif use_auth_token:
|
|
token = HfFolder.get_token()
|
|
if token is None:
|
|
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
|
|
headers["authorization"] = "Bearer {}".format(token)
|
|
|
|
url_to_download = url
|
|
etag = None
|
|
if not local_files_only:
|
|
try:
|
|
r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout)
|
|
r.raise_for_status()
|
|
etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
|
|
# We favor a custom header indicating the etag of the linked resource, and
|
|
# we fallback to the regular etag header.
|
|
# If we don't have any of those, raise an error.
|
|
if etag is None:
|
|
raise OSError(
|
|
"Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility."
|
|
)
|
|
# In case of a redirect,
|
|
# save an extra redirect on the request.get call,
|
|
# and ensure we download the exact atomic version even if it changed
|
|
# between the HEAD and the GET (unlikely, but hey).
|
|
if 300 <= r.status_code <= 399:
|
|
url_to_download = r.headers["Location"]
|
|
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
|
|
# etag is already None
|
|
pass
|
|
|
|
filename = url_to_filename(url, etag)
|
|
|
|
# get cache path to put the file
|
|
cache_path = os.path.join(cache_dir, filename)
|
|
|
|
# etag is None == we don't have a connection or we passed local_files_only.
|
|
# try to get the last downloaded one
|
|
if etag is None:
|
|
if os.path.exists(cache_path):
|
|
return cache_path
|
|
else:
|
|
matching_files = [
|
|
file
|
|
for file in fnmatch.filter(os.listdir(cache_dir), filename.split(".")[0] + ".*")
|
|
if not file.endswith(".json") and not file.endswith(".lock")
|
|
]
|
|
if len(matching_files) > 0:
|
|
return os.path.join(cache_dir, matching_files[-1])
|
|
else:
|
|
# If files cannot be found and local_files_only=True,
|
|
# the models might've been found if local_files_only=False
|
|
# Notify the user about that
|
|
if local_files_only:
|
|
raise FileNotFoundError(
|
|
"Cannot find the requested files in the cached path and outgoing traffic has been"
|
|
" disabled. To enable model look-ups and downloads online, set 'local_files_only'"
|
|
" to False."
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
"Connection error, and we cannot find the requested files in the cached path."
|
|
" Please try again or make sure your Internet connection is on."
|
|
)
|
|
|
|
# From now on, etag is not None.
|
|
if os.path.exists(cache_path) and not force_download:
|
|
return cache_path
|
|
|
|
# Prevent parallel downloads of the same file with a lock.
|
|
lock_path = cache_path + ".lock"
|
|
with FileLock(lock_path):
|
|
|
|
# If the download just completed while the lock was activated.
|
|
if os.path.exists(cache_path) and not force_download:
|
|
# Even if returning early like here, the lock will be released.
|
|
return cache_path
|
|
|
|
if resume_download:
|
|
incomplete_path = cache_path + ".incomplete"
|
|
|
|
@contextmanager
|
|
def _resumable_file_manager() -> "io.BufferedWriter":
|
|
with open(incomplete_path, "ab") as f:
|
|
yield f
|
|
|
|
temp_file_manager = _resumable_file_manager
|
|
if os.path.exists(incomplete_path):
|
|
resume_size = os.stat(incomplete_path).st_size
|
|
else:
|
|
resume_size = 0
|
|
else:
|
|
temp_file_manager = partial(tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False)
|
|
resume_size = 0
|
|
|
|
# Download to temporary file, then copy to cache dir once finished.
|
|
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
|
with temp_file_manager() as temp_file:
|
|
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
|
|
|
|
http_get(url_to_download, temp_file, proxies=proxies, resume_size=resume_size, headers=headers)
|
|
|
|
logger.info("storing %s in cache at %s", url, cache_path)
|
|
os.replace(temp_file.name, cache_path)
|
|
|
|
logger.info("creating metadata file for %s", cache_path)
|
|
meta = {"url": url, "etag": etag}
|
|
meta_path = cache_path + ".json"
|
|
with open(meta_path, "w") as meta_file:
|
|
json.dump(meta, meta_file)
|
|
|
|
return cache_path
|
|
|
|
|
|
class cached_property(property):
|
|
"""
|
|
Descriptor that mimics @property but caches output in member variable.
|
|
|
|
From tensorflow_datasets
|
|
|
|
Built-in in functools from Python 3.8.
|
|
"""
|
|
|
|
def __get__(self, obj, objtype=None):
|
|
# See docs.python.org/3/howto/descriptor.html#properties
|
|
if obj is None:
|
|
return self
|
|
if self.fget is None:
|
|
raise AttributeError("unreadable attribute")
|
|
attr = "__cached_" + self.fget.__name__
|
|
cached = getattr(obj, attr, None)
|
|
if cached is None:
|
|
cached = self.fget(obj)
|
|
setattr(obj, attr, cached)
|
|
return cached
|
|
|
|
|
|
def torch_required(func):
|
|
# Chose a different decorator name than in tests so it's clear they are not the same.
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
if is_torch_available():
|
|
return func(*args, **kwargs)
|
|
else:
|
|
raise ImportError(f"Method `{func.__name__}` requires PyTorch.")
|
|
|
|
return wrapper
|
|
|
|
|
|
def tf_required(func):
|
|
# Chose a different decorator name than in tests so it's clear they are not the same.
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
if is_tf_available():
|
|
return func(*args, **kwargs)
|
|
else:
|
|
raise ImportError(f"Method `{func.__name__}` requires TF.")
|
|
|
|
return wrapper
|
|
|
|
|
|
def is_tensor(x):
|
|
""" Tests if ``x`` is a :obj:`torch.Tensor`, :obj:`tf.Tensor` or :obj:`np.ndarray`. """
|
|
if is_torch_available():
|
|
import torch
|
|
|
|
if isinstance(x, torch.Tensor):
|
|
return True
|
|
if is_tf_available():
|
|
import tensorflow as tf
|
|
|
|
if isinstance(x, tf.Tensor):
|
|
return True
|
|
return isinstance(x, np.ndarray)
|
|
|
|
|
|
class ModelOutput(OrderedDict):
|
|
"""
|
|
Base class for all model outputs as dataclass. Has a ``__getitem__`` that allows indexing by integer or slice (like
|
|
a tuple) or strings (like a dictionary) that will ignore the ``None`` attributes. Otherwise behaves like a regular
|
|
python dictionary.
|
|
|
|
.. warning::
|
|
You can't unpack a :obj:`ModelOutput` directly. Use the :meth:`~transformers.file_utils.ModelOutput.to_tuple`
|
|
method to convert it to a tuple before.
|
|
"""
|
|
|
|
def __post_init__(self):
|
|
class_fields = fields(self)
|
|
|
|
# Safety and consistency checks
|
|
assert len(class_fields), f"{self.__class__.__name__} has no fields."
|
|
assert all(
|
|
field.default is None for field in class_fields[1:]
|
|
), f"{self.__class__.__name__} should not have more than one required field."
|
|
|
|
first_field = getattr(self, class_fields[0].name)
|
|
other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
|
|
|
|
if other_fields_are_none and not is_tensor(first_field):
|
|
try:
|
|
iterator = iter(first_field)
|
|
first_field_iterator = True
|
|
except TypeError:
|
|
first_field_iterator = False
|
|
|
|
# if we provided an iterator as first field and the iterator is a (key, value) iterator
|
|
# set the associated fields
|
|
if first_field_iterator:
|
|
for element in iterator:
|
|
if (
|
|
not isinstance(element, (list, tuple))
|
|
or not len(element) == 2
|
|
or not isinstance(element[0], str)
|
|
):
|
|
break
|
|
setattr(self, element[0], element[1])
|
|
if element[1] is not None:
|
|
self[element[0]] = element[1]
|
|
elif first_field is not None:
|
|
self[class_fields[0].name] = first_field
|
|
else:
|
|
for field in class_fields:
|
|
v = getattr(self, field.name)
|
|
if v is not None:
|
|
self[field.name] = v
|
|
|
|
def __delitem__(self, *args, **kwargs):
|
|
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
|
|
|
|
def setdefault(self, *args, **kwargs):
|
|
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
|
|
|
|
def pop(self, *args, **kwargs):
|
|
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
|
|
|
|
def update(self, *args, **kwargs):
|
|
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
|
|
|
|
def __getitem__(self, k):
|
|
if isinstance(k, str):
|
|
inner_dict = {k: v for (k, v) in self.items()}
|
|
return inner_dict[k]
|
|
else:
|
|
return self.to_tuple()[k]
|
|
|
|
def __setattr__(self, name, value):
|
|
if name in self.keys() and value is not None:
|
|
# Don't call self.__setitem__ to avoid recursion errors
|
|
super().__setitem__(name, value)
|
|
super().__setattr__(name, value)
|
|
|
|
def __setitem__(self, key, value):
|
|
# Will raise a KeyException if needed
|
|
super().__setitem__(key, value)
|
|
# Don't call self.__setattr__ to avoid recursion errors
|
|
super().__setattr__(key, value)
|
|
|
|
def to_tuple(self) -> Tuple[Any]:
|
|
"""
|
|
Convert self to a tuple containing all the attributes/keys that are not ``None``.
|
|
"""
|
|
return tuple(self[k] for k in self.keys())
|
|
|
|
|
|
class _BaseLazyModule(ModuleType):
|
|
"""
|
|
Module class that surfaces all objects but only performs associated imports when the objects are requested.
|
|
"""
|
|
|
|
# Very heavily inspired by optuna.integration._IntegrationModule
|
|
# https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
|
|
def __init__(self, name, import_structure):
|
|
super().__init__(name)
|
|
self._modules = set(import_structure.keys())
|
|
self._class_to_module = {}
|
|
for key, values in import_structure.items():
|
|
for value in values:
|
|
self._class_to_module[value] = key
|
|
# Needed for autocompletion in an IDE
|
|
self.__all__ = list(import_structure.keys()) + sum(import_structure.values(), [])
|
|
|
|
# Needed for autocompletion in an IDE
|
|
def __dir__(self):
|
|
return super().__dir__() + self.__all__
|
|
|
|
def __getattr__(self, name: str) -> Any:
|
|
if name in self._modules:
|
|
value = self._get_module(name)
|
|
elif name in self._class_to_module.keys():
|
|
module = self._get_module(self._class_to_module[name])
|
|
value = getattr(module, name)
|
|
else:
|
|
raise AttributeError(f"module {self.__name__} has no attribute {name}")
|
|
|
|
setattr(self, name, value)
|
|
return value
|
|
|
|
def _get_module(self, module_name: str) -> ModuleType:
|
|
raise NotImplementedError
|