Integrate Bert-like model on Flax runtime. (#3722)
* WIP flax bert * Initial commit Bert Jax/Flax implementation. * Embeddings working and equivalent to PyTorch. * Move embeddings in its own module BertEmbeddings * Added jax.jit annotation on forward call * BertEncoder on par with PyTorch ! :D * Add BertPooler on par with PyTorch !! * Working Jax+Flax implementation of BertModel with < 1e-5 differences on the last layer. * Fix pooled output to take only the first token of the sequence. * Refactoring to use BertConfig from transformers. * Renamed FXBertModel to FlaxBertModel * Model is now initialized in FlaxBertModel constructor and reused. * WIP JaxPreTrainedModel * Cleaning up the code of FlaxBertModel * Added ability to load Flax model saved through save_pretrained() * Added ability to convert Pytorch Bert model to FlaxBert * FlaxBert can now load every Pytorch Bert model with on-the-fly conversion * Fix hardcoded shape values in conversion scripts. * Improve the way we handle LayerNorm conversion from PyTorch to Flax. * Added positional embeddings as parameter of BertModel with default to np.arange. * Let's roll FlaxRoberta ! * Fix missing position_ids parameters on predict for Bert * Flax backend now supports batched inputs Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Make it possible to load msgpacked model on convert from pytorch in last resort. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Moved save_pretrained to Jax base class along with more constructor parameters. * Use specialized, model dependent conversion functio. * Expose `is_flax_available` in file_utils. * Added unittest for Flax models. * Added run_tests_flax to the CI. * Introduce FlaxAutoModel * Added more unittests * Flax model reference the _MODEL_ARCHIVE_MAP from PyTorch model. * Addressing review comments. * Expose seed in both Bert and Roberta * Fix typo suggested by @stefan-it Co-Authored-By: Stefan Schweter <stefan@schweter.it> * Attempt to make style * Attempt to make style in tests too * Added jax & jaxlib to the flax optional dependencies. * Attempt to fix flake8 warnings ... * Redo black again and again * When black and flake8 fight each other for a space ... 💥 💥 💥 * Try removing trailing comma to make both black and flake happy! * Fix invalid is_<framework>_available call, thanks @LysandreJik 🎉 * Fix another invalid import in flax_roberta test * Bump and pin flax release to 0.1.0. * Make flake8 happy, remove unused jax import * Change the type of the catch for msgpack. * Remove unused import. * Put seed as optional constructor parameter. * trigger ci again * Fix too much parameters in BertAttention. * Formatting. * Simplify Flax unittests to avoid machine crashes. * Fix invalid number of arguments when raising issue for an unknown model. * Address @bastings comment in PR, moving jax.jit decorated outside of __call__ * Fix incorrect path to require_flax/require_pytorch functions. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Attempt to make style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correct rebasing of circle-ci dependencies Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Fix import sorting. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Fix unused imports. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Again import sorting... Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Installing missing nlp dependency for flax unittests. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Fix laoding of model for Flax implementations. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * jit the inner function call to make JAX-compatible Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Format ! Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Flake one more time 🎶 Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Rewrites BERT in Flax to the new Linen API (#7211) * Rewrite Flax HuggingFace PR to Linen * Some fixes * Fix tests * Fix CI with change of name of nlp (#7054) * nlp -> datasets * More nlp -> datasets * Woopsie * More nlp -> datasets * One last * Expose `is_flax_available` in file_utils. * Added run_tests_flax to the CI. * Attempt to make style * trigger ci again * Fix import sorting. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Revert "Rewrites BERT in Flax to the new Linen API (#7211)" This reverts commit 23703a5eb3364e26a1cbc3ee34b4710d86a674b0. * Remove jnp.lax references Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Reintroduce Linen changes ... Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Use jax native's gelu function. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Renaming BertModel to BertModule to highlight the fact this is the Flax Module object. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Rewrite FlaxAutoModel test to not rely on pretrained_model_archive_map Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Remove unused variable in BertModule. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Remove unused variable in BertModule again Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Attempt to have is_flax_available working again. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Introduce JAX TensorType Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Improve ImportError message when trying to convert to various TensorType format. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Makes Flax model jittable. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Ensure flax models are jittable in unittests. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Remove unused imports. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Ensure jax imports are guarded behind is_flax_available. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style again Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style again again Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style again again again Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Update src/transformers/file_utils.py Co-authored-by: Marc van Zee <marcvanzee@gmail.com> * Bump flax to it's latest version Co-authored-by: Marc van Zee <marcvanzee@gmail.com> * Bump jax version to at least 0.2.0 Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Update the unittest to use TensorType.JAX Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * isort import in tests. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Match new flax parameters name "params" Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Remove unused imports. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Add flax models to transformers __init__ Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Attempt to address all CI related comments. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correct circle.yml indent. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correct circle.yml indent (2) Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Remove coverage from flax tests Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Addressing many naming suggestions from comments Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Simplify for loop logic to interate over layers in FlaxBertLayerCollection Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * use f-string syntax for formatting logs. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Use config property from FlaxPreTrainedModel. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * use "cls_token" instead of "first_token" variable name. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * use "hidden_state" instead of "h" variable name. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correct class reference in docstring to link to Flax related modules. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Added HF + Google Flax team copyright. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make Roberta independent from Bert Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Move activation functions to flax_utils. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Move activation functions to flax_utils for bert. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Added docstring for BERT Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Update import for Bert and Roberta tokenizers Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * fix-copies Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correct FlaxRobertaLayer to match PyTorch. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Use the same store_artifact for flax unittest Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make sure gradient are disabled only locally for flax unittest using torch equivalence. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Use relative imports Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> Co-authored-by: Stefan Schweter <stefan@schweter.it> Co-authored-by: Marc van Zee <marcvanzee@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -139,6 +139,31 @@ jobs:
|
|||||||
- store_artifacts:
|
- store_artifacts:
|
||||||
path: ~/transformers/output.txt
|
path: ~/transformers/output.txt
|
||||||
destination: test_output.txt
|
destination: test_output.txt
|
||||||
|
run_tests_flax:
|
||||||
|
working_directory: ~/transformers
|
||||||
|
docker:
|
||||||
|
- image: circleci/python:3.7
|
||||||
|
environment:
|
||||||
|
OMP_NUM_THREADS: 1
|
||||||
|
resource_class: xlarge
|
||||||
|
parallelism: 1
|
||||||
|
steps:
|
||||||
|
- checkout
|
||||||
|
- restore_cache:
|
||||||
|
keys:
|
||||||
|
- v0.3-flax-{{ checksum "setup.py" }}
|
||||||
|
- v0.3-{{ checksum "setup.py" }}
|
||||||
|
- run: pip install --upgrade pip
|
||||||
|
- run: pip install git+https://github.com/huggingface/datasets
|
||||||
|
- run: sudo pip install .[flax,sklearn,torch,testing]
|
||||||
|
- save_cache:
|
||||||
|
key: v0.3-flax-{{ checksum "setup.py" }}
|
||||||
|
paths:
|
||||||
|
- '~/.cache/pip'
|
||||||
|
- run: python -m pytest -n 8 --dist=loadfile -rA -s ./tests/ | tee output.txt
|
||||||
|
- store_artifacts:
|
||||||
|
path: ~/transformers/output.txt
|
||||||
|
destination: test_output.txt
|
||||||
run_tests_custom_tokenizers:
|
run_tests_custom_tokenizers:
|
||||||
working_directory: ~/transformers
|
working_directory: ~/transformers
|
||||||
docker:
|
docker:
|
||||||
@@ -305,6 +330,7 @@ workflows:
|
|||||||
- run_tests_torch_and_tf
|
- run_tests_torch_and_tf
|
||||||
- run_tests_torch
|
- run_tests_torch
|
||||||
- run_tests_tf
|
- run_tests_tf
|
||||||
|
- run_tests_flax
|
||||||
- build_doc
|
- build_doc
|
||||||
- deploy_doc: *workflow_filters
|
- deploy_doc: *workflow_filters
|
||||||
tpu_testing_jobs:
|
tpu_testing_jobs:
|
||||||
|
|||||||
1
setup.py
1
setup.py
@@ -87,6 +87,7 @@ extras["tf-cpu"] = [
|
|||||||
# "keras2onnx @ git+git://github.com/onnx/keras-onnx.git@cbdc75cb950b16db7f0a67be96a278f8d2953b48#egg=keras2onnx",
|
# "keras2onnx @ git+git://github.com/onnx/keras-onnx.git@cbdc75cb950b16db7f0a67be96a278f8d2953b48#egg=keras2onnx",
|
||||||
]
|
]
|
||||||
extras["torch"] = ["torch>=1.0"]
|
extras["torch"] = ["torch>=1.0"]
|
||||||
|
extras["flax"] = ["jaxlib==0.1.55", "jax>=0.2.0", "flax==0.2.2"]
|
||||||
extras["onnxruntime"] = ["onnxruntime>=1.4.0", "onnxruntime-tools>=1.4.2"]
|
extras["onnxruntime"] = ["onnxruntime>=1.4.0", "onnxruntime-tools>=1.4.2"]
|
||||||
|
|
||||||
extras["serving"] = ["pydantic", "uvicorn", "fastapi", "starlette"]
|
extras["serving"] = ["pydantic", "uvicorn", "fastapi", "starlette"]
|
||||||
|
|||||||
@@ -103,6 +103,7 @@ from .file_utils import (
|
|||||||
is_apex_available,
|
is_apex_available,
|
||||||
is_datasets_available,
|
is_datasets_available,
|
||||||
is_faiss_available,
|
is_faiss_available,
|
||||||
|
is_flax_available,
|
||||||
is_psutil_available,
|
is_psutil_available,
|
||||||
is_py3nvml_available,
|
is_py3nvml_available,
|
||||||
is_sentencepiece_available,
|
is_sentencepiece_available,
|
||||||
@@ -817,6 +818,10 @@ else:
|
|||||||
from .utils.dummy_tf_objects import *
|
from .utils.dummy_tf_objects import *
|
||||||
|
|
||||||
|
|
||||||
|
if is_flax_available():
|
||||||
|
from .modeling_flax_bert import FlaxBertModel
|
||||||
|
from .modeling_flax_roberta import FlaxRobertaModel
|
||||||
|
|
||||||
if not is_tf_available() and not is_torch_available():
|
if not is_tf_available() and not is_torch_available():
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Neither PyTorch nor TensorFlow >= 2.0 have been found."
|
"Neither PyTorch nor TensorFlow >= 2.0 have been found."
|
||||||
|
|||||||
@@ -34,10 +34,13 @@ from .utils import logging
|
|||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES"}
|
||||||
|
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
|
||||||
|
|
||||||
try:
|
try:
|
||||||
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
||||||
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
||||||
if USE_TORCH in ("1", "ON", "YES", "AUTO") and USE_TF not in ("1", "ON", "YES"):
|
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
_torch_available = True # pylint: disable=invalid-name
|
_torch_available = True # pylint: disable=invalid-name
|
||||||
@@ -52,7 +55,7 @@ try:
|
|||||||
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
||||||
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
||||||
|
|
||||||
if USE_TF in ("1", "ON", "YES", "AUTO") and USE_TORCH not in ("1", "ON", "YES"):
|
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2
|
assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2
|
||||||
@@ -65,6 +68,22 @@ except (ImportError, AssertionError):
|
|||||||
_tf_available = False # pylint: disable=invalid-name
|
_tf_available = False # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
|
||||||
|
|
||||||
|
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
|
||||||
|
import flax
|
||||||
|
import jax
|
||||||
|
|
||||||
|
logger.info("JAX version {}, Flax: available".format(jax.__version__))
|
||||||
|
logger.info("Flax available: {}".format(flax))
|
||||||
|
_flax_available = True
|
||||||
|
else:
|
||||||
|
_flax_available = False
|
||||||
|
except ImportError:
|
||||||
|
_flax_available = False # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import datasets # noqa: F401
|
import datasets # noqa: F401
|
||||||
|
|
||||||
@@ -213,6 +232,10 @@ def is_tf_available():
|
|||||||
return _tf_available
|
return _tf_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_flax_available():
|
||||||
|
return _flax_available
|
||||||
|
|
||||||
|
|
||||||
def is_torch_tpu_available():
|
def is_torch_tpu_available():
|
||||||
return _torch_tpu_available
|
return _torch_tpu_available
|
||||||
|
|
||||||
|
|||||||
167
src/transformers/modeling_flax_auto.py
Normal file
167
src/transformers/modeling_flax_auto.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" Auto Model class. """
|
||||||
|
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
from .configuration_auto import AutoConfig, BertConfig, RobertaConfig
|
||||||
|
from .configuration_utils import PretrainedConfig
|
||||||
|
from .modeling_flax_bert import FlaxBertModel
|
||||||
|
from .modeling_flax_roberta import FlaxRobertaModel
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict(
|
||||||
|
(key, value)
|
||||||
|
for pretrained_map in [
|
||||||
|
FlaxBertModel.pretrained_model_archive_map,
|
||||||
|
FlaxRobertaModel.pretrained_model_archive_map,
|
||||||
|
]
|
||||||
|
for key, value, in pretrained_map.items()
|
||||||
|
)
|
||||||
|
|
||||||
|
MODEL_MAPPING = OrderedDict(
|
||||||
|
[
|
||||||
|
(RobertaConfig, FlaxRobertaModel),
|
||||||
|
(BertConfig, FlaxBertModel),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxAutoModel(object):
|
||||||
|
r"""
|
||||||
|
:class:`~transformers.FlaxAutoModel` is a generic model class
|
||||||
|
that will be instantiated as one of the base model classes of the library
|
||||||
|
when created with the `FlaxAutoModel.from_pretrained(pretrained_model_name_or_path)`
|
||||||
|
or the `FlaxAutoModel.from_config(config)` class methods.
|
||||||
|
|
||||||
|
This class cannot be instantiated using `__init__()` (throws an error).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
raise EnvironmentError(
|
||||||
|
"FlaxAutoModel is designed to be instantiated "
|
||||||
|
"using the `FlaxAutoModel.from_pretrained(pretrained_model_name_or_path)` or "
|
||||||
|
"`FlaxAutoModel.from_config(config)` methods."
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config):
|
||||||
|
r"""Instantiates one of the base model classes of the library
|
||||||
|
from a configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (:class:`~transformers.PretrainedConfig`):
|
||||||
|
The model class to instantiate is selected based on the configuration class:
|
||||||
|
|
||||||
|
- isInstance of `roberta` configuration class: :class:`~transformers.FlaxRobertaModel` (RoBERTa model)
|
||||||
|
- isInstance of `bert` configuration class: :class:`~transformers.FlaxBertModel` (Bert model)
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
||||||
|
model = FlaxAutoModel.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||||
|
"""
|
||||||
|
for config_class, model_class in MODEL_MAPPING.items():
|
||||||
|
if isinstance(config, config_class):
|
||||||
|
return model_class(config)
|
||||||
|
raise ValueError(
|
||||||
|
f"Unrecognized configuration class {config.__class__} "
|
||||||
|
f"for this kind of FlaxAutoModel: {cls.__name__}.\n"
|
||||||
|
f"Model type should be one of {', '.join(c.__name__ for c in MODEL_MAPPING.keys())}."
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
|
r"""Instantiates one of the base model classes of the library
|
||||||
|
from a pre-trained model configuration.
|
||||||
|
|
||||||
|
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||||
|
based on the `model_type` property of the config object, or when it's missing,
|
||||||
|
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
|
||||||
|
|
||||||
|
The base model class to instantiate is selected as the first pattern matching
|
||||||
|
in the `pretrained_model_name_or_path` string (in the following order):
|
||||||
|
- contains `roberta`: :class:`~transformers.FlaxRobertaModel` (RoBERTa model)
|
||||||
|
- contains `bert`: :class:`~transformers.FlaxBertModel` (Bert model)
|
||||||
|
|
||||||
|
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
|
||||||
|
To train the model, you should first set it back in training mode with `model.train()`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained_model_name_or_path: either:
|
||||||
|
|
||||||
|
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
|
||||||
|
- a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
|
||||||
|
- a path to a `directory` containing model weights saved using :func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
|
||||||
|
- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
|
||||||
|
|
||||||
|
model_args: (`optional`) Sequence of positional arguments:
|
||||||
|
All remaining positional arguments will be passed to the underlying model's ``__init__`` method
|
||||||
|
|
||||||
|
config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
|
||||||
|
Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
|
||||||
|
|
||||||
|
- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
|
||||||
|
- the model was saved using :func:`~transformers.FlaxPreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
|
||||||
|
- the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
|
||||||
|
|
||||||
|
state_dict: (`optional`) dict:
|
||||||
|
an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
|
||||||
|
This option can be used if you want to create a model from a pretrained configuration but load your own weights.
|
||||||
|
In this case though, you should check if using :func:`~transformers.FlaxPreTrainedModel.save_pretrained` and :func:`~transformers.FlaxPreTrainedModel.from_pretrained` is not a simpler option.
|
||||||
|
|
||||||
|
cache_dir: (`optional`) string:
|
||||||
|
Path to a directory in which a downloaded pre-trained model
|
||||||
|
configuration should be cached if the standard cache should not be used.
|
||||||
|
|
||||||
|
force_download: (`optional`) boolean, default False:
|
||||||
|
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
|
||||||
|
|
||||||
|
resume_download: (`optional`) boolean, default False:
|
||||||
|
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
|
||||||
|
|
||||||
|
proxies: (`optional`) dict, default None:
|
||||||
|
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
||||||
|
The proxies are used on each request.
|
||||||
|
|
||||||
|
output_loading_info: (`optional`) boolean:
|
||||||
|
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
|
||||||
|
|
||||||
|
kwargs: (`optional`) Remaining dictionary of keyword arguments:
|
||||||
|
These arguments will be passed to the configuration and the model.
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
model = FlaxAutoModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
|
||||||
|
model = FlaxAutoModel.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||||
|
assert model.config.output_attention == True
|
||||||
|
|
||||||
|
"""
|
||||||
|
config = kwargs.pop("config", None)
|
||||||
|
if not isinstance(config, PretrainedConfig):
|
||||||
|
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
|
|
||||||
|
for config_class, model_class in MODEL_MAPPING.items():
|
||||||
|
if isinstance(config, config_class):
|
||||||
|
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||||
|
raise ValueError(
|
||||||
|
f"Unrecognized configuration class {config.__class__} "
|
||||||
|
f"for this kind of FlaxAutoModel: {cls.__name__}.\n"
|
||||||
|
f"Model type should be one of {', '.join(c.__name__ for c in MODEL_MAPPING.keys())}"
|
||||||
|
)
|
||||||
438
src/transformers/modeling_flax_bert.py
Normal file
438
src/transformers/modeling_flax_bert.py
Normal file
@@ -0,0 +1,438 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Callable, Dict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import flax.linen as nn
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from flax.linen import compact
|
||||||
|
|
||||||
|
from .configuration_bert import BertConfig
|
||||||
|
from .file_utils import add_start_docstrings
|
||||||
|
from .modeling_flax_utils import FlaxPreTrainedModel, gelu
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
_CONFIG_FOR_DOC = "BertConfig"
|
||||||
|
_TOKENIZER_FOR_DOC = "BertTokenizer"
|
||||||
|
|
||||||
|
|
||||||
|
BERT_START_DOCSTRING = r"""
|
||||||
|
|
||||||
|
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
|
||||||
|
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
|
||||||
|
pruning heads etc.)
|
||||||
|
|
||||||
|
This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__ subclass.
|
||||||
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
|
||||||
|
usage and behavior.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
|
||||||
|
Initializing with a config file does not load the weights associated with the model, only the configuration.
|
||||||
|
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
||||||
|
"""
|
||||||
|
|
||||||
|
BERT_INPUTS_DOCSTRING = r"""
|
||||||
|
Args:
|
||||||
|
input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
|
||||||
|
Indices of input sequence tokens in the vocabulary.
|
||||||
|
|
||||||
|
Indices can be obtained using :class:`~transformers.BertTokenizer`.
|
||||||
|
See :meth:`transformers.PreTrainedTokenizer.encode` and
|
||||||
|
:meth:`transformers.PreTrainedTokenizer.__call__` for details.
|
||||||
|
|
||||||
|
`What are input IDs? <../glossary.html#input-ids>`__
|
||||||
|
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
|
||||||
|
Mask to avoid performing attention on padding token indices.
|
||||||
|
Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
|
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
|
||||||
|
Segment token indices to indicate first and second portions of the inputs.
|
||||||
|
Indices are selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 0 corresponds to a `sentence A` token,
|
||||||
|
- 1 corresponds to a `sentence B` token.
|
||||||
|
|
||||||
|
`What are token type IDs? <../glossary.html#token-type-ids>`_
|
||||||
|
position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
|
||||||
|
Indices of positions of each input sequence tokens in the position embeddings.
|
||||||
|
Selected in the range ``[0, config.max_position_embeddings - 1]``.
|
||||||
|
|
||||||
|
`What are position IDs? <../glossary.html#position-ids>`_
|
||||||
|
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the self-attention modules.
|
||||||
|
Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
|
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
|
||||||
|
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
||||||
|
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
|
||||||
|
vectors than the model's internal embedding lookup matrix.
|
||||||
|
output_attentions (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
|
||||||
|
tensors for more detail.
|
||||||
|
output_hidden_states (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
||||||
|
more detail.
|
||||||
|
return_dict (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxBertLayerNorm(nn.Module):
|
||||||
|
"""Layer normalization (https://arxiv.org/abs/1607.06450).
|
||||||
|
Operates on the last axis of the input data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
epsilon: float = 1e-6
|
||||||
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
bias: bool = True
|
||||||
|
scale: bool = True
|
||||||
|
bias_init: jnp.ndarray = nn.initializers.zeros
|
||||||
|
scale_init: jnp.ndarray = nn.initializers.ones
|
||||||
|
|
||||||
|
@compact
|
||||||
|
def __call__(self, x):
|
||||||
|
"""Applies layer normalization on the input.
|
||||||
|
It normalizes the activations of the layer for each given example in a
|
||||||
|
batch independently, rather than across a batch like Batch Normalization.
|
||||||
|
i.e. applies a transformation that maintains the mean activation within
|
||||||
|
each example close to 0 and the activation standard deviation close to 1.
|
||||||
|
Args:
|
||||||
|
x: the inputs
|
||||||
|
epsilon: A small float added to variance to avoid dividing by zero.
|
||||||
|
dtype: the dtype of the computation (default: float32).
|
||||||
|
bias: If True, bias (beta) is added.
|
||||||
|
scale: If True, multiply by scale (gamma). When the next layer is linear
|
||||||
|
(also e.g. nn.relu), this can be disabled since the scaling will be done
|
||||||
|
by the next layer.
|
||||||
|
bias_init: Initializer for bias, by default, zero.
|
||||||
|
scale_init: Initializer for scale, by default, one.
|
||||||
|
Returns:
|
||||||
|
Normalized inputs (the same shape as inputs).
|
||||||
|
"""
|
||||||
|
features = x.shape[-1]
|
||||||
|
mean = jnp.mean(x, axis=-1, keepdims=True)
|
||||||
|
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
|
||||||
|
var = mean2 - jax.lax.square(mean)
|
||||||
|
mul = jax.lax.rsqrt(var + self.epsilon)
|
||||||
|
if self.scale:
|
||||||
|
mul = mul * jnp.asarray(self.param("gamma", self.scale_init, (features,)), self.dtype)
|
||||||
|
y = (x - mean) * mul
|
||||||
|
if self.bias:
|
||||||
|
y = y + jnp.asarray(self.param("beta", self.bias_init, (features,)), self.dtype)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxBertEmbedding(nn.Module):
|
||||||
|
"""
|
||||||
|
Specify a new class for doing the embedding stuff
|
||||||
|
as Flax's one use 'embedding' for the parameter name
|
||||||
|
and PyTorch use 'weight'
|
||||||
|
"""
|
||||||
|
|
||||||
|
vocab_size: int
|
||||||
|
hidden_size: int
|
||||||
|
emb_init: Callable[..., np.ndarray] = nn.initializers.normal(stddev=0.1)
|
||||||
|
|
||||||
|
@compact
|
||||||
|
def __call__(self, inputs):
|
||||||
|
embedding = self.param("weight", self.emb_init, (self.vocab_size, self.hidden_size))
|
||||||
|
return jnp.take(embedding, inputs, axis=0)
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxBertEmbeddings(nn.Module):
|
||||||
|
"""Construct the embeddings from word, position and token_type embeddings."""
|
||||||
|
|
||||||
|
vocab_size: int
|
||||||
|
hidden_size: int
|
||||||
|
type_vocab_size: int
|
||||||
|
max_length: int
|
||||||
|
|
||||||
|
@compact
|
||||||
|
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask):
|
||||||
|
|
||||||
|
# Embed
|
||||||
|
w_emb = FlaxBertEmbedding(self.vocab_size, self.hidden_size, name="word_embeddings")(
|
||||||
|
jnp.atleast_2d(input_ids.astype("i4"))
|
||||||
|
)
|
||||||
|
p_emb = FlaxBertEmbedding(self.max_length, self.hidden_size, name="position_embeddings")(
|
||||||
|
jnp.atleast_2d(position_ids.astype("i4"))
|
||||||
|
)
|
||||||
|
t_emb = FlaxBertEmbedding(self.type_vocab_size, self.hidden_size, name="token_type_embeddings")(
|
||||||
|
jnp.atleast_2d(token_type_ids.astype("i4"))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sum all embeddings
|
||||||
|
summed_emb = w_emb + jnp.broadcast_to(p_emb, w_emb.shape) + t_emb
|
||||||
|
|
||||||
|
# Layer Norm
|
||||||
|
layer_norm = FlaxBertLayerNorm(name="layer_norm")(summed_emb)
|
||||||
|
|
||||||
|
return layer_norm
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxBertAttention(nn.Module):
|
||||||
|
num_heads: int
|
||||||
|
head_size: int
|
||||||
|
|
||||||
|
@compact
|
||||||
|
def __call__(self, hidden_state, attention_mask):
|
||||||
|
self_att = nn.attention.SelfAttention(num_heads=self.num_heads, qkv_features=self.head_size, name="self")(
|
||||||
|
hidden_state, attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
layer_norm = FlaxBertLayerNorm(name="layer_norm")(self_att + hidden_state)
|
||||||
|
return layer_norm
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxBertIntermediate(nn.Module):
|
||||||
|
output_size: int
|
||||||
|
|
||||||
|
@compact
|
||||||
|
def __call__(self, hidden_state):
|
||||||
|
# TODO: Add ACT2FN reference to change activation function
|
||||||
|
dense = nn.Dense(features=self.output_size, name="dense")(hidden_state)
|
||||||
|
return gelu(dense)
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxBertOutput(nn.Module):
|
||||||
|
@compact
|
||||||
|
def __call__(self, intermediate_output, attention_output):
|
||||||
|
hidden_state = nn.Dense(attention_output.shape[-1], name="dense")(intermediate_output)
|
||||||
|
hidden_state = FlaxBertLayerNorm(name="layer_norm")(hidden_state + attention_output)
|
||||||
|
return hidden_state
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxBertLayer(nn.Module):
|
||||||
|
num_heads: int
|
||||||
|
head_size: int
|
||||||
|
intermediate_size: int
|
||||||
|
|
||||||
|
@compact
|
||||||
|
def __call__(self, hidden_state, attention_mask):
|
||||||
|
attention = FlaxBertAttention(self.num_heads, self.head_size, name="attention")(hidden_state, attention_mask)
|
||||||
|
intermediate = FlaxBertIntermediate(self.intermediate_size, name="intermediate")(attention)
|
||||||
|
output = FlaxBertOutput(name="output")(intermediate, attention)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxBertLayerCollection(nn.Module):
|
||||||
|
"""
|
||||||
|
Stores N BertLayer(s)
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_layers: int
|
||||||
|
num_heads: int
|
||||||
|
head_size: int
|
||||||
|
intermediate_size: int
|
||||||
|
|
||||||
|
@compact
|
||||||
|
def __call__(self, inputs, attention_mask):
|
||||||
|
assert self.num_layers > 0, f"num_layers should be >= 1, got ({self.num_layers})"
|
||||||
|
|
||||||
|
# Initialize input / output
|
||||||
|
input_i = inputs
|
||||||
|
|
||||||
|
# Forward over all encoders
|
||||||
|
for i in range(self.num_layers):
|
||||||
|
layer = FlaxBertLayer(self.num_heads, self.head_size, self.intermediate_size, name=f"{i}")
|
||||||
|
input_i = layer(input_i, attention_mask)
|
||||||
|
return input_i
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxBertEncoder(nn.Module):
|
||||||
|
num_layers: int
|
||||||
|
num_heads: int
|
||||||
|
head_size: int
|
||||||
|
intermediate_size: int
|
||||||
|
|
||||||
|
@compact
|
||||||
|
def __call__(self, hidden_state, attention_mask):
|
||||||
|
layer = FlaxBertLayerCollection(
|
||||||
|
self.num_layers, self.num_heads, self.head_size, self.intermediate_size, name="layer"
|
||||||
|
)(hidden_state, attention_mask)
|
||||||
|
return layer
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxBertPooler(nn.Module):
|
||||||
|
@compact
|
||||||
|
def __call__(self, hidden_state):
|
||||||
|
cls_token = hidden_state[:, 0]
|
||||||
|
out = nn.Dense(hidden_state.shape[-1], name="dense")(cls_token)
|
||||||
|
return jax.lax.tanh(out)
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxBertModule(nn.Module):
|
||||||
|
vocab_size: int
|
||||||
|
hidden_size: int
|
||||||
|
type_vocab_size: int
|
||||||
|
max_length: int
|
||||||
|
num_encoder_layers: int
|
||||||
|
num_heads: int
|
||||||
|
head_size: int
|
||||||
|
intermediate_size: int
|
||||||
|
|
||||||
|
@compact
|
||||||
|
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask):
|
||||||
|
|
||||||
|
# Embedding
|
||||||
|
embeddings = FlaxBertEmbeddings(
|
||||||
|
self.vocab_size, self.hidden_size, self.type_vocab_size, self.max_length, name="embeddings"
|
||||||
|
)(input_ids, token_type_ids, position_ids, attention_mask)
|
||||||
|
|
||||||
|
# N stacked encoding layers
|
||||||
|
encoder = FlaxBertEncoder(
|
||||||
|
self.num_encoder_layers, self.num_heads, self.head_size, self.intermediate_size, name="encoder"
|
||||||
|
)(embeddings, attention_mask)
|
||||||
|
|
||||||
|
pooled = FlaxBertPooler(name="pooler")(encoder)
|
||||||
|
return encoder, pooled
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
|
||||||
|
BERT_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class FlaxBertModel(FlaxPreTrainedModel):
|
||||||
|
"""
|
||||||
|
The model can behave as an encoder (with only self-attention) as well
|
||||||
|
as a decoder, in which case a layer of cross-attention is added between
|
||||||
|
the self-attention layers, following the architecture described in `Attention is all you need
|
||||||
|
<https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones,
|
||||||
|
Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_class = FlaxBertModule
|
||||||
|
config_class = BertConfig
|
||||||
|
base_model_prefix = "bert"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def convert_from_pytorch(pt_state: Dict, config: BertConfig) -> Dict:
|
||||||
|
jax_state = dict(pt_state)
|
||||||
|
|
||||||
|
# Need to change some parameters name to match Flax names so that we don't have to fork any layer
|
||||||
|
for key, tensor in pt_state.items():
|
||||||
|
# Key parts
|
||||||
|
key_parts = set(key.split("."))
|
||||||
|
|
||||||
|
# Every dense layer has "kernel" parameters instead of "weight"
|
||||||
|
if "dense.weight" in key:
|
||||||
|
del jax_state[key]
|
||||||
|
key = key.replace("weight", "kernel")
|
||||||
|
jax_state[key] = tensor
|
||||||
|
|
||||||
|
# SelfAttention needs also to replace "weight" by "kernel"
|
||||||
|
if {"query", "key", "value"} & key_parts:
|
||||||
|
|
||||||
|
# Flax SelfAttention decomposes the heads (num_head, size // num_heads)
|
||||||
|
if "bias" in key:
|
||||||
|
jax_state[key] = tensor.reshape((config.num_attention_heads, -1))
|
||||||
|
elif "weight":
|
||||||
|
del jax_state[key]
|
||||||
|
key = key.replace("weight", "kernel")
|
||||||
|
tensor = tensor.reshape((config.num_attention_heads, -1, config.hidden_size)).transpose((2, 0, 1))
|
||||||
|
jax_state[key] = tensor
|
||||||
|
|
||||||
|
# SelfAttention output is not a separate layer, remove one nesting
|
||||||
|
if "attention.output.dense" in key:
|
||||||
|
del jax_state[key]
|
||||||
|
key = key.replace("attention.output.dense", "attention.self.out")
|
||||||
|
jax_state[key] = tensor
|
||||||
|
|
||||||
|
# SelfAttention output is not a separate layer, remove nesting on layer norm
|
||||||
|
if "attention.output.LayerNorm" in key:
|
||||||
|
del jax_state[key]
|
||||||
|
key = key.replace("attention.output.LayerNorm", "attention.LayerNorm")
|
||||||
|
jax_state[key] = tensor
|
||||||
|
|
||||||
|
# There are some transposed parameters w.r.t their PyTorch counterpart
|
||||||
|
if "intermediate.dense.kernel" in key or "output.dense.kernel" in key:
|
||||||
|
jax_state[key] = tensor.T
|
||||||
|
|
||||||
|
# Self Attention output projection needs to be transposed
|
||||||
|
if "out.kernel" in key:
|
||||||
|
jax_state[key] = tensor.reshape((config.hidden_size, config.num_attention_heads, -1)).transpose(
|
||||||
|
1, 2, 0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pooler needs to transpose its kernel
|
||||||
|
if "pooler.dense.kernel" in key:
|
||||||
|
jax_state[key] = tensor.T
|
||||||
|
|
||||||
|
# Handle LayerNorm conversion
|
||||||
|
if "LayerNorm" in key:
|
||||||
|
del jax_state[key]
|
||||||
|
|
||||||
|
# Replace LayerNorm by layer_norm
|
||||||
|
new_key = key.replace("LayerNorm", "layer_norm")
|
||||||
|
|
||||||
|
if "weight" in key:
|
||||||
|
new_key = new_key.replace("weight", "gamma")
|
||||||
|
elif "bias" in key:
|
||||||
|
new_key = new_key.replace("bias", "beta")
|
||||||
|
|
||||||
|
jax_state[new_key] = tensor
|
||||||
|
|
||||||
|
return jax_state
|
||||||
|
|
||||||
|
def __init__(self, config: BertConfig, state: dict, seed: int = 0, **kwargs):
|
||||||
|
model = FlaxBertModule(
|
||||||
|
vocab_size=config.vocab_size,
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
type_vocab_size=config.type_vocab_size,
|
||||||
|
max_length=config.max_position_embeddings,
|
||||||
|
num_encoder_layers=config.num_hidden_layers,
|
||||||
|
num_heads=config.num_attention_heads,
|
||||||
|
head_size=config.hidden_size,
|
||||||
|
intermediate_size=config.intermediate_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(config, model, state, seed)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def module(self) -> nn.Module:
|
||||||
|
return self._module
|
||||||
|
|
||||||
|
def __call__(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None):
|
||||||
|
if token_type_ids is None:
|
||||||
|
token_type_ids = jnp.ones_like(input_ids)
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = jnp.arange(jnp.atleast_2d(input_ids).shape[-1])
|
||||||
|
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = jnp.ones_like(input_ids)
|
||||||
|
|
||||||
|
return self.model.apply(
|
||||||
|
{"params": self.params},
|
||||||
|
jnp.array(input_ids, dtype="i4"),
|
||||||
|
jnp.array(token_type_ids, dtype="i4"),
|
||||||
|
jnp.array(position_ids, dtype="i4"),
|
||||||
|
jnp.array(attention_mask, dtype="i4"),
|
||||||
|
)
|
||||||
450
src/transformers/modeling_flax_roberta.py
Normal file
450
src/transformers/modeling_flax_roberta.py
Normal file
@@ -0,0 +1,450 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import Callable, Dict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import flax.linen as nn
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from flax.linen import compact
|
||||||
|
|
||||||
|
from .configuration_roberta import RobertaConfig
|
||||||
|
from .file_utils import add_start_docstrings
|
||||||
|
from .modeling_flax_utils import FlaxPreTrainedModel, gelu
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
_CONFIG_FOR_DOC = "RobertaConfig"
|
||||||
|
_TOKENIZER_FOR_DOC = "RobertaTokenizer"
|
||||||
|
|
||||||
|
|
||||||
|
ROBERTA_START_DOCSTRING = r"""
|
||||||
|
|
||||||
|
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
|
||||||
|
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
|
||||||
|
pruning heads etc.)
|
||||||
|
|
||||||
|
This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__ subclass.
|
||||||
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
|
||||||
|
usage and behavior.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
config (:class:`~transformers.RobertaConfig`): Model configuration class with all the parameters of the
|
||||||
|
model. Initializing with a config file does not load the weights associated with the model, only the configuration.
|
||||||
|
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ROBERTA_INPUTS_DOCSTRING = r"""
|
||||||
|
Args:
|
||||||
|
input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
|
||||||
|
Indices of input sequence tokens in the vocabulary.
|
||||||
|
|
||||||
|
Indices can be obtained using :class:`~transformers.RobertaTokenizer`.
|
||||||
|
See :meth:`transformers.PreTrainedTokenizer.encode` and
|
||||||
|
:meth:`transformers.PreTrainedTokenizer.__call__` for details.
|
||||||
|
|
||||||
|
`What are input IDs? <../glossary.html#input-ids>`__
|
||||||
|
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
|
||||||
|
Mask to avoid performing attention on padding token indices.
|
||||||
|
Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **maked**.
|
||||||
|
|
||||||
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
|
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
|
||||||
|
Segment token indices to indicate first and second portions of the inputs.
|
||||||
|
Indices are selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 0 corresponds to a `sentence A` token,
|
||||||
|
- 1 corresponds to a `sentence B` token.
|
||||||
|
|
||||||
|
`What are token type IDs? <../glossary.html#token-type-ids>`_
|
||||||
|
position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
|
||||||
|
Indices of positions of each input sequence tokens in the position embeddings.
|
||||||
|
Selected in the range ``[0, config.max_position_embeddings - 1]``.
|
||||||
|
|
||||||
|
`What are position IDs? <../glossary.html#position-ids>`_
|
||||||
|
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
|
||||||
|
Mask to nullify selected heads of the self-attention modules.
|
||||||
|
Mask values selected in ``[0, 1]``:
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
|
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
|
||||||
|
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
||||||
|
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
|
||||||
|
vectors than the model's internal embedding lookup matrix.
|
||||||
|
output_attentions (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
|
||||||
|
tensors for more detail.
|
||||||
|
output_hidden_states (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
||||||
|
more detail.
|
||||||
|
return_dict (:obj:`bool`, `optional`):
|
||||||
|
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.modeling_flax_bert.FlaxBertLayerNorm with Bert->Roberta
|
||||||
|
class FlaxRobertaLayerNorm(nn.Module):
|
||||||
|
"""Layer normalization (https://arxiv.org/abs/1607.06450).
|
||||||
|
Operates on the last axis of the input data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
epsilon: float = 1e-6
|
||||||
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
bias: bool = True
|
||||||
|
scale: bool = True
|
||||||
|
bias_init: jnp.ndarray = nn.initializers.zeros
|
||||||
|
scale_init: jnp.ndarray = nn.initializers.ones
|
||||||
|
|
||||||
|
@compact
|
||||||
|
def __call__(self, x):
|
||||||
|
"""Applies layer normalization on the input.
|
||||||
|
It normalizes the activations of the layer for each given example in a
|
||||||
|
batch independently, rather than across a batch like Batch Normalization.
|
||||||
|
i.e. applies a transformation that maintains the mean activation within
|
||||||
|
each example close to 0 and the activation standard deviation close to 1.
|
||||||
|
Args:
|
||||||
|
x: the inputs
|
||||||
|
epsilon: A small float added to variance to avoid dividing by zero.
|
||||||
|
dtype: the dtype of the computation (default: float32).
|
||||||
|
bias: If True, bias (beta) is added.
|
||||||
|
scale: If True, multiply by scale (gamma). When the next layer is linear
|
||||||
|
(also e.g. nn.relu), this can be disabled since the scaling will be done
|
||||||
|
by the next layer.
|
||||||
|
bias_init: Initializer for bias, by default, zero.
|
||||||
|
scale_init: Initializer for scale, by default, one.
|
||||||
|
Returns:
|
||||||
|
Normalized inputs (the same shape as inputs).
|
||||||
|
"""
|
||||||
|
features = x.shape[-1]
|
||||||
|
mean = jnp.mean(x, axis=-1, keepdims=True)
|
||||||
|
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
|
||||||
|
var = mean2 - jax.lax.square(mean)
|
||||||
|
mul = jax.lax.rsqrt(var + self.epsilon)
|
||||||
|
if self.scale:
|
||||||
|
mul = mul * jnp.asarray(self.param("gamma", self.scale_init, (features,)), self.dtype)
|
||||||
|
y = (x - mean) * mul
|
||||||
|
if self.bias:
|
||||||
|
y = y + jnp.asarray(self.param("beta", self.bias_init, (features,)), self.dtype)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.modeling_flax_bert.FlaxBertEmbedding with Bert->Roberta
|
||||||
|
class FlaxRobertaEmbedding(nn.Module):
|
||||||
|
"""
|
||||||
|
Specify a new class for doing the embedding stuff
|
||||||
|
as Flax's one use 'embedding' for the parameter name
|
||||||
|
and PyTorch use 'weight'
|
||||||
|
"""
|
||||||
|
|
||||||
|
vocab_size: int
|
||||||
|
hidden_size: int
|
||||||
|
emb_init: Callable[..., np.ndarray] = nn.initializers.normal(stddev=0.1)
|
||||||
|
|
||||||
|
@compact
|
||||||
|
def __call__(self, inputs):
|
||||||
|
embedding = self.param("weight", self.emb_init, (self.vocab_size, self.hidden_size))
|
||||||
|
return jnp.take(embedding, inputs, axis=0)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.modeling_flax_bert.FlaxBertEmbeddings with Bert->Roberta
|
||||||
|
class FlaxRobertaEmbeddings(nn.Module):
|
||||||
|
"""Construct the embeddings from word, position and token_type embeddings."""
|
||||||
|
|
||||||
|
vocab_size: int
|
||||||
|
hidden_size: int
|
||||||
|
type_vocab_size: int
|
||||||
|
max_length: int
|
||||||
|
|
||||||
|
@compact
|
||||||
|
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask):
|
||||||
|
|
||||||
|
# Embed
|
||||||
|
w_emb = FlaxRobertaEmbedding(self.vocab_size, self.hidden_size, name="word_embeddings")(
|
||||||
|
jnp.atleast_2d(input_ids.astype("i4"))
|
||||||
|
)
|
||||||
|
p_emb = FlaxRobertaEmbedding(self.max_length, self.hidden_size, name="position_embeddings")(
|
||||||
|
jnp.atleast_2d(position_ids.astype("i4"))
|
||||||
|
)
|
||||||
|
t_emb = FlaxRobertaEmbedding(self.type_vocab_size, self.hidden_size, name="token_type_embeddings")(
|
||||||
|
jnp.atleast_2d(token_type_ids.astype("i4"))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sum all embeddings
|
||||||
|
summed_emb = w_emb + jnp.broadcast_to(p_emb, w_emb.shape) + t_emb
|
||||||
|
|
||||||
|
# Layer Norm
|
||||||
|
layer_norm = FlaxRobertaLayerNorm(name="layer_norm")(summed_emb)
|
||||||
|
|
||||||
|
return layer_norm
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.modeling_flax_bert.FlaxBertAttention with Bert->Roberta
|
||||||
|
class FlaxRobertaAttention(nn.Module):
|
||||||
|
num_heads: int
|
||||||
|
head_size: int
|
||||||
|
|
||||||
|
@compact
|
||||||
|
def __call__(self, hidden_state, attention_mask):
|
||||||
|
self_att = nn.attention.SelfAttention(num_heads=self.num_heads, qkv_features=self.head_size, name="self")(
|
||||||
|
hidden_state, attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
layer_norm = FlaxRobertaLayerNorm(name="layer_norm")(self_att + hidden_state)
|
||||||
|
return layer_norm
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta
|
||||||
|
class FlaxRobertaIntermediate(nn.Module):
|
||||||
|
output_size: int
|
||||||
|
|
||||||
|
@compact
|
||||||
|
def __call__(self, hidden_state):
|
||||||
|
# TODO: Add ACT2FN reference to change activation function
|
||||||
|
dense = nn.Dense(features=self.output_size, name="dense")(hidden_state)
|
||||||
|
return gelu(dense)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.modeling_flax_bert.FlaxBertOutput with Bert->Roberta
|
||||||
|
class FlaxRobertaOutput(nn.Module):
|
||||||
|
@compact
|
||||||
|
def __call__(self, intermediate_output, attention_output):
|
||||||
|
hidden_state = nn.Dense(attention_output.shape[-1], name="dense")(intermediate_output)
|
||||||
|
hidden_state = FlaxRobertaLayerNorm(name="layer_norm")(hidden_state + attention_output)
|
||||||
|
return hidden_state
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxRobertaLayer(nn.Module):
|
||||||
|
num_heads: int
|
||||||
|
head_size: int
|
||||||
|
intermediate_size: int
|
||||||
|
|
||||||
|
@compact
|
||||||
|
def __call__(self, hidden_state, attention_mask):
|
||||||
|
attention = FlaxRobertaAttention(self.num_heads, self.head_size, name="attention")(
|
||||||
|
hidden_state, attention_mask
|
||||||
|
)
|
||||||
|
intermediate = FlaxRobertaIntermediate(self.intermediate_size, name="intermediate")(attention)
|
||||||
|
output = FlaxRobertaOutput(name="output")(intermediate, attention)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.modeling_flax_bert.FlaxBertLayerCollection with Bert->Roberta
|
||||||
|
class FlaxRobertaLayerCollection(nn.Module):
|
||||||
|
"""
|
||||||
|
Stores N RobertaLayer(s)
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_layers: int
|
||||||
|
num_heads: int
|
||||||
|
head_size: int
|
||||||
|
intermediate_size: int
|
||||||
|
|
||||||
|
@compact
|
||||||
|
def __call__(self, inputs, attention_mask):
|
||||||
|
assert self.num_layers > 0, f"num_layers should be >= 1, got ({self.num_layers})"
|
||||||
|
|
||||||
|
# Initialize input / output
|
||||||
|
input_i = inputs
|
||||||
|
|
||||||
|
# Forward over all encoders
|
||||||
|
for i in range(self.num_layers):
|
||||||
|
layer = FlaxRobertaLayer(self.num_heads, self.head_size, self.intermediate_size, name=f"{i}")
|
||||||
|
input_i = layer(input_i, attention_mask)
|
||||||
|
return input_i
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.modeling_flax_bert.FlaxBertEncoder with Bert->Roberta
|
||||||
|
class FlaxRobertaEncoder(nn.Module):
|
||||||
|
num_layers: int
|
||||||
|
num_heads: int
|
||||||
|
head_size: int
|
||||||
|
intermediate_size: int
|
||||||
|
|
||||||
|
@compact
|
||||||
|
def __call__(self, hidden_state, attention_mask):
|
||||||
|
layer = FlaxRobertaLayerCollection(
|
||||||
|
self.num_layers, self.num_heads, self.head_size, self.intermediate_size, name="layer"
|
||||||
|
)(hidden_state, attention_mask)
|
||||||
|
return layer
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.modeling_flax_bert.FlaxBertPooler with Bert->Roberta
|
||||||
|
class FlaxRobertaPooler(nn.Module):
|
||||||
|
@compact
|
||||||
|
def __call__(self, hidden_state):
|
||||||
|
cls_token = hidden_state[:, 0]
|
||||||
|
out = nn.Dense(hidden_state.shape[-1], name="dense")(cls_token)
|
||||||
|
return jax.lax.tanh(out)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.modeling_flax_bert.FlaxBertModule with Bert->Roberta
|
||||||
|
class FlaxRobertaModule(nn.Module):
|
||||||
|
vocab_size: int
|
||||||
|
hidden_size: int
|
||||||
|
type_vocab_size: int
|
||||||
|
max_length: int
|
||||||
|
num_encoder_layers: int
|
||||||
|
num_heads: int
|
||||||
|
head_size: int
|
||||||
|
intermediate_size: int
|
||||||
|
|
||||||
|
@compact
|
||||||
|
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask):
|
||||||
|
|
||||||
|
# Embedding
|
||||||
|
embeddings = FlaxRobertaEmbeddings(
|
||||||
|
self.vocab_size, self.hidden_size, self.type_vocab_size, self.max_length, name="embeddings"
|
||||||
|
)(input_ids, token_type_ids, position_ids, attention_mask)
|
||||||
|
|
||||||
|
# N stacked encoding layers
|
||||||
|
encoder = FlaxRobertaEncoder(
|
||||||
|
self.num_encoder_layers, self.num_heads, self.head_size, self.intermediate_size, name="encoder"
|
||||||
|
)(embeddings, attention_mask)
|
||||||
|
|
||||||
|
pooled = FlaxRobertaPooler(name="pooler")(encoder)
|
||||||
|
return encoder, pooled
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
|
||||||
|
ROBERTA_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class FlaxRobertaModel(FlaxPreTrainedModel):
|
||||||
|
"""
|
||||||
|
The model can behave as an encoder (with only self-attention) as well
|
||||||
|
as a decoder, in which case a layer of cross-attention is added between
|
||||||
|
the self-attention layers, following the architecture described in `Attention is all you need`_ by Ashish Vaswani,
|
||||||
|
Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_class = FlaxRobertaModule
|
||||||
|
config_class = RobertaConfig
|
||||||
|
base_model_prefix = "roberta"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def convert_from_pytorch(pt_state: Dict, config: RobertaConfig) -> Dict:
|
||||||
|
jax_state = dict(pt_state)
|
||||||
|
|
||||||
|
# Need to change some parameters name to match Flax names so that we don't have to fork any layer
|
||||||
|
for key, tensor in pt_state.items():
|
||||||
|
# Key parts
|
||||||
|
key_parts = set(key.split("."))
|
||||||
|
|
||||||
|
# Every dense layer has "kernel" parameters instead of "weight"
|
||||||
|
if "dense.weight" in key:
|
||||||
|
del jax_state[key]
|
||||||
|
key = key.replace("weight", "kernel")
|
||||||
|
jax_state[key] = tensor
|
||||||
|
|
||||||
|
# SelfAttention needs also to replace "weight" by "kernel"
|
||||||
|
if {"query", "key", "value"} & key_parts:
|
||||||
|
|
||||||
|
# Flax SelfAttention decomposes the heads (num_head, size // num_heads)
|
||||||
|
if "bias" in key:
|
||||||
|
jax_state[key] = tensor.reshape((config.num_attention_heads, -1))
|
||||||
|
elif "weight":
|
||||||
|
del jax_state[key]
|
||||||
|
key = key.replace("weight", "kernel")
|
||||||
|
tensor = tensor.reshape((config.num_attention_heads, -1, config.hidden_size)).transpose((2, 0, 1))
|
||||||
|
jax_state[key] = tensor
|
||||||
|
|
||||||
|
# SelfAttention output is not a separate layer, remove one nesting
|
||||||
|
if "attention.output.dense" in key:
|
||||||
|
del jax_state[key]
|
||||||
|
key = key.replace("attention.output.dense", "attention.self.out")
|
||||||
|
jax_state[key] = tensor
|
||||||
|
|
||||||
|
# SelfAttention output is not a separate layer, remove nesting on layer norm
|
||||||
|
if "attention.output.LayerNorm" in key:
|
||||||
|
del jax_state[key]
|
||||||
|
key = key.replace("attention.output.LayerNorm", "attention.LayerNorm")
|
||||||
|
jax_state[key] = tensor
|
||||||
|
|
||||||
|
# There are some transposed parameters w.r.t their PyTorch counterpart
|
||||||
|
if "intermediate.dense.kernel" in key or "output.dense.kernel" in key:
|
||||||
|
jax_state[key] = tensor.T
|
||||||
|
|
||||||
|
# Self Attention output projection needs to be transposed
|
||||||
|
if "out.kernel" in key:
|
||||||
|
jax_state[key] = tensor.reshape((config.hidden_size, config.num_attention_heads, -1)).transpose(
|
||||||
|
1, 2, 0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pooler needs to transpose its kernel
|
||||||
|
if "pooler.dense.kernel" in key:
|
||||||
|
jax_state[key] = tensor.T
|
||||||
|
|
||||||
|
# Handle LayerNorm conversion
|
||||||
|
if "LayerNorm" in key:
|
||||||
|
del jax_state[key]
|
||||||
|
|
||||||
|
# Replace LayerNorm by layer_norm
|
||||||
|
new_key = key.replace("LayerNorm", "layer_norm")
|
||||||
|
|
||||||
|
if "weight" in key:
|
||||||
|
new_key = new_key.replace("weight", "gamma")
|
||||||
|
elif "bias" in key:
|
||||||
|
new_key = new_key.replace("bias", "beta")
|
||||||
|
|
||||||
|
jax_state[new_key] = tensor
|
||||||
|
|
||||||
|
return jax_state
|
||||||
|
|
||||||
|
def __init__(self, config: RobertaConfig, state: dict, seed: int = 0, **kwargs):
|
||||||
|
model = FlaxRobertaModule(
|
||||||
|
vocab_size=config.vocab_size,
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
type_vocab_size=config.type_vocab_size,
|
||||||
|
max_length=config.max_position_embeddings,
|
||||||
|
num_encoder_layers=config.num_hidden_layers,
|
||||||
|
num_heads=config.num_attention_heads,
|
||||||
|
head_size=config.hidden_size,
|
||||||
|
intermediate_size=config.intermediate_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(config, model, state, seed)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def module(self) -> nn.Module:
|
||||||
|
return self._module
|
||||||
|
|
||||||
|
def __call__(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None):
|
||||||
|
if token_type_ids is None:
|
||||||
|
token_type_ids = jnp.ones_like(input_ids)
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = np.arange(
|
||||||
|
self.config.pad_token_id + 1, np.atleast_2d(input_ids).shape[-1] + self.config.pad_token_id + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = jnp.ones_like(input_ids)
|
||||||
|
|
||||||
|
return self.model.apply(
|
||||||
|
{"params": self.params},
|
||||||
|
jnp.array(input_ids, dtype="i4"),
|
||||||
|
jnp.array(token_type_ids, dtype="i4"),
|
||||||
|
jnp.array(position_ids, dtype="i4"),
|
||||||
|
jnp.array(attention_mask, dtype="i4"),
|
||||||
|
)
|
||||||
194
src/transformers/modeling_flax_utils.py
Normal file
194
src/transformers/modeling_flax_utils.py
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from pickle import UnpicklingError
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import flax.linen as nn
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from flax.serialization import to_bytes
|
||||||
|
from flax.traverse_util import unflatten_dict
|
||||||
|
from jax.random import PRNGKey
|
||||||
|
|
||||||
|
from .configuration_utils import PretrainedConfig
|
||||||
|
from .file_utils import WEIGHTS_NAME, cached_path, hf_bucket_url, is_remote_url
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@jax.jit
|
||||||
|
def gelu(x):
|
||||||
|
r"""Gaussian error linear unit activation function.
|
||||||
|
|
||||||
|
Computes the element-wise function:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left(
|
||||||
|
\sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)
|
||||||
|
|
||||||
|
We explicitly use the approximation rather than the exact formulation for
|
||||||
|
speed. For more information, see `Gaussian Error Linear Units (GELUs)
|
||||||
|
<https://arxiv.org/abs/1606.08415>`_, section 2.
|
||||||
|
"""
|
||||||
|
return x * 0.5 * (1.0 + jax.lax.erf(x / jnp.sqrt(2.0)))
|
||||||
|
|
||||||
|
|
||||||
|
ACT2FN = {
|
||||||
|
"gelu": nn.gelu,
|
||||||
|
"relu": nn.relu,
|
||||||
|
"swish": nn.swish,
|
||||||
|
"gelu_new": gelu,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxPreTrainedModel(ABC):
|
||||||
|
config_class = None
|
||||||
|
pretrained_model_archive_map = {}
|
||||||
|
base_model_prefix = ""
|
||||||
|
model_class = None
|
||||||
|
|
||||||
|
def __init__(self, config: PretrainedConfig, module: nn.Module, params: Dict, seed: int = 0):
|
||||||
|
if config is None:
|
||||||
|
raise ValueError("config cannot be None")
|
||||||
|
|
||||||
|
if module is None:
|
||||||
|
raise ValueError("module cannot be None")
|
||||||
|
|
||||||
|
if params is None:
|
||||||
|
raise ValueError("state cannot be None")
|
||||||
|
|
||||||
|
# Those are private to be exposed as typed property on derived classes.
|
||||||
|
self._config = config
|
||||||
|
self._module = module
|
||||||
|
|
||||||
|
# Those are public as their type is generic to every derived classes.
|
||||||
|
self.key = PRNGKey(seed)
|
||||||
|
self.params = params
|
||||||
|
self.model = module
|
||||||
|
|
||||||
|
@property
|
||||||
|
def config(self) -> PretrainedConfig:
|
||||||
|
return self._config
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def convert_from_pytorch(pt_state: Dict, config: PretrainedConfig) -> Dict:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
|
r"""
|
||||||
|
Instantiate a pretrained Flax model from a pre-trained model configuration.
|
||||||
|
"""
|
||||||
|
config = kwargs.pop("config", None)
|
||||||
|
# state_dict = kwargs.pop("state_dict", None)
|
||||||
|
cache_dir = kwargs.pop("cache_dir", None)
|
||||||
|
# from_tf = kwargs.pop("from_tf", False)
|
||||||
|
force_download = kwargs.pop("force_download", False)
|
||||||
|
resume_download = kwargs.pop("resume_download", False)
|
||||||
|
proxies = kwargs.pop("proxies", None)
|
||||||
|
# output_loading_info = kwargs.pop("output_loading_info", False)
|
||||||
|
local_files_only = kwargs.pop("local_files_only", False)
|
||||||
|
use_cdn = kwargs.pop("use_cdn", True)
|
||||||
|
|
||||||
|
# Load config if we don't provide a configuration
|
||||||
|
if not isinstance(config, PretrainedConfig):
|
||||||
|
config_path = config if config is not None else pretrained_model_name_or_path
|
||||||
|
config, model_kwargs = cls.config_class.from_pretrained(
|
||||||
|
config_path,
|
||||||
|
*model_args,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
return_unused_kwargs=True,
|
||||||
|
force_download=force_download,
|
||||||
|
resume_download=resume_download,
|
||||||
|
proxies=proxies,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model_kwargs = kwargs
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
if pretrained_model_name_or_path is not None:
|
||||||
|
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||||
|
archive_file = pretrained_model_name_or_path
|
||||||
|
else:
|
||||||
|
archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, use_cdn=use_cdn)
|
||||||
|
|
||||||
|
# redirect to the cache, if necessary
|
||||||
|
try:
|
||||||
|
resolved_archive_file = cached_path(
|
||||||
|
archive_file,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
proxies=proxies,
|
||||||
|
resume_download=resume_download,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
)
|
||||||
|
except EnvironmentError:
|
||||||
|
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
||||||
|
msg = f"Couldn't reach server at '{archive_file}' to download pretrained weights."
|
||||||
|
else:
|
||||||
|
msg = (
|
||||||
|
f"Model name '{pretrained_model_name_or_path}' "
|
||||||
|
f"was not found in model name list ({', '.join(cls.pretrained_model_archive_map.keys())}). "
|
||||||
|
f"We assumed '{archive_file}' was a path or url to model weight files but "
|
||||||
|
"couldn't find any such file at this path or url."
|
||||||
|
)
|
||||||
|
raise EnvironmentError(msg)
|
||||||
|
|
||||||
|
if resolved_archive_file == archive_file:
|
||||||
|
logger.info(f"loading weights file {archive_file}")
|
||||||
|
else:
|
||||||
|
logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}")
|
||||||
|
else:
|
||||||
|
resolved_archive_file = None
|
||||||
|
|
||||||
|
# Instantiate model.
|
||||||
|
with open(resolved_archive_file, "rb") as state_f:
|
||||||
|
try:
|
||||||
|
from flax.serialization import from_bytes
|
||||||
|
|
||||||
|
state = from_bytes(cls.model_class, state_f)
|
||||||
|
except TypeError:
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
state = torch.load(state_f)
|
||||||
|
state = {k: v.numpy() for k, v in state.items()}
|
||||||
|
state = cls.convert_from_pytorch(state, config)
|
||||||
|
state = unflatten_dict({tuple(k.split(".")[1:]): v for k, v in state.items()})
|
||||||
|
except UnpicklingError:
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"Unable to convert model {archive_file} to Flax deserializable object. "
|
||||||
|
"Supported format are PyTorch archive or Flax msgpack"
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls(config, state, *model_args, **model_kwargs)
|
||||||
|
|
||||||
|
def save_pretrained(self, folder):
|
||||||
|
folder_abs = os.path.abspath(folder)
|
||||||
|
|
||||||
|
if not os.path.exists(folder_abs):
|
||||||
|
os.mkdir(folder_abs)
|
||||||
|
|
||||||
|
with open(os.path.join(folder_abs, f"{self._config.model_type}.flax", "wb")) as f:
|
||||||
|
model_bytes = to_bytes(self.params)
|
||||||
|
f.write(model_bytes)
|
||||||
@@ -13,6 +13,7 @@ from pathlib import Path
|
|||||||
from .file_utils import (
|
from .file_utils import (
|
||||||
_datasets_available,
|
_datasets_available,
|
||||||
_faiss_available,
|
_faiss_available,
|
||||||
|
_flax_available,
|
||||||
_sentencepiece_available,
|
_sentencepiece_available,
|
||||||
_tf_available,
|
_tf_available,
|
||||||
_tokenizers_available,
|
_tokenizers_available,
|
||||||
@@ -115,6 +116,18 @@ def require_tf(test_case):
|
|||||||
return test_case
|
return test_case
|
||||||
|
|
||||||
|
|
||||||
|
def require_flax(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test that requires JAX & Flax
|
||||||
|
|
||||||
|
These tests are skipped when one / both are not installed
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not _flax_available:
|
||||||
|
test_case = unittest.skip("test requires JAX & Flax")(test_case)
|
||||||
|
return test_case
|
||||||
|
|
||||||
|
|
||||||
def require_sentencepiece(test_case):
|
def require_sentencepiece(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that requires SentencePiece.
|
Decorator marking a test that requires SentencePiece.
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ from .file_utils import (
|
|||||||
add_end_docstrings,
|
add_end_docstrings,
|
||||||
cached_path,
|
cached_path,
|
||||||
hf_bucket_url,
|
hf_bucket_url,
|
||||||
|
is_flax_available,
|
||||||
is_remote_url,
|
is_remote_url,
|
||||||
is_tf_available,
|
is_tf_available,
|
||||||
is_tokenizers_available,
|
is_tokenizers_available,
|
||||||
@@ -47,6 +48,8 @@ if is_tf_available():
|
|||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
if is_flax_available():
|
||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
if is_tokenizers_available():
|
if is_tokenizers_available():
|
||||||
from tokenizers import AddedToken
|
from tokenizers import AddedToken
|
||||||
@@ -143,6 +146,7 @@ class TensorType(ExplicitEnum):
|
|||||||
PYTORCH = "pt"
|
PYTORCH = "pt"
|
||||||
TENSORFLOW = "tf"
|
TENSORFLOW = "tf"
|
||||||
NUMPY = "np"
|
NUMPY = "np"
|
||||||
|
JAX = "jax"
|
||||||
|
|
||||||
|
|
||||||
class CharSpan(NamedTuple):
|
class CharSpan(NamedTuple):
|
||||||
@@ -559,18 +563,27 @@ class BatchEncoding(UserDict):
|
|||||||
tensor_type = TensorType(tensor_type)
|
tensor_type = TensorType(tensor_type)
|
||||||
|
|
||||||
# Get a function reference for the correct framework
|
# Get a function reference for the correct framework
|
||||||
if tensor_type == TensorType.TENSORFLOW and is_tf_available():
|
if tensor_type == TensorType.TENSORFLOW:
|
||||||
as_tensor = tf.constant
|
if not is_tf_available():
|
||||||
elif tensor_type == TensorType.PYTORCH and is_torch_available():
|
|
||||||
as_tensor = torch.tensor
|
|
||||||
elif tensor_type == TensorType.NUMPY:
|
|
||||||
as_tensor = np.asarray
|
|
||||||
else:
|
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format(
|
"Unable to convert output to TensorFlow tensors format, TensorFlow is not installed."
|
||||||
tensor_type
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
as_tensor = tf.constant
|
||||||
|
elif tensor_type == TensorType.PYTORCH:
|
||||||
|
if not is_torch_available():
|
||||||
|
raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
|
||||||
|
as_tensor = torch.tensor
|
||||||
|
elif tensor_type == TensorType.JAX:
|
||||||
|
if not is_flax_available():
|
||||||
|
raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.")
|
||||||
|
as_tensor = jnp.array
|
||||||
|
else:
|
||||||
|
as_tensor = np.asarray
|
||||||
|
# (mfuntowicz: This code is unreachable)
|
||||||
|
# else:
|
||||||
|
# raise ImportError(
|
||||||
|
# "Unable to convert output to tensors format {}".format(tensor_type)
|
||||||
|
# )
|
||||||
|
|
||||||
# Do the tensor conversion in batch
|
# Do the tensor conversion in batch
|
||||||
for key, value in self.items():
|
for key, value in self.items():
|
||||||
|
|||||||
64
tests/test_flax_auto.py
Normal file
64
tests/test_flax_auto.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers import AutoConfig, AutoTokenizer, BertConfig, TensorType, is_flax_available
|
||||||
|
from transformers.testing_utils import require_flax, slow
|
||||||
|
|
||||||
|
|
||||||
|
if is_flax_available():
|
||||||
|
import jax
|
||||||
|
from transformers.modeling_flax_auto import FlaxAutoModel
|
||||||
|
from transformers.modeling_flax_bert import FlaxBertModel
|
||||||
|
from transformers.modeling_flax_roberta import FlaxRobertaModel
|
||||||
|
|
||||||
|
|
||||||
|
@require_flax
|
||||||
|
class FlaxAutoModelTest(unittest.TestCase):
|
||||||
|
@slow
|
||||||
|
def test_bert_from_pretrained(self):
|
||||||
|
for model_name in ["bert-base-cased", "bert-large-uncased"]:
|
||||||
|
with self.subTest(model_name):
|
||||||
|
config = AutoConfig.from_pretrained(model_name)
|
||||||
|
self.assertIsNotNone(config)
|
||||||
|
self.assertIsInstance(config, BertConfig)
|
||||||
|
|
||||||
|
model = FlaxAutoModel.from_pretrained(model_name)
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
self.assertIsInstance(model, FlaxBertModel)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_roberta_from_pretrained(self):
|
||||||
|
for model_name in ["roberta-base-cased", "roberta-large-uncased"]:
|
||||||
|
with self.subTest(model_name):
|
||||||
|
config = AutoConfig.from_pretrained(model_name)
|
||||||
|
self.assertIsNotNone(config)
|
||||||
|
self.assertIsInstance(config, BertConfig)
|
||||||
|
|
||||||
|
model = FlaxAutoModel.from_pretrained(model_name)
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
self.assertIsInstance(model, FlaxRobertaModel)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_bert_jax_jit(self):
|
||||||
|
for model_name in ["bert-base-cased", "bert-large-uncased"]:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
model = FlaxBertModel.from_pretrained(model_name)
|
||||||
|
tokens = tokenizer("Do you support jax jitted function?", return_tensors=TensorType.JAX)
|
||||||
|
|
||||||
|
@jax.jit
|
||||||
|
def eval(**kwargs):
|
||||||
|
return model(**kwargs)
|
||||||
|
|
||||||
|
eval(**tokens).block_until_ready()
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_roberta_jax_jit(self):
|
||||||
|
for model_name in ["roberta-base-cased", "roberta-large-uncased"]:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
model = FlaxRobertaModel.from_pretrained(model_name)
|
||||||
|
tokens = tokenizer("Do you support jax jitted function?", return_tensors=TensorType.JAX)
|
||||||
|
|
||||||
|
@jax.jit
|
||||||
|
def eval(**kwargs):
|
||||||
|
return model(**kwargs)
|
||||||
|
|
||||||
|
eval(**tokens).block_until_ready()
|
||||||
42
tests/test_modeling_flax_bert.py
Normal file
42
tests/test_modeling_flax_bert.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
from numpy import ndarray
|
||||||
|
|
||||||
|
from transformers import TensorType, is_flax_available, is_torch_available
|
||||||
|
from transformers.testing_utils import require_flax, require_torch
|
||||||
|
from transformers.tokenization_bert_fast import BertTokenizerFast
|
||||||
|
|
||||||
|
|
||||||
|
if is_flax_available():
|
||||||
|
from transformers.modeling_flax_bert import FlaxBertModel
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers.modeling_bert import BertModel
|
||||||
|
|
||||||
|
|
||||||
|
@require_flax
|
||||||
|
@require_torch
|
||||||
|
class FlaxBertModelTest(unittest.TestCase):
|
||||||
|
def test_from_pytorch(self):
|
||||||
|
with torch.no_grad():
|
||||||
|
with self.subTest("bert-base-cased"):
|
||||||
|
tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")
|
||||||
|
fx_model = FlaxBertModel.from_pretrained("bert-base-cased")
|
||||||
|
pt_model = BertModel.from_pretrained("bert-base-cased")
|
||||||
|
|
||||||
|
# Check for simple input
|
||||||
|
pt_inputs = tokenizer.encode_plus("This is a simple input", return_tensors=TensorType.PYTORCH)
|
||||||
|
fx_inputs = tokenizer.encode_plus("This is a simple input", return_tensors=TensorType.JAX)
|
||||||
|
pt_outputs = pt_model(**pt_inputs)
|
||||||
|
fx_outputs = fx_model(**fx_inputs)
|
||||||
|
|
||||||
|
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||||
|
|
||||||
|
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||||
|
self.assert_almost_equals(fx_output, pt_output.numpy(), 5e-4)
|
||||||
|
|
||||||
|
def assert_almost_equals(self, a: ndarray, b: ndarray, tol: float):
|
||||||
|
diff = (a - b).sum()
|
||||||
|
self.assertLessEqual(diff, tol, "Difference between torch and flax is {} (>= {})".format(diff, tol))
|
||||||
42
tests/test_modeling_flax_roberta.py
Normal file
42
tests/test_modeling_flax_roberta.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
from numpy import ndarray
|
||||||
|
|
||||||
|
from transformers import TensorType, is_flax_available, is_torch_available
|
||||||
|
from transformers.testing_utils import require_flax, require_torch
|
||||||
|
from transformers.tokenization_roberta_fast import RobertaTokenizerFast
|
||||||
|
|
||||||
|
|
||||||
|
if is_flax_available():
|
||||||
|
from transformers.modeling_flax_roberta import FlaxRobertaModel
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers.modeling_roberta import RobertaModel
|
||||||
|
|
||||||
|
|
||||||
|
@require_flax
|
||||||
|
@require_torch
|
||||||
|
class FlaxRobertaModelTest(unittest.TestCase):
|
||||||
|
def test_from_pytorch(self):
|
||||||
|
with torch.no_grad():
|
||||||
|
with self.subTest("roberta-base"):
|
||||||
|
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
|
||||||
|
fx_model = FlaxRobertaModel.from_pretrained("roberta-base")
|
||||||
|
pt_model = RobertaModel.from_pretrained("roberta-base")
|
||||||
|
|
||||||
|
# Check for simple input
|
||||||
|
pt_inputs = tokenizer.encode_plus("This is a simple input", return_tensors=TensorType.PYTORCH)
|
||||||
|
fx_inputs = tokenizer.encode_plus("This is a simple input", return_tensors=TensorType.JAX)
|
||||||
|
pt_outputs = pt_model(**pt_inputs)
|
||||||
|
fx_outputs = fx_model(**fx_inputs)
|
||||||
|
|
||||||
|
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||||
|
|
||||||
|
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||||
|
self.assert_almost_equals(fx_output, pt_output.numpy(), 5e-4)
|
||||||
|
|
||||||
|
def assert_almost_equals(self, a: ndarray, b: ndarray, tol: float):
|
||||||
|
diff = (a - b).sum()
|
||||||
|
self.assertLessEqual(diff, tol, "Difference between torch and flax is {} (>= {})".format(diff, tol))
|
||||||
Reference in New Issue
Block a user