Integrate Bert-like model on Flax runtime. (#3722)

* WIP flax bert

* Initial commit Bert Jax/Flax implementation.

* Embeddings working and equivalent to PyTorch.

* Move embeddings in its own module BertEmbeddings

* Added jax.jit annotation on forward call

* BertEncoder on par with PyTorch ! :D

* Add BertPooler on par with PyTorch !!

* Working Jax+Flax implementation of BertModel with < 1e-5 differences on the last layer.

* Fix pooled output to take only the first token of the sequence.

* Refactoring to use BertConfig from transformers.

* Renamed FXBertModel to FlaxBertModel

* Model is now initialized in FlaxBertModel constructor and reused.

* WIP JaxPreTrainedModel

* Cleaning up the code of FlaxBertModel

* Added ability to load Flax model saved through save_pretrained()

* Added ability to convert Pytorch Bert model to FlaxBert

* FlaxBert can now load every Pytorch Bert model with on-the-fly conversion

* Fix hardcoded shape values in conversion scripts.

* Improve the way we handle LayerNorm conversion from PyTorch to Flax.

* Added positional embeddings as parameter of BertModel with default to np.arange.

* Let's roll FlaxRoberta !

* Fix missing position_ids parameters on predict for Bert

* Flax backend now supports batched inputs

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Make it possible to load msgpacked model on convert from pytorch in last resort.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Moved save_pretrained to Jax base class along with more constructor parameters.

* Use specialized, model dependent conversion functio.

* Expose `is_flax_available` in file_utils.

* Added unittest for Flax models.

* Added run_tests_flax to the CI.

* Introduce FlaxAutoModel

* Added more unittests

* Flax model reference the _MODEL_ARCHIVE_MAP from PyTorch model.

* Addressing review comments.

* Expose seed in both Bert and Roberta

* Fix typo suggested by @stefan-it

Co-Authored-By: Stefan Schweter <stefan@schweter.it>

* Attempt to make style

* Attempt to make style in tests too

* Added jax & jaxlib to the flax optional dependencies.

* Attempt to fix flake8 warnings ...

* Redo black again and again

* When black and flake8 fight each other for a space ... 💥 💥 💥

* Try removing trailing comma to make both black and flake happy!

* Fix invalid is_<framework>_available call, thanks @LysandreJik 🎉

* Fix another invalid import in flax_roberta test

* Bump and pin flax release to 0.1.0.

* Make flake8 happy, remove unused jax import

* Change the type of the catch for msgpack.

* Remove unused import.

* Put seed as optional constructor parameter.

* trigger ci again

* Fix too much parameters in BertAttention.

* Formatting.

* Simplify Flax unittests to avoid machine crashes.

* Fix invalid number of arguments when raising issue for an unknown model.

* Address @bastings comment in PR, moving jax.jit decorated outside of __call__

* Fix incorrect path to require_flax/require_pytorch functions.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Attempt to make style.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Correct rebasing of circle-ci dependencies

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Fix import sorting.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Fix unused imports.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Again import sorting...

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Installing missing nlp dependency for flax unittests.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Fix laoding of model for Flax implementations.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* jit the inner function call to make JAX-compatible

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Format !

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Flake one more time 🎶

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Rewrites BERT in Flax to the new Linen API (#7211)

* Rewrite Flax HuggingFace PR to Linen

* Some fixes

* Fix tests

* Fix CI with change of name of nlp (#7054)

* nlp -> datasets

* More nlp -> datasets

* Woopsie

* More nlp -> datasets

* One last

* Expose `is_flax_available` in file_utils.

* Added run_tests_flax to the CI.

* Attempt to make style

* trigger ci again

* Fix import sorting.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Revert "Rewrites BERT in Flax to the new Linen API (#7211)"

This reverts commit 23703a5eb3364e26a1cbc3ee34b4710d86a674b0.

* Remove jnp.lax references

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Make style.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Reintroduce Linen changes ...

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Make style.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Use jax native's gelu function.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Renaming BertModel to BertModule to highlight the fact this is the Flax Module object.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Rewrite FlaxAutoModel test to not rely on pretrained_model_archive_map

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Remove unused variable in BertModule.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Remove unused variable in BertModule again

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Attempt to have is_flax_available working again.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Introduce JAX TensorType

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Improve ImportError message when trying to convert to various TensorType format.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Makes Flax model jittable.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Ensure flax models are jittable in unittests.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Remove unused imports.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Ensure jax imports are guarded behind is_flax_available.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Make style.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Make style again

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Make style again again

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Make style again again again

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Update src/transformers/file_utils.py

Co-authored-by: Marc van Zee <marcvanzee@gmail.com>

* Bump flax to it's latest version

Co-authored-by: Marc van Zee <marcvanzee@gmail.com>

* Bump jax version to at least 0.2.0

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Style.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Update the unittest to use TensorType.JAX

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* isort import in tests.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Match new flax parameters name "params"

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Remove unused imports.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Add flax models to transformers __init__

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Attempt to address all CI related comments.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Correct circle.yml indent.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Correct circle.yml indent (2)

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Remove coverage from flax tests

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Addressing many naming suggestions from comments

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Simplify for loop logic to interate over layers in FlaxBertLayerCollection

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* use f-string syntax for formatting logs.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Use config property from FlaxPreTrainedModel.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* use "cls_token" instead of "first_token" variable name.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* use "hidden_state" instead of "h" variable name.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Correct class reference in docstring to link to Flax related modules.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Added HF + Google Flax team copyright.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Make Roberta independent from Bert

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Move activation functions to flax_utils.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Move activation functions to flax_utils for bert.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Added docstring for BERT

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Update import for Bert and Roberta tokenizers

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Make style.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* fix-copies

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Correct FlaxRobertaLayer to match PyTorch.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Use the same store_artifact for flax unittest

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Style.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Make sure gradient are disabled only locally for flax unittest using torch equivalence.

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

* Use relative imports

Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>

Co-authored-by: Stefan Schweter <stefan@schweter.it>
Co-authored-by: Marc van Zee <marcvanzee@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Funtowicz Morgan
2020-10-19 15:55:41 +02:00
committed by GitHub
parent 0193c8290d
commit 8f8f8d99fc
13 changed files with 1491 additions and 13 deletions

View File

@@ -0,0 +1,194 @@
# coding=utf-8
# Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from abc import ABC, abstractmethod
from pickle import UnpicklingError
from typing import Dict
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.serialization import to_bytes
from flax.traverse_util import unflatten_dict
from jax.random import PRNGKey
from .configuration_utils import PretrainedConfig
from .file_utils import WEIGHTS_NAME, cached_path, hf_bucket_url, is_remote_url
from .utils import logging
logger = logging.get_logger(__name__)
@jax.jit
def gelu(x):
r"""Gaussian error linear unit activation function.
Computes the element-wise function:
.. math::
\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left(
\sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)
We explicitly use the approximation rather than the exact formulation for
speed. For more information, see `Gaussian Error Linear Units (GELUs)
<https://arxiv.org/abs/1606.08415>`_, section 2.
"""
return x * 0.5 * (1.0 + jax.lax.erf(x / jnp.sqrt(2.0)))
ACT2FN = {
"gelu": nn.gelu,
"relu": nn.relu,
"swish": nn.swish,
"gelu_new": gelu,
}
class FlaxPreTrainedModel(ABC):
config_class = None
pretrained_model_archive_map = {}
base_model_prefix = ""
model_class = None
def __init__(self, config: PretrainedConfig, module: nn.Module, params: Dict, seed: int = 0):
if config is None:
raise ValueError("config cannot be None")
if module is None:
raise ValueError("module cannot be None")
if params is None:
raise ValueError("state cannot be None")
# Those are private to be exposed as typed property on derived classes.
self._config = config
self._module = module
# Those are public as their type is generic to every derived classes.
self.key = PRNGKey(seed)
self.params = params
self.model = module
@property
def config(self) -> PretrainedConfig:
return self._config
@staticmethod
@abstractmethod
def convert_from_pytorch(pt_state: Dict, config: PretrainedConfig) -> Dict:
raise NotImplementedError()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r"""
Instantiate a pretrained Flax model from a pre-trained model configuration.
"""
config = kwargs.pop("config", None)
# state_dict = kwargs.pop("state_dict", None)
cache_dir = kwargs.pop("cache_dir", None)
# from_tf = kwargs.pop("from_tf", False)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
# output_loading_info = kwargs.pop("output_loading_info", False)
local_files_only = kwargs.pop("local_files_only", False)
use_cdn = kwargs.pop("use_cdn", True)
# Load config if we don't provide a configuration
if not isinstance(config, PretrainedConfig):
config_path = config if config is not None else pretrained_model_name_or_path
config, model_kwargs = cls.config_class.from_pretrained(
config_path,
*model_args,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
**kwargs,
)
else:
model_kwargs = kwargs
# Load model
if pretrained_model_name_or_path is not None:
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path
else:
archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, use_cdn=use_cdn)
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(
archive_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
)
except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
msg = f"Couldn't reach server at '{archive_file}' to download pretrained weights."
else:
msg = (
f"Model name '{pretrained_model_name_or_path}' "
f"was not found in model name list ({', '.join(cls.pretrained_model_archive_map.keys())}). "
f"We assumed '{archive_file}' was a path or url to model weight files but "
"couldn't find any such file at this path or url."
)
raise EnvironmentError(msg)
if resolved_archive_file == archive_file:
logger.info(f"loading weights file {archive_file}")
else:
logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}")
else:
resolved_archive_file = None
# Instantiate model.
with open(resolved_archive_file, "rb") as state_f:
try:
from flax.serialization import from_bytes
state = from_bytes(cls.model_class, state_f)
except TypeError:
try:
import torch
state = torch.load(state_f)
state = {k: v.numpy() for k, v in state.items()}
state = cls.convert_from_pytorch(state, config)
state = unflatten_dict({tuple(k.split(".")[1:]): v for k, v in state.items()})
except UnpicklingError:
raise EnvironmentError(
f"Unable to convert model {archive_file} to Flax deserializable object. "
"Supported format are PyTorch archive or Flax msgpack"
)
return cls(config, state, *model_args, **model_kwargs)
def save_pretrained(self, folder):
folder_abs = os.path.abspath(folder)
if not os.path.exists(folder_abs):
os.mkdir(folder_abs)
with open(os.path.join(folder_abs, f"{self._config.model_type}.flax", "wb")) as f:
model_bytes = to_bytes(self.params)
f.write(model_bytes)