Integrate Bert-like model on Flax runtime. (#3722)
* WIP flax bert * Initial commit Bert Jax/Flax implementation. * Embeddings working and equivalent to PyTorch. * Move embeddings in its own module BertEmbeddings * Added jax.jit annotation on forward call * BertEncoder on par with PyTorch ! :D * Add BertPooler on par with PyTorch !! * Working Jax+Flax implementation of BertModel with < 1e-5 differences on the last layer. * Fix pooled output to take only the first token of the sequence. * Refactoring to use BertConfig from transformers. * Renamed FXBertModel to FlaxBertModel * Model is now initialized in FlaxBertModel constructor and reused. * WIP JaxPreTrainedModel * Cleaning up the code of FlaxBertModel * Added ability to load Flax model saved through save_pretrained() * Added ability to convert Pytorch Bert model to FlaxBert * FlaxBert can now load every Pytorch Bert model with on-the-fly conversion * Fix hardcoded shape values in conversion scripts. * Improve the way we handle LayerNorm conversion from PyTorch to Flax. * Added positional embeddings as parameter of BertModel with default to np.arange. * Let's roll FlaxRoberta ! * Fix missing position_ids parameters on predict for Bert * Flax backend now supports batched inputs Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Make it possible to load msgpacked model on convert from pytorch in last resort. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Moved save_pretrained to Jax base class along with more constructor parameters. * Use specialized, model dependent conversion functio. * Expose `is_flax_available` in file_utils. * Added unittest for Flax models. * Added run_tests_flax to the CI. * Introduce FlaxAutoModel * Added more unittests * Flax model reference the _MODEL_ARCHIVE_MAP from PyTorch model. * Addressing review comments. * Expose seed in both Bert and Roberta * Fix typo suggested by @stefan-it Co-Authored-By: Stefan Schweter <stefan@schweter.it> * Attempt to make style * Attempt to make style in tests too * Added jax & jaxlib to the flax optional dependencies. * Attempt to fix flake8 warnings ... * Redo black again and again * When black and flake8 fight each other for a space ... 💥 💥 💥 * Try removing trailing comma to make both black and flake happy! * Fix invalid is_<framework>_available call, thanks @LysandreJik 🎉 * Fix another invalid import in flax_roberta test * Bump and pin flax release to 0.1.0. * Make flake8 happy, remove unused jax import * Change the type of the catch for msgpack. * Remove unused import. * Put seed as optional constructor parameter. * trigger ci again * Fix too much parameters in BertAttention. * Formatting. * Simplify Flax unittests to avoid machine crashes. * Fix invalid number of arguments when raising issue for an unknown model. * Address @bastings comment in PR, moving jax.jit decorated outside of __call__ * Fix incorrect path to require_flax/require_pytorch functions. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Attempt to make style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correct rebasing of circle-ci dependencies Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Fix import sorting. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Fix unused imports. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Again import sorting... Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Installing missing nlp dependency for flax unittests. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Fix laoding of model for Flax implementations. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * jit the inner function call to make JAX-compatible Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Format ! Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Flake one more time 🎶 Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Rewrites BERT in Flax to the new Linen API (#7211) * Rewrite Flax HuggingFace PR to Linen * Some fixes * Fix tests * Fix CI with change of name of nlp (#7054) * nlp -> datasets * More nlp -> datasets * Woopsie * More nlp -> datasets * One last * Expose `is_flax_available` in file_utils. * Added run_tests_flax to the CI. * Attempt to make style * trigger ci again * Fix import sorting. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Revert "Rewrites BERT in Flax to the new Linen API (#7211)" This reverts commit 23703a5eb3364e26a1cbc3ee34b4710d86a674b0. * Remove jnp.lax references Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Reintroduce Linen changes ... Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Use jax native's gelu function. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Renaming BertModel to BertModule to highlight the fact this is the Flax Module object. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Rewrite FlaxAutoModel test to not rely on pretrained_model_archive_map Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Remove unused variable in BertModule. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Remove unused variable in BertModule again Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Attempt to have is_flax_available working again. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Introduce JAX TensorType Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Improve ImportError message when trying to convert to various TensorType format. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Makes Flax model jittable. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Ensure flax models are jittable in unittests. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Remove unused imports. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Ensure jax imports are guarded behind is_flax_available. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style again Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style again again Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style again again again Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Update src/transformers/file_utils.py Co-authored-by: Marc van Zee <marcvanzee@gmail.com> * Bump flax to it's latest version Co-authored-by: Marc van Zee <marcvanzee@gmail.com> * Bump jax version to at least 0.2.0 Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Update the unittest to use TensorType.JAX Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * isort import in tests. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Match new flax parameters name "params" Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Remove unused imports. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Add flax models to transformers __init__ Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Attempt to address all CI related comments. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correct circle.yml indent. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correct circle.yml indent (2) Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Remove coverage from flax tests Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Addressing many naming suggestions from comments Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Simplify for loop logic to interate over layers in FlaxBertLayerCollection Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * use f-string syntax for formatting logs. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Use config property from FlaxPreTrainedModel. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * use "cls_token" instead of "first_token" variable name. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * use "hidden_state" instead of "h" variable name. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correct class reference in docstring to link to Flax related modules. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Added HF + Google Flax team copyright. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make Roberta independent from Bert Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Move activation functions to flax_utils. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Move activation functions to flax_utils for bert. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Added docstring for BERT Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Update import for Bert and Roberta tokenizers Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * fix-copies Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correct FlaxRobertaLayer to match PyTorch. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Use the same store_artifact for flax unittest Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make sure gradient are disabled only locally for flax unittest using torch equivalence. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Use relative imports Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> Co-authored-by: Stefan Schweter <stefan@schweter.it> Co-authored-by: Marc van Zee <marcvanzee@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -139,6 +139,31 @@ jobs:
|
||||
- store_artifacts:
|
||||
path: ~/transformers/output.txt
|
||||
destination: test_output.txt
|
||||
run_tests_flax:
|
||||
working_directory: ~/transformers
|
||||
docker:
|
||||
- image: circleci/python:3.7
|
||||
environment:
|
||||
OMP_NUM_THREADS: 1
|
||||
resource_class: xlarge
|
||||
parallelism: 1
|
||||
steps:
|
||||
- checkout
|
||||
- restore_cache:
|
||||
keys:
|
||||
- v0.3-flax-{{ checksum "setup.py" }}
|
||||
- v0.3-{{ checksum "setup.py" }}
|
||||
- run: pip install --upgrade pip
|
||||
- run: pip install git+https://github.com/huggingface/datasets
|
||||
- run: sudo pip install .[flax,sklearn,torch,testing]
|
||||
- save_cache:
|
||||
key: v0.3-flax-{{ checksum "setup.py" }}
|
||||
paths:
|
||||
- '~/.cache/pip'
|
||||
- run: python -m pytest -n 8 --dist=loadfile -rA -s ./tests/ | tee output.txt
|
||||
- store_artifacts:
|
||||
path: ~/transformers/output.txt
|
||||
destination: test_output.txt
|
||||
run_tests_custom_tokenizers:
|
||||
working_directory: ~/transformers
|
||||
docker:
|
||||
@@ -305,6 +330,7 @@ workflows:
|
||||
- run_tests_torch_and_tf
|
||||
- run_tests_torch
|
||||
- run_tests_tf
|
||||
- run_tests_flax
|
||||
- build_doc
|
||||
- deploy_doc: *workflow_filters
|
||||
tpu_testing_jobs:
|
||||
|
||||
1
setup.py
1
setup.py
@@ -87,6 +87,7 @@ extras["tf-cpu"] = [
|
||||
# "keras2onnx @ git+git://github.com/onnx/keras-onnx.git@cbdc75cb950b16db7f0a67be96a278f8d2953b48#egg=keras2onnx",
|
||||
]
|
||||
extras["torch"] = ["torch>=1.0"]
|
||||
extras["flax"] = ["jaxlib==0.1.55", "jax>=0.2.0", "flax==0.2.2"]
|
||||
extras["onnxruntime"] = ["onnxruntime>=1.4.0", "onnxruntime-tools>=1.4.2"]
|
||||
|
||||
extras["serving"] = ["pydantic", "uvicorn", "fastapi", "starlette"]
|
||||
|
||||
@@ -103,6 +103,7 @@ from .file_utils import (
|
||||
is_apex_available,
|
||||
is_datasets_available,
|
||||
is_faiss_available,
|
||||
is_flax_available,
|
||||
is_psutil_available,
|
||||
is_py3nvml_available,
|
||||
is_sentencepiece_available,
|
||||
@@ -817,6 +818,10 @@ else:
|
||||
from .utils.dummy_tf_objects import *
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
from .modeling_flax_bert import FlaxBertModel
|
||||
from .modeling_flax_roberta import FlaxRobertaModel
|
||||
|
||||
if not is_tf_available() and not is_torch_available():
|
||||
logger.warning(
|
||||
"Neither PyTorch nor TensorFlow >= 2.0 have been found."
|
||||
|
||||
@@ -34,10 +34,13 @@ from .utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES"}
|
||||
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
|
||||
|
||||
try:
|
||||
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
||||
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
||||
if USE_TORCH in ("1", "ON", "YES", "AUTO") and USE_TF not in ("1", "ON", "YES"):
|
||||
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
|
||||
import torch
|
||||
|
||||
_torch_available = True # pylint: disable=invalid-name
|
||||
@@ -52,7 +55,7 @@ try:
|
||||
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
||||
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
||||
|
||||
if USE_TF in ("1", "ON", "YES", "AUTO") and USE_TORCH not in ("1", "ON", "YES"):
|
||||
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
|
||||
import tensorflow as tf
|
||||
|
||||
assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2
|
||||
@@ -65,6 +68,22 @@ except (ImportError, AssertionError):
|
||||
_tf_available = False # pylint: disable=invalid-name
|
||||
|
||||
|
||||
try:
|
||||
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
|
||||
|
||||
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
|
||||
import flax
|
||||
import jax
|
||||
|
||||
logger.info("JAX version {}, Flax: available".format(jax.__version__))
|
||||
logger.info("Flax available: {}".format(flax))
|
||||
_flax_available = True
|
||||
else:
|
||||
_flax_available = False
|
||||
except ImportError:
|
||||
_flax_available = False # pylint: disable=invalid-name
|
||||
|
||||
|
||||
try:
|
||||
import datasets # noqa: F401
|
||||
|
||||
@@ -213,6 +232,10 @@ def is_tf_available():
|
||||
return _tf_available
|
||||
|
||||
|
||||
def is_flax_available():
|
||||
return _flax_available
|
||||
|
||||
|
||||
def is_torch_tpu_available():
|
||||
return _torch_tpu_available
|
||||
|
||||
|
||||
167
src/transformers/modeling_flax_auto.py
Normal file
167
src/transformers/modeling_flax_auto.py
Normal file
@@ -0,0 +1,167 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Auto Model class. """
|
||||
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
from .configuration_auto import AutoConfig, BertConfig, RobertaConfig
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .modeling_flax_bert import FlaxBertModel
|
||||
from .modeling_flax_roberta import FlaxRobertaModel
|
||||
from .utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict(
|
||||
(key, value)
|
||||
for pretrained_map in [
|
||||
FlaxBertModel.pretrained_model_archive_map,
|
||||
FlaxRobertaModel.pretrained_model_archive_map,
|
||||
]
|
||||
for key, value, in pretrained_map.items()
|
||||
)
|
||||
|
||||
MODEL_MAPPING = OrderedDict(
|
||||
[
|
||||
(RobertaConfig, FlaxRobertaModel),
|
||||
(BertConfig, FlaxBertModel),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class FlaxAutoModel(object):
|
||||
r"""
|
||||
:class:`~transformers.FlaxAutoModel` is a generic model class
|
||||
that will be instantiated as one of the base model classes of the library
|
||||
when created with the `FlaxAutoModel.from_pretrained(pretrained_model_name_or_path)`
|
||||
or the `FlaxAutoModel.from_config(config)` class methods.
|
||||
|
||||
This class cannot be instantiated using `__init__()` (throws an error).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
raise EnvironmentError(
|
||||
"FlaxAutoModel is designed to be instantiated "
|
||||
"using the `FlaxAutoModel.from_pretrained(pretrained_model_name_or_path)` or "
|
||||
"`FlaxAutoModel.from_config(config)` methods."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
r"""Instantiates one of the base model classes of the library
|
||||
from a configuration.
|
||||
|
||||
Args:
|
||||
config (:class:`~transformers.PretrainedConfig`):
|
||||
The model class to instantiate is selected based on the configuration class:
|
||||
|
||||
- isInstance of `roberta` configuration class: :class:`~transformers.FlaxRobertaModel` (RoBERTa model)
|
||||
- isInstance of `bert` configuration class: :class:`~transformers.FlaxBertModel` (Bert model)
|
||||
Examples:
|
||||
|
||||
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
||||
model = FlaxAutoModel.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||
"""
|
||||
for config_class, model_class in MODEL_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
return model_class(config)
|
||||
raise ValueError(
|
||||
f"Unrecognized configuration class {config.__class__} "
|
||||
f"for this kind of FlaxAutoModel: {cls.__name__}.\n"
|
||||
f"Model type should be one of {', '.join(c.__name__ for c in MODEL_MAPPING.keys())}."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
r"""Instantiates one of the base model classes of the library
|
||||
from a pre-trained model configuration.
|
||||
|
||||
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
|
||||
The base model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
- contains `roberta`: :class:`~transformers.FlaxRobertaModel` (RoBERTa model)
|
||||
- contains `bert`: :class:`~transformers.FlaxBertModel` (Bert model)
|
||||
|
||||
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
|
||||
To train the model, you should first set it back in training mode with `model.train()`
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path: either:
|
||||
|
||||
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
|
||||
- a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
|
||||
- a path to a `directory` containing model weights saved using :func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
|
||||
- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
|
||||
|
||||
model_args: (`optional`) Sequence of positional arguments:
|
||||
All remaining positional arguments will be passed to the underlying model's ``__init__`` method
|
||||
|
||||
config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
|
||||
Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
|
||||
|
||||
- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
|
||||
- the model was saved using :func:`~transformers.FlaxPreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
|
||||
- the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
|
||||
|
||||
state_dict: (`optional`) dict:
|
||||
an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
|
||||
This option can be used if you want to create a model from a pretrained configuration but load your own weights.
|
||||
In this case though, you should check if using :func:`~transformers.FlaxPreTrainedModel.save_pretrained` and :func:`~transformers.FlaxPreTrainedModel.from_pretrained` is not a simpler option.
|
||||
|
||||
cache_dir: (`optional`) string:
|
||||
Path to a directory in which a downloaded pre-trained model
|
||||
configuration should be cached if the standard cache should not be used.
|
||||
|
||||
force_download: (`optional`) boolean, default False:
|
||||
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
|
||||
|
||||
resume_download: (`optional`) boolean, default False:
|
||||
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
|
||||
|
||||
proxies: (`optional`) dict, default None:
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
||||
The proxies are used on each request.
|
||||
|
||||
output_loading_info: (`optional`) boolean:
|
||||
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
|
||||
|
||||
kwargs: (`optional`) Remaining dictionary of keyword arguments:
|
||||
These arguments will be passed to the configuration and the model.
|
||||
|
||||
Examples::
|
||||
|
||||
model = FlaxAutoModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
|
||||
model = FlaxAutoModel.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||
assert model.config.output_attention == True
|
||||
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
for config_class, model_class in MODEL_MAPPING.items():
|
||||
if isinstance(config, config_class):
|
||||
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||
raise ValueError(
|
||||
f"Unrecognized configuration class {config.__class__} "
|
||||
f"for this kind of FlaxAutoModel: {cls.__name__}.\n"
|
||||
f"Model type should be one of {', '.join(c.__name__ for c in MODEL_MAPPING.keys())}"
|
||||
)
|
||||
438
src/transformers/modeling_flax_bert.py
Normal file
438
src/transformers/modeling_flax_bert.py
Normal file
@@ -0,0 +1,438 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Callable, Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.linen import compact
|
||||
|
||||
from .configuration_bert import BertConfig
|
||||
from .file_utils import add_start_docstrings
|
||||
from .modeling_flax_utils import FlaxPreTrainedModel, gelu
|
||||
from .utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "BertConfig"
|
||||
_TOKENIZER_FOR_DOC = "BertTokenizer"
|
||||
|
||||
|
||||
BERT_START_DOCSTRING = r"""
|
||||
|
||||
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
|
||||
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
|
||||
pruning heads etc.)
|
||||
|
||||
This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__ subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
|
||||
usage and behavior.
|
||||
|
||||
Parameters:
|
||||
config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the configuration.
|
||||
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
||||
"""
|
||||
|
||||
BERT_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
|
||||
Indices can be obtained using :class:`~transformers.BertTokenizer`.
|
||||
See :meth:`transformers.PreTrainedTokenizer.encode` and
|
||||
:meth:`transformers.PreTrainedTokenizer.__call__` for details.
|
||||
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
|
||||
Segment token indices to indicate first and second portions of the inputs.
|
||||
Indices are selected in ``[0, 1]``:
|
||||
|
||||
- 0 corresponds to a `sentence A` token,
|
||||
- 1 corresponds to a `sentence B` token.
|
||||
|
||||
`What are token type IDs? <../glossary.html#token-type-ids>`_
|
||||
position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
|
||||
Indices of positions of each input sequence tokens in the position embeddings.
|
||||
Selected in the range ``[0, config.max_position_embeddings - 1]``.
|
||||
|
||||
`What are position IDs? <../glossary.html#position-ids>`_
|
||||
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the self-attention modules.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
|
||||
vectors than the model's internal embedding lookup matrix.
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (:obj:`bool`, `optional`):
|
||||
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
||||
more detail.
|
||||
return_dict (:obj:`bool`, `optional`):
|
||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
class FlaxBertLayerNorm(nn.Module):
|
||||
"""Layer normalization (https://arxiv.org/abs/1607.06450).
|
||||
Operates on the last axis of the input data.
|
||||
"""
|
||||
|
||||
epsilon: float = 1e-6
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
bias: bool = True
|
||||
scale: bool = True
|
||||
bias_init: jnp.ndarray = nn.initializers.zeros
|
||||
scale_init: jnp.ndarray = nn.initializers.ones
|
||||
|
||||
@compact
|
||||
def __call__(self, x):
|
||||
"""Applies layer normalization on the input.
|
||||
It normalizes the activations of the layer for each given example in a
|
||||
batch independently, rather than across a batch like Batch Normalization.
|
||||
i.e. applies a transformation that maintains the mean activation within
|
||||
each example close to 0 and the activation standard deviation close to 1.
|
||||
Args:
|
||||
x: the inputs
|
||||
epsilon: A small float added to variance to avoid dividing by zero.
|
||||
dtype: the dtype of the computation (default: float32).
|
||||
bias: If True, bias (beta) is added.
|
||||
scale: If True, multiply by scale (gamma). When the next layer is linear
|
||||
(also e.g. nn.relu), this can be disabled since the scaling will be done
|
||||
by the next layer.
|
||||
bias_init: Initializer for bias, by default, zero.
|
||||
scale_init: Initializer for scale, by default, one.
|
||||
Returns:
|
||||
Normalized inputs (the same shape as inputs).
|
||||
"""
|
||||
features = x.shape[-1]
|
||||
mean = jnp.mean(x, axis=-1, keepdims=True)
|
||||
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
|
||||
var = mean2 - jax.lax.square(mean)
|
||||
mul = jax.lax.rsqrt(var + self.epsilon)
|
||||
if self.scale:
|
||||
mul = mul * jnp.asarray(self.param("gamma", self.scale_init, (features,)), self.dtype)
|
||||
y = (x - mean) * mul
|
||||
if self.bias:
|
||||
y = y + jnp.asarray(self.param("beta", self.bias_init, (features,)), self.dtype)
|
||||
return y
|
||||
|
||||
|
||||
class FlaxBertEmbedding(nn.Module):
|
||||
"""
|
||||
Specify a new class for doing the embedding stuff
|
||||
as Flax's one use 'embedding' for the parameter name
|
||||
and PyTorch use 'weight'
|
||||
"""
|
||||
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
emb_init: Callable[..., np.ndarray] = nn.initializers.normal(stddev=0.1)
|
||||
|
||||
@compact
|
||||
def __call__(self, inputs):
|
||||
embedding = self.param("weight", self.emb_init, (self.vocab_size, self.hidden_size))
|
||||
return jnp.take(embedding, inputs, axis=0)
|
||||
|
||||
|
||||
class FlaxBertEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word, position and token_type embeddings."""
|
||||
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
type_vocab_size: int
|
||||
max_length: int
|
||||
|
||||
@compact
|
||||
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask):
|
||||
|
||||
# Embed
|
||||
w_emb = FlaxBertEmbedding(self.vocab_size, self.hidden_size, name="word_embeddings")(
|
||||
jnp.atleast_2d(input_ids.astype("i4"))
|
||||
)
|
||||
p_emb = FlaxBertEmbedding(self.max_length, self.hidden_size, name="position_embeddings")(
|
||||
jnp.atleast_2d(position_ids.astype("i4"))
|
||||
)
|
||||
t_emb = FlaxBertEmbedding(self.type_vocab_size, self.hidden_size, name="token_type_embeddings")(
|
||||
jnp.atleast_2d(token_type_ids.astype("i4"))
|
||||
)
|
||||
|
||||
# Sum all embeddings
|
||||
summed_emb = w_emb + jnp.broadcast_to(p_emb, w_emb.shape) + t_emb
|
||||
|
||||
# Layer Norm
|
||||
layer_norm = FlaxBertLayerNorm(name="layer_norm")(summed_emb)
|
||||
|
||||
return layer_norm
|
||||
|
||||
|
||||
class FlaxBertAttention(nn.Module):
|
||||
num_heads: int
|
||||
head_size: int
|
||||
|
||||
@compact
|
||||
def __call__(self, hidden_state, attention_mask):
|
||||
self_att = nn.attention.SelfAttention(num_heads=self.num_heads, qkv_features=self.head_size, name="self")(
|
||||
hidden_state, attention_mask
|
||||
)
|
||||
|
||||
layer_norm = FlaxBertLayerNorm(name="layer_norm")(self_att + hidden_state)
|
||||
return layer_norm
|
||||
|
||||
|
||||
class FlaxBertIntermediate(nn.Module):
|
||||
output_size: int
|
||||
|
||||
@compact
|
||||
def __call__(self, hidden_state):
|
||||
# TODO: Add ACT2FN reference to change activation function
|
||||
dense = nn.Dense(features=self.output_size, name="dense")(hidden_state)
|
||||
return gelu(dense)
|
||||
|
||||
|
||||
class FlaxBertOutput(nn.Module):
|
||||
@compact
|
||||
def __call__(self, intermediate_output, attention_output):
|
||||
hidden_state = nn.Dense(attention_output.shape[-1], name="dense")(intermediate_output)
|
||||
hidden_state = FlaxBertLayerNorm(name="layer_norm")(hidden_state + attention_output)
|
||||
return hidden_state
|
||||
|
||||
|
||||
class FlaxBertLayer(nn.Module):
|
||||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
|
||||
@compact
|
||||
def __call__(self, hidden_state, attention_mask):
|
||||
attention = FlaxBertAttention(self.num_heads, self.head_size, name="attention")(hidden_state, attention_mask)
|
||||
intermediate = FlaxBertIntermediate(self.intermediate_size, name="intermediate")(attention)
|
||||
output = FlaxBertOutput(name="output")(intermediate, attention)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class FlaxBertLayerCollection(nn.Module):
|
||||
"""
|
||||
Stores N BertLayer(s)
|
||||
"""
|
||||
|
||||
num_layers: int
|
||||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
|
||||
@compact
|
||||
def __call__(self, inputs, attention_mask):
|
||||
assert self.num_layers > 0, f"num_layers should be >= 1, got ({self.num_layers})"
|
||||
|
||||
# Initialize input / output
|
||||
input_i = inputs
|
||||
|
||||
# Forward over all encoders
|
||||
for i in range(self.num_layers):
|
||||
layer = FlaxBertLayer(self.num_heads, self.head_size, self.intermediate_size, name=f"{i}")
|
||||
input_i = layer(input_i, attention_mask)
|
||||
return input_i
|
||||
|
||||
|
||||
class FlaxBertEncoder(nn.Module):
|
||||
num_layers: int
|
||||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
|
||||
@compact
|
||||
def __call__(self, hidden_state, attention_mask):
|
||||
layer = FlaxBertLayerCollection(
|
||||
self.num_layers, self.num_heads, self.head_size, self.intermediate_size, name="layer"
|
||||
)(hidden_state, attention_mask)
|
||||
return layer
|
||||
|
||||
|
||||
class FlaxBertPooler(nn.Module):
|
||||
@compact
|
||||
def __call__(self, hidden_state):
|
||||
cls_token = hidden_state[:, 0]
|
||||
out = nn.Dense(hidden_state.shape[-1], name="dense")(cls_token)
|
||||
return jax.lax.tanh(out)
|
||||
|
||||
|
||||
class FlaxBertModule(nn.Module):
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
type_vocab_size: int
|
||||
max_length: int
|
||||
num_encoder_layers: int
|
||||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
|
||||
@compact
|
||||
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask):
|
||||
|
||||
# Embedding
|
||||
embeddings = FlaxBertEmbeddings(
|
||||
self.vocab_size, self.hidden_size, self.type_vocab_size, self.max_length, name="embeddings"
|
||||
)(input_ids, token_type_ids, position_ids, attention_mask)
|
||||
|
||||
# N stacked encoding layers
|
||||
encoder = FlaxBertEncoder(
|
||||
self.num_encoder_layers, self.num_heads, self.head_size, self.intermediate_size, name="encoder"
|
||||
)(embeddings, attention_mask)
|
||||
|
||||
pooled = FlaxBertPooler(name="pooler")(encoder)
|
||||
return encoder, pooled
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
BERT_START_DOCSTRING,
|
||||
)
|
||||
class FlaxBertModel(FlaxPreTrainedModel):
|
||||
"""
|
||||
The model can behave as an encoder (with only self-attention) as well
|
||||
as a decoder, in which case a layer of cross-attention is added between
|
||||
the self-attention layers, following the architecture described in `Attention is all you need
|
||||
<https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones,
|
||||
Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
||||
"""
|
||||
|
||||
model_class = FlaxBertModule
|
||||
config_class = BertConfig
|
||||
base_model_prefix = "bert"
|
||||
|
||||
@staticmethod
|
||||
def convert_from_pytorch(pt_state: Dict, config: BertConfig) -> Dict:
|
||||
jax_state = dict(pt_state)
|
||||
|
||||
# Need to change some parameters name to match Flax names so that we don't have to fork any layer
|
||||
for key, tensor in pt_state.items():
|
||||
# Key parts
|
||||
key_parts = set(key.split("."))
|
||||
|
||||
# Every dense layer has "kernel" parameters instead of "weight"
|
||||
if "dense.weight" in key:
|
||||
del jax_state[key]
|
||||
key = key.replace("weight", "kernel")
|
||||
jax_state[key] = tensor
|
||||
|
||||
# SelfAttention needs also to replace "weight" by "kernel"
|
||||
if {"query", "key", "value"} & key_parts:
|
||||
|
||||
# Flax SelfAttention decomposes the heads (num_head, size // num_heads)
|
||||
if "bias" in key:
|
||||
jax_state[key] = tensor.reshape((config.num_attention_heads, -1))
|
||||
elif "weight":
|
||||
del jax_state[key]
|
||||
key = key.replace("weight", "kernel")
|
||||
tensor = tensor.reshape((config.num_attention_heads, -1, config.hidden_size)).transpose((2, 0, 1))
|
||||
jax_state[key] = tensor
|
||||
|
||||
# SelfAttention output is not a separate layer, remove one nesting
|
||||
if "attention.output.dense" in key:
|
||||
del jax_state[key]
|
||||
key = key.replace("attention.output.dense", "attention.self.out")
|
||||
jax_state[key] = tensor
|
||||
|
||||
# SelfAttention output is not a separate layer, remove nesting on layer norm
|
||||
if "attention.output.LayerNorm" in key:
|
||||
del jax_state[key]
|
||||
key = key.replace("attention.output.LayerNorm", "attention.LayerNorm")
|
||||
jax_state[key] = tensor
|
||||
|
||||
# There are some transposed parameters w.r.t their PyTorch counterpart
|
||||
if "intermediate.dense.kernel" in key or "output.dense.kernel" in key:
|
||||
jax_state[key] = tensor.T
|
||||
|
||||
# Self Attention output projection needs to be transposed
|
||||
if "out.kernel" in key:
|
||||
jax_state[key] = tensor.reshape((config.hidden_size, config.num_attention_heads, -1)).transpose(
|
||||
1, 2, 0
|
||||
)
|
||||
|
||||
# Pooler needs to transpose its kernel
|
||||
if "pooler.dense.kernel" in key:
|
||||
jax_state[key] = tensor.T
|
||||
|
||||
# Handle LayerNorm conversion
|
||||
if "LayerNorm" in key:
|
||||
del jax_state[key]
|
||||
|
||||
# Replace LayerNorm by layer_norm
|
||||
new_key = key.replace("LayerNorm", "layer_norm")
|
||||
|
||||
if "weight" in key:
|
||||
new_key = new_key.replace("weight", "gamma")
|
||||
elif "bias" in key:
|
||||
new_key = new_key.replace("bias", "beta")
|
||||
|
||||
jax_state[new_key] = tensor
|
||||
|
||||
return jax_state
|
||||
|
||||
def __init__(self, config: BertConfig, state: dict, seed: int = 0, **kwargs):
|
||||
model = FlaxBertModule(
|
||||
vocab_size=config.vocab_size,
|
||||
hidden_size=config.hidden_size,
|
||||
type_vocab_size=config.type_vocab_size,
|
||||
max_length=config.max_position_embeddings,
|
||||
num_encoder_layers=config.num_hidden_layers,
|
||||
num_heads=config.num_attention_heads,
|
||||
head_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
)
|
||||
|
||||
super().__init__(config, model, state, seed)
|
||||
|
||||
@property
|
||||
def module(self) -> nn.Module:
|
||||
return self._module
|
||||
|
||||
def __call__(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None):
|
||||
if token_type_ids is None:
|
||||
token_type_ids = jnp.ones_like(input_ids)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = jnp.arange(jnp.atleast_2d(input_ids).shape[-1])
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
|
||||
return self.model.apply(
|
||||
{"params": self.params},
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(token_type_ids, dtype="i4"),
|
||||
jnp.array(position_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
)
|
||||
450
src/transformers/modeling_flax_roberta.py
Normal file
450
src/transformers/modeling_flax_roberta.py
Normal file
@@ -0,0 +1,450 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Callable, Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.linen import compact
|
||||
|
||||
from .configuration_roberta import RobertaConfig
|
||||
from .file_utils import add_start_docstrings
|
||||
from .modeling_flax_utils import FlaxPreTrainedModel, gelu
|
||||
from .utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "RobertaConfig"
|
||||
_TOKENIZER_FOR_DOC = "RobertaTokenizer"
|
||||
|
||||
|
||||
ROBERTA_START_DOCSTRING = r"""
|
||||
|
||||
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
|
||||
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
|
||||
pruning heads etc.)
|
||||
|
||||
This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__ subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
|
||||
usage and behavior.
|
||||
|
||||
Parameters:
|
||||
config (:class:`~transformers.RobertaConfig`): Model configuration class with all the parameters of the
|
||||
model. Initializing with a config file does not load the weights associated with the model, only the configuration.
|
||||
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
||||
"""
|
||||
|
||||
ROBERTA_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
|
||||
Indices can be obtained using :class:`~transformers.RobertaTokenizer`.
|
||||
See :meth:`transformers.PreTrainedTokenizer.encode` and
|
||||
:meth:`transformers.PreTrainedTokenizer.__call__` for details.
|
||||
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **maked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
|
||||
Segment token indices to indicate first and second portions of the inputs.
|
||||
Indices are selected in ``[0, 1]``:
|
||||
|
||||
- 0 corresponds to a `sentence A` token,
|
||||
- 1 corresponds to a `sentence B` token.
|
||||
|
||||
`What are token type IDs? <../glossary.html#token-type-ids>`_
|
||||
position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
|
||||
Indices of positions of each input sequence tokens in the position embeddings.
|
||||
Selected in the range ``[0, config.max_position_embeddings - 1]``.
|
||||
|
||||
`What are position IDs? <../glossary.html#position-ids>`_
|
||||
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the self-attention modules.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
|
||||
vectors than the model's internal embedding lookup matrix.
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (:obj:`bool`, `optional`):
|
||||
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
||||
more detail.
|
||||
return_dict (:obj:`bool`, `optional`):
|
||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
# Copied from transformers.modeling_flax_bert.FlaxBertLayerNorm with Bert->Roberta
|
||||
class FlaxRobertaLayerNorm(nn.Module):
|
||||
"""Layer normalization (https://arxiv.org/abs/1607.06450).
|
||||
Operates on the last axis of the input data.
|
||||
"""
|
||||
|
||||
epsilon: float = 1e-6
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
bias: bool = True
|
||||
scale: bool = True
|
||||
bias_init: jnp.ndarray = nn.initializers.zeros
|
||||
scale_init: jnp.ndarray = nn.initializers.ones
|
||||
|
||||
@compact
|
||||
def __call__(self, x):
|
||||
"""Applies layer normalization on the input.
|
||||
It normalizes the activations of the layer for each given example in a
|
||||
batch independently, rather than across a batch like Batch Normalization.
|
||||
i.e. applies a transformation that maintains the mean activation within
|
||||
each example close to 0 and the activation standard deviation close to 1.
|
||||
Args:
|
||||
x: the inputs
|
||||
epsilon: A small float added to variance to avoid dividing by zero.
|
||||
dtype: the dtype of the computation (default: float32).
|
||||
bias: If True, bias (beta) is added.
|
||||
scale: If True, multiply by scale (gamma). When the next layer is linear
|
||||
(also e.g. nn.relu), this can be disabled since the scaling will be done
|
||||
by the next layer.
|
||||
bias_init: Initializer for bias, by default, zero.
|
||||
scale_init: Initializer for scale, by default, one.
|
||||
Returns:
|
||||
Normalized inputs (the same shape as inputs).
|
||||
"""
|
||||
features = x.shape[-1]
|
||||
mean = jnp.mean(x, axis=-1, keepdims=True)
|
||||
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
|
||||
var = mean2 - jax.lax.square(mean)
|
||||
mul = jax.lax.rsqrt(var + self.epsilon)
|
||||
if self.scale:
|
||||
mul = mul * jnp.asarray(self.param("gamma", self.scale_init, (features,)), self.dtype)
|
||||
y = (x - mean) * mul
|
||||
if self.bias:
|
||||
y = y + jnp.asarray(self.param("beta", self.bias_init, (features,)), self.dtype)
|
||||
return y
|
||||
|
||||
|
||||
# Copied from transformers.modeling_flax_bert.FlaxBertEmbedding with Bert->Roberta
|
||||
class FlaxRobertaEmbedding(nn.Module):
|
||||
"""
|
||||
Specify a new class for doing the embedding stuff
|
||||
as Flax's one use 'embedding' for the parameter name
|
||||
and PyTorch use 'weight'
|
||||
"""
|
||||
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
emb_init: Callable[..., np.ndarray] = nn.initializers.normal(stddev=0.1)
|
||||
|
||||
@compact
|
||||
def __call__(self, inputs):
|
||||
embedding = self.param("weight", self.emb_init, (self.vocab_size, self.hidden_size))
|
||||
return jnp.take(embedding, inputs, axis=0)
|
||||
|
||||
|
||||
# Copied from transformers.modeling_flax_bert.FlaxBertEmbeddings with Bert->Roberta
|
||||
class FlaxRobertaEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word, position and token_type embeddings."""
|
||||
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
type_vocab_size: int
|
||||
max_length: int
|
||||
|
||||
@compact
|
||||
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask):
|
||||
|
||||
# Embed
|
||||
w_emb = FlaxRobertaEmbedding(self.vocab_size, self.hidden_size, name="word_embeddings")(
|
||||
jnp.atleast_2d(input_ids.astype("i4"))
|
||||
)
|
||||
p_emb = FlaxRobertaEmbedding(self.max_length, self.hidden_size, name="position_embeddings")(
|
||||
jnp.atleast_2d(position_ids.astype("i4"))
|
||||
)
|
||||
t_emb = FlaxRobertaEmbedding(self.type_vocab_size, self.hidden_size, name="token_type_embeddings")(
|
||||
jnp.atleast_2d(token_type_ids.astype("i4"))
|
||||
)
|
||||
|
||||
# Sum all embeddings
|
||||
summed_emb = w_emb + jnp.broadcast_to(p_emb, w_emb.shape) + t_emb
|
||||
|
||||
# Layer Norm
|
||||
layer_norm = FlaxRobertaLayerNorm(name="layer_norm")(summed_emb)
|
||||
|
||||
return layer_norm
|
||||
|
||||
|
||||
# Copied from transformers.modeling_flax_bert.FlaxBertAttention with Bert->Roberta
|
||||
class FlaxRobertaAttention(nn.Module):
|
||||
num_heads: int
|
||||
head_size: int
|
||||
|
||||
@compact
|
||||
def __call__(self, hidden_state, attention_mask):
|
||||
self_att = nn.attention.SelfAttention(num_heads=self.num_heads, qkv_features=self.head_size, name="self")(
|
||||
hidden_state, attention_mask
|
||||
)
|
||||
|
||||
layer_norm = FlaxRobertaLayerNorm(name="layer_norm")(self_att + hidden_state)
|
||||
return layer_norm
|
||||
|
||||
|
||||
# Copied from transformers.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta
|
||||
class FlaxRobertaIntermediate(nn.Module):
|
||||
output_size: int
|
||||
|
||||
@compact
|
||||
def __call__(self, hidden_state):
|
||||
# TODO: Add ACT2FN reference to change activation function
|
||||
dense = nn.Dense(features=self.output_size, name="dense")(hidden_state)
|
||||
return gelu(dense)
|
||||
|
||||
|
||||
# Copied from transformers.modeling_flax_bert.FlaxBertOutput with Bert->Roberta
|
||||
class FlaxRobertaOutput(nn.Module):
|
||||
@compact
|
||||
def __call__(self, intermediate_output, attention_output):
|
||||
hidden_state = nn.Dense(attention_output.shape[-1], name="dense")(intermediate_output)
|
||||
hidden_state = FlaxRobertaLayerNorm(name="layer_norm")(hidden_state + attention_output)
|
||||
return hidden_state
|
||||
|
||||
|
||||
class FlaxRobertaLayer(nn.Module):
|
||||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
|
||||
@compact
|
||||
def __call__(self, hidden_state, attention_mask):
|
||||
attention = FlaxRobertaAttention(self.num_heads, self.head_size, name="attention")(
|
||||
hidden_state, attention_mask
|
||||
)
|
||||
intermediate = FlaxRobertaIntermediate(self.intermediate_size, name="intermediate")(attention)
|
||||
output = FlaxRobertaOutput(name="output")(intermediate, attention)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
# Copied from transformers.modeling_flax_bert.FlaxBertLayerCollection with Bert->Roberta
|
||||
class FlaxRobertaLayerCollection(nn.Module):
|
||||
"""
|
||||
Stores N RobertaLayer(s)
|
||||
"""
|
||||
|
||||
num_layers: int
|
||||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
|
||||
@compact
|
||||
def __call__(self, inputs, attention_mask):
|
||||
assert self.num_layers > 0, f"num_layers should be >= 1, got ({self.num_layers})"
|
||||
|
||||
# Initialize input / output
|
||||
input_i = inputs
|
||||
|
||||
# Forward over all encoders
|
||||
for i in range(self.num_layers):
|
||||
layer = FlaxRobertaLayer(self.num_heads, self.head_size, self.intermediate_size, name=f"{i}")
|
||||
input_i = layer(input_i, attention_mask)
|
||||
return input_i
|
||||
|
||||
|
||||
# Copied from transformers.modeling_flax_bert.FlaxBertEncoder with Bert->Roberta
|
||||
class FlaxRobertaEncoder(nn.Module):
|
||||
num_layers: int
|
||||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
|
||||
@compact
|
||||
def __call__(self, hidden_state, attention_mask):
|
||||
layer = FlaxRobertaLayerCollection(
|
||||
self.num_layers, self.num_heads, self.head_size, self.intermediate_size, name="layer"
|
||||
)(hidden_state, attention_mask)
|
||||
return layer
|
||||
|
||||
|
||||
# Copied from transformers.modeling_flax_bert.FlaxBertPooler with Bert->Roberta
|
||||
class FlaxRobertaPooler(nn.Module):
|
||||
@compact
|
||||
def __call__(self, hidden_state):
|
||||
cls_token = hidden_state[:, 0]
|
||||
out = nn.Dense(hidden_state.shape[-1], name="dense")(cls_token)
|
||||
return jax.lax.tanh(out)
|
||||
|
||||
|
||||
# Copied from transformers.modeling_flax_bert.FlaxBertModule with Bert->Roberta
|
||||
class FlaxRobertaModule(nn.Module):
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
type_vocab_size: int
|
||||
max_length: int
|
||||
num_encoder_layers: int
|
||||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
|
||||
@compact
|
||||
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask):
|
||||
|
||||
# Embedding
|
||||
embeddings = FlaxRobertaEmbeddings(
|
||||
self.vocab_size, self.hidden_size, self.type_vocab_size, self.max_length, name="embeddings"
|
||||
)(input_ids, token_type_ids, position_ids, attention_mask)
|
||||
|
||||
# N stacked encoding layers
|
||||
encoder = FlaxRobertaEncoder(
|
||||
self.num_encoder_layers, self.num_heads, self.head_size, self.intermediate_size, name="encoder"
|
||||
)(embeddings, attention_mask)
|
||||
|
||||
pooled = FlaxRobertaPooler(name="pooler")(encoder)
|
||||
return encoder, pooled
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
ROBERTA_START_DOCSTRING,
|
||||
)
|
||||
class FlaxRobertaModel(FlaxPreTrainedModel):
|
||||
"""
|
||||
The model can behave as an encoder (with only self-attention) as well
|
||||
as a decoder, in which case a layer of cross-attention is added between
|
||||
the self-attention layers, following the architecture described in `Attention is all you need`_ by Ashish Vaswani,
|
||||
Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
||||
"""
|
||||
|
||||
model_class = FlaxRobertaModule
|
||||
config_class = RobertaConfig
|
||||
base_model_prefix = "roberta"
|
||||
|
||||
@staticmethod
|
||||
def convert_from_pytorch(pt_state: Dict, config: RobertaConfig) -> Dict:
|
||||
jax_state = dict(pt_state)
|
||||
|
||||
# Need to change some parameters name to match Flax names so that we don't have to fork any layer
|
||||
for key, tensor in pt_state.items():
|
||||
# Key parts
|
||||
key_parts = set(key.split("."))
|
||||
|
||||
# Every dense layer has "kernel" parameters instead of "weight"
|
||||
if "dense.weight" in key:
|
||||
del jax_state[key]
|
||||
key = key.replace("weight", "kernel")
|
||||
jax_state[key] = tensor
|
||||
|
||||
# SelfAttention needs also to replace "weight" by "kernel"
|
||||
if {"query", "key", "value"} & key_parts:
|
||||
|
||||
# Flax SelfAttention decomposes the heads (num_head, size // num_heads)
|
||||
if "bias" in key:
|
||||
jax_state[key] = tensor.reshape((config.num_attention_heads, -1))
|
||||
elif "weight":
|
||||
del jax_state[key]
|
||||
key = key.replace("weight", "kernel")
|
||||
tensor = tensor.reshape((config.num_attention_heads, -1, config.hidden_size)).transpose((2, 0, 1))
|
||||
jax_state[key] = tensor
|
||||
|
||||
# SelfAttention output is not a separate layer, remove one nesting
|
||||
if "attention.output.dense" in key:
|
||||
del jax_state[key]
|
||||
key = key.replace("attention.output.dense", "attention.self.out")
|
||||
jax_state[key] = tensor
|
||||
|
||||
# SelfAttention output is not a separate layer, remove nesting on layer norm
|
||||
if "attention.output.LayerNorm" in key:
|
||||
del jax_state[key]
|
||||
key = key.replace("attention.output.LayerNorm", "attention.LayerNorm")
|
||||
jax_state[key] = tensor
|
||||
|
||||
# There are some transposed parameters w.r.t their PyTorch counterpart
|
||||
if "intermediate.dense.kernel" in key or "output.dense.kernel" in key:
|
||||
jax_state[key] = tensor.T
|
||||
|
||||
# Self Attention output projection needs to be transposed
|
||||
if "out.kernel" in key:
|
||||
jax_state[key] = tensor.reshape((config.hidden_size, config.num_attention_heads, -1)).transpose(
|
||||
1, 2, 0
|
||||
)
|
||||
|
||||
# Pooler needs to transpose its kernel
|
||||
if "pooler.dense.kernel" in key:
|
||||
jax_state[key] = tensor.T
|
||||
|
||||
# Handle LayerNorm conversion
|
||||
if "LayerNorm" in key:
|
||||
del jax_state[key]
|
||||
|
||||
# Replace LayerNorm by layer_norm
|
||||
new_key = key.replace("LayerNorm", "layer_norm")
|
||||
|
||||
if "weight" in key:
|
||||
new_key = new_key.replace("weight", "gamma")
|
||||
elif "bias" in key:
|
||||
new_key = new_key.replace("bias", "beta")
|
||||
|
||||
jax_state[new_key] = tensor
|
||||
|
||||
return jax_state
|
||||
|
||||
def __init__(self, config: RobertaConfig, state: dict, seed: int = 0, **kwargs):
|
||||
model = FlaxRobertaModule(
|
||||
vocab_size=config.vocab_size,
|
||||
hidden_size=config.hidden_size,
|
||||
type_vocab_size=config.type_vocab_size,
|
||||
max_length=config.max_position_embeddings,
|
||||
num_encoder_layers=config.num_hidden_layers,
|
||||
num_heads=config.num_attention_heads,
|
||||
head_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
)
|
||||
|
||||
super().__init__(config, model, state, seed)
|
||||
|
||||
@property
|
||||
def module(self) -> nn.Module:
|
||||
return self._module
|
||||
|
||||
def __call__(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None):
|
||||
if token_type_ids is None:
|
||||
token_type_ids = jnp.ones_like(input_ids)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = np.arange(
|
||||
self.config.pad_token_id + 1, np.atleast_2d(input_ids).shape[-1] + self.config.pad_token_id + 1
|
||||
)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
|
||||
return self.model.apply(
|
||||
{"params": self.params},
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(token_type_ids, dtype="i4"),
|
||||
jnp.array(position_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
)
|
||||
194
src/transformers/modeling_flax_utils.py
Normal file
194
src/transformers/modeling_flax_utils.py
Normal file
@@ -0,0 +1,194 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from pickle import UnpicklingError
|
||||
from typing import Dict
|
||||
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.serialization import to_bytes
|
||||
from flax.traverse_util import unflatten_dict
|
||||
from jax.random import PRNGKey
|
||||
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .file_utils import WEIGHTS_NAME, cached_path, hf_bucket_url, is_remote_url
|
||||
from .utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@jax.jit
|
||||
def gelu(x):
|
||||
r"""Gaussian error linear unit activation function.
|
||||
|
||||
Computes the element-wise function:
|
||||
|
||||
.. math::
|
||||
\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left(
|
||||
\sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)
|
||||
|
||||
We explicitly use the approximation rather than the exact formulation for
|
||||
speed. For more information, see `Gaussian Error Linear Units (GELUs)
|
||||
<https://arxiv.org/abs/1606.08415>`_, section 2.
|
||||
"""
|
||||
return x * 0.5 * (1.0 + jax.lax.erf(x / jnp.sqrt(2.0)))
|
||||
|
||||
|
||||
ACT2FN = {
|
||||
"gelu": nn.gelu,
|
||||
"relu": nn.relu,
|
||||
"swish": nn.swish,
|
||||
"gelu_new": gelu,
|
||||
}
|
||||
|
||||
|
||||
class FlaxPreTrainedModel(ABC):
|
||||
config_class = None
|
||||
pretrained_model_archive_map = {}
|
||||
base_model_prefix = ""
|
||||
model_class = None
|
||||
|
||||
def __init__(self, config: PretrainedConfig, module: nn.Module, params: Dict, seed: int = 0):
|
||||
if config is None:
|
||||
raise ValueError("config cannot be None")
|
||||
|
||||
if module is None:
|
||||
raise ValueError("module cannot be None")
|
||||
|
||||
if params is None:
|
||||
raise ValueError("state cannot be None")
|
||||
|
||||
# Those are private to be exposed as typed property on derived classes.
|
||||
self._config = config
|
||||
self._module = module
|
||||
|
||||
# Those are public as their type is generic to every derived classes.
|
||||
self.key = PRNGKey(seed)
|
||||
self.params = params
|
||||
self.model = module
|
||||
|
||||
@property
|
||||
def config(self) -> PretrainedConfig:
|
||||
return self._config
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def convert_from_pytorch(pt_state: Dict, config: PretrainedConfig) -> Dict:
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
r"""
|
||||
Instantiate a pretrained Flax model from a pre-trained model configuration.
|
||||
"""
|
||||
config = kwargs.pop("config", None)
|
||||
# state_dict = kwargs.pop("state_dict", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
# from_tf = kwargs.pop("from_tf", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
# output_loading_info = kwargs.pop("output_loading_info", False)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
use_cdn = kwargs.pop("use_cdn", True)
|
||||
|
||||
# Load config if we don't provide a configuration
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config_path = config if config is not None else pretrained_model_name_or_path
|
||||
config, model_kwargs = cls.config_class.from_pretrained(
|
||||
config_path,
|
||||
*model_args,
|
||||
cache_dir=cache_dir,
|
||||
return_unused_kwargs=True,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
model_kwargs = kwargs
|
||||
|
||||
# Load model
|
||||
if pretrained_model_name_or_path is not None:
|
||||
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||
archive_file = pretrained_model_name_or_path
|
||||
else:
|
||||
archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, use_cdn=use_cdn)
|
||||
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_archive_file = cached_path(
|
||||
archive_file,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
except EnvironmentError:
|
||||
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
||||
msg = f"Couldn't reach server at '{archive_file}' to download pretrained weights."
|
||||
else:
|
||||
msg = (
|
||||
f"Model name '{pretrained_model_name_or_path}' "
|
||||
f"was not found in model name list ({', '.join(cls.pretrained_model_archive_map.keys())}). "
|
||||
f"We assumed '{archive_file}' was a path or url to model weight files but "
|
||||
"couldn't find any such file at this path or url."
|
||||
)
|
||||
raise EnvironmentError(msg)
|
||||
|
||||
if resolved_archive_file == archive_file:
|
||||
logger.info(f"loading weights file {archive_file}")
|
||||
else:
|
||||
logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}")
|
||||
else:
|
||||
resolved_archive_file = None
|
||||
|
||||
# Instantiate model.
|
||||
with open(resolved_archive_file, "rb") as state_f:
|
||||
try:
|
||||
from flax.serialization import from_bytes
|
||||
|
||||
state = from_bytes(cls.model_class, state_f)
|
||||
except TypeError:
|
||||
try:
|
||||
import torch
|
||||
|
||||
state = torch.load(state_f)
|
||||
state = {k: v.numpy() for k, v in state.items()}
|
||||
state = cls.convert_from_pytorch(state, config)
|
||||
state = unflatten_dict({tuple(k.split(".")[1:]): v for k, v in state.items()})
|
||||
except UnpicklingError:
|
||||
raise EnvironmentError(
|
||||
f"Unable to convert model {archive_file} to Flax deserializable object. "
|
||||
"Supported format are PyTorch archive or Flax msgpack"
|
||||
)
|
||||
|
||||
return cls(config, state, *model_args, **model_kwargs)
|
||||
|
||||
def save_pretrained(self, folder):
|
||||
folder_abs = os.path.abspath(folder)
|
||||
|
||||
if not os.path.exists(folder_abs):
|
||||
os.mkdir(folder_abs)
|
||||
|
||||
with open(os.path.join(folder_abs, f"{self._config.model_type}.flax", "wb")) as f:
|
||||
model_bytes = to_bytes(self.params)
|
||||
f.write(model_bytes)
|
||||
@@ -13,6 +13,7 @@ from pathlib import Path
|
||||
from .file_utils import (
|
||||
_datasets_available,
|
||||
_faiss_available,
|
||||
_flax_available,
|
||||
_sentencepiece_available,
|
||||
_tf_available,
|
||||
_tokenizers_available,
|
||||
@@ -115,6 +116,18 @@ def require_tf(test_case):
|
||||
return test_case
|
||||
|
||||
|
||||
def require_flax(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires JAX & Flax
|
||||
|
||||
These tests are skipped when one / both are not installed
|
||||
|
||||
"""
|
||||
if not _flax_available:
|
||||
test_case = unittest.skip("test requires JAX & Flax")(test_case)
|
||||
return test_case
|
||||
|
||||
|
||||
def require_sentencepiece(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires SentencePiece.
|
||||
|
||||
@@ -33,6 +33,7 @@ from .file_utils import (
|
||||
add_end_docstrings,
|
||||
cached_path,
|
||||
hf_bucket_url,
|
||||
is_flax_available,
|
||||
is_remote_url,
|
||||
is_tf_available,
|
||||
is_tokenizers_available,
|
||||
@@ -47,6 +48,8 @@ if is_tf_available():
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
if is_flax_available():
|
||||
import jax.numpy as jnp
|
||||
|
||||
if is_tokenizers_available():
|
||||
from tokenizers import AddedToken
|
||||
@@ -143,6 +146,7 @@ class TensorType(ExplicitEnum):
|
||||
PYTORCH = "pt"
|
||||
TENSORFLOW = "tf"
|
||||
NUMPY = "np"
|
||||
JAX = "jax"
|
||||
|
||||
|
||||
class CharSpan(NamedTuple):
|
||||
@@ -559,18 +563,27 @@ class BatchEncoding(UserDict):
|
||||
tensor_type = TensorType(tensor_type)
|
||||
|
||||
# Get a function reference for the correct framework
|
||||
if tensor_type == TensorType.TENSORFLOW and is_tf_available():
|
||||
as_tensor = tf.constant
|
||||
elif tensor_type == TensorType.PYTORCH and is_torch_available():
|
||||
as_tensor = torch.tensor
|
||||
elif tensor_type == TensorType.NUMPY:
|
||||
as_tensor = np.asarray
|
||||
else:
|
||||
if tensor_type == TensorType.TENSORFLOW:
|
||||
if not is_tf_available():
|
||||
raise ImportError(
|
||||
"Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format(
|
||||
tensor_type
|
||||
)
|
||||
"Unable to convert output to TensorFlow tensors format, TensorFlow is not installed."
|
||||
)
|
||||
as_tensor = tf.constant
|
||||
elif tensor_type == TensorType.PYTORCH:
|
||||
if not is_torch_available():
|
||||
raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
|
||||
as_tensor = torch.tensor
|
||||
elif tensor_type == TensorType.JAX:
|
||||
if not is_flax_available():
|
||||
raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.")
|
||||
as_tensor = jnp.array
|
||||
else:
|
||||
as_tensor = np.asarray
|
||||
# (mfuntowicz: This code is unreachable)
|
||||
# else:
|
||||
# raise ImportError(
|
||||
# "Unable to convert output to tensors format {}".format(tensor_type)
|
||||
# )
|
||||
|
||||
# Do the tensor conversion in batch
|
||||
for key, value in self.items():
|
||||
|
||||
64
tests/test_flax_auto.py
Normal file
64
tests/test_flax_auto.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import unittest
|
||||
|
||||
from transformers import AutoConfig, AutoTokenizer, BertConfig, TensorType, is_flax_available
|
||||
from transformers.testing_utils import require_flax, slow
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import jax
|
||||
from transformers.modeling_flax_auto import FlaxAutoModel
|
||||
from transformers.modeling_flax_bert import FlaxBertModel
|
||||
from transformers.modeling_flax_roberta import FlaxRobertaModel
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxAutoModelTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_bert_from_pretrained(self):
|
||||
for model_name in ["bert-base-cased", "bert-large-uncased"]:
|
||||
with self.subTest(model_name):
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
self.assertIsNotNone(config)
|
||||
self.assertIsInstance(config, BertConfig)
|
||||
|
||||
model = FlaxAutoModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertIsInstance(model, FlaxBertModel)
|
||||
|
||||
@slow
|
||||
def test_roberta_from_pretrained(self):
|
||||
for model_name in ["roberta-base-cased", "roberta-large-uncased"]:
|
||||
with self.subTest(model_name):
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
self.assertIsNotNone(config)
|
||||
self.assertIsInstance(config, BertConfig)
|
||||
|
||||
model = FlaxAutoModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertIsInstance(model, FlaxRobertaModel)
|
||||
|
||||
@slow
|
||||
def test_bert_jax_jit(self):
|
||||
for model_name in ["bert-base-cased", "bert-large-uncased"]:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
model = FlaxBertModel.from_pretrained(model_name)
|
||||
tokens = tokenizer("Do you support jax jitted function?", return_tensors=TensorType.JAX)
|
||||
|
||||
@jax.jit
|
||||
def eval(**kwargs):
|
||||
return model(**kwargs)
|
||||
|
||||
eval(**tokens).block_until_ready()
|
||||
|
||||
@slow
|
||||
def test_roberta_jax_jit(self):
|
||||
for model_name in ["roberta-base-cased", "roberta-large-uncased"]:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
model = FlaxRobertaModel.from_pretrained(model_name)
|
||||
tokens = tokenizer("Do you support jax jitted function?", return_tensors=TensorType.JAX)
|
||||
|
||||
@jax.jit
|
||||
def eval(**kwargs):
|
||||
return model(**kwargs)
|
||||
|
||||
eval(**tokens).block_until_ready()
|
||||
42
tests/test_modeling_flax_bert.py
Normal file
42
tests/test_modeling_flax_bert.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import unittest
|
||||
|
||||
from numpy import ndarray
|
||||
|
||||
from transformers import TensorType, is_flax_available, is_torch_available
|
||||
from transformers.testing_utils import require_flax, require_torch
|
||||
from transformers.tokenization_bert_fast import BertTokenizerFast
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
from transformers.modeling_flax_bert import FlaxBertModel
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.modeling_bert import BertModel
|
||||
|
||||
|
||||
@require_flax
|
||||
@require_torch
|
||||
class FlaxBertModelTest(unittest.TestCase):
|
||||
def test_from_pytorch(self):
|
||||
with torch.no_grad():
|
||||
with self.subTest("bert-base-cased"):
|
||||
tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")
|
||||
fx_model = FlaxBertModel.from_pretrained("bert-base-cased")
|
||||
pt_model = BertModel.from_pretrained("bert-base-cased")
|
||||
|
||||
# Check for simple input
|
||||
pt_inputs = tokenizer.encode_plus("This is a simple input", return_tensors=TensorType.PYTORCH)
|
||||
fx_inputs = tokenizer.encode_plus("This is a simple input", return_tensors=TensorType.JAX)
|
||||
pt_outputs = pt_model(**pt_inputs)
|
||||
fx_outputs = fx_model(**fx_inputs)
|
||||
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 5e-4)
|
||||
|
||||
def assert_almost_equals(self, a: ndarray, b: ndarray, tol: float):
|
||||
diff = (a - b).sum()
|
||||
self.assertLessEqual(diff, tol, "Difference between torch and flax is {} (>= {})".format(diff, tol))
|
||||
42
tests/test_modeling_flax_roberta.py
Normal file
42
tests/test_modeling_flax_roberta.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import unittest
|
||||
|
||||
from numpy import ndarray
|
||||
|
||||
from transformers import TensorType, is_flax_available, is_torch_available
|
||||
from transformers.testing_utils import require_flax, require_torch
|
||||
from transformers.tokenization_roberta_fast import RobertaTokenizerFast
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
from transformers.modeling_flax_roberta import FlaxRobertaModel
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.modeling_roberta import RobertaModel
|
||||
|
||||
|
||||
@require_flax
|
||||
@require_torch
|
||||
class FlaxRobertaModelTest(unittest.TestCase):
|
||||
def test_from_pytorch(self):
|
||||
with torch.no_grad():
|
||||
with self.subTest("roberta-base"):
|
||||
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
|
||||
fx_model = FlaxRobertaModel.from_pretrained("roberta-base")
|
||||
pt_model = RobertaModel.from_pretrained("roberta-base")
|
||||
|
||||
# Check for simple input
|
||||
pt_inputs = tokenizer.encode_plus("This is a simple input", return_tensors=TensorType.PYTORCH)
|
||||
fx_inputs = tokenizer.encode_plus("This is a simple input", return_tensors=TensorType.JAX)
|
||||
pt_outputs = pt_model(**pt_inputs)
|
||||
fx_outputs = fx_model(**fx_inputs)
|
||||
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 5e-4)
|
||||
|
||||
def assert_almost_equals(self, a: ndarray, b: ndarray, tol: float):
|
||||
diff = (a - b).sum()
|
||||
self.assertLessEqual(diff, tol, "Difference between torch and flax is {} (>= {})".format(diff, tol))
|
||||
Reference in New Issue
Block a user