[Flax] Big FlaxBert Refactor (#11364)
* improve flax * refactor * typos * Update src/transformers/modeling_flax_utils.py * Apply suggestions from code review * Update src/transformers/modeling_flax_utils.py * fix typo * improve error tolerance * typo * correct nasty saving bug * fix from pretrained * correct tree map * add note * correct weight tying
This commit is contained in:
committed by
GitHub
parent
3ed5e97ba0
commit
8c9b5fcbaf
@@ -12,12 +12,17 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" PyTorch - TF 2.0 general utilities."""
|
""" PyTorch - Flax general utilities."""
|
||||||
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from pickle import UnpicklingError
|
||||||
|
|
||||||
from flax.core.frozen_dict import unfreeze
|
import numpy as np
|
||||||
|
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import transformers
|
||||||
|
from flax.serialization import from_bytes
|
||||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||||
|
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
@@ -37,7 +42,7 @@ def load_pytorch_checkpoint_in_flax_state_dict(flax_model, pytorch_checkpoint_pa
|
|||||||
import torch # noqa: F401
|
import torch # noqa: F401
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see "
|
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see "
|
||||||
"https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation instructions."
|
"https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation instructions."
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
@@ -57,7 +62,7 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
|
|||||||
# convert pytorch tensor to numpy
|
# convert pytorch tensor to numpy
|
||||||
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
|
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
|
||||||
|
|
||||||
random_flax_state_dict = flatten_dict(unfreeze(flax_model.params))
|
random_flax_state_dict = flatten_dict(flax_model.params)
|
||||||
flax_state_dict = {}
|
flax_state_dict = {}
|
||||||
|
|
||||||
remove_base_model_prefix = (flax_model.base_model_prefix not in flax_model.params) and (
|
remove_base_model_prefix = (flax_model.base_model_prefix not in flax_model.params) and (
|
||||||
@@ -80,7 +85,12 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
|
|||||||
elif add_base_model_prefix and require_base_model_prefix:
|
elif add_base_model_prefix and require_base_model_prefix:
|
||||||
pt_tuple_key = (flax_model.base_model_prefix,) + pt_tuple_key
|
pt_tuple_key = (flax_model.base_model_prefix,) + pt_tuple_key
|
||||||
|
|
||||||
if pt_tuple_key[-1] == "weight" and pt_tuple_key not in random_flax_state_dict:
|
# Correctly rename weight parameters
|
||||||
|
if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
|
||||||
|
pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
|
||||||
|
if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
|
||||||
|
pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
|
||||||
|
elif pt_tuple_key[-1] == "weight" and pt_tuple_key not in random_flax_state_dict:
|
||||||
pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
|
pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
|
||||||
pt_tensor = pt_tensor.T
|
pt_tensor = pt_tensor.T
|
||||||
elif pt_tuple_key[-1] == "gamma":
|
elif pt_tuple_key[-1] == "gamma":
|
||||||
@@ -89,12 +99,128 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
|
|||||||
pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
|
pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
|
||||||
|
|
||||||
if pt_tuple_key in random_flax_state_dict:
|
if pt_tuple_key in random_flax_state_dict:
|
||||||
if random_flax_state_dict[pt_tuple_key].shape != pt_tensor.shape:
|
if pt_tensor.shape != random_flax_state_dict[pt_tuple_key].shape:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape {random_flax_state_dict[pt_tuple_key].shape}, but is {pt_tensor.shape}."
|
"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape {random_flax_state_dict[pt_tuple_key].shape}, but is {pt_tensor.shape}."
|
||||||
)
|
)
|
||||||
|
|
||||||
# add unexpected weight so that warning is thrown
|
# also add unexpected weight so that warning is thrown
|
||||||
flax_state_dict[pt_tuple_key] = pt_tensor
|
flax_state_dict[pt_tuple_key] = jnp.asarray(pt_tensor)
|
||||||
|
|
||||||
return unflatten_dict(flax_state_dict)
|
return unflatten_dict(flax_state_dict)
|
||||||
|
|
||||||
|
|
||||||
|
#####################
|
||||||
|
# Flax => PyTorch #
|
||||||
|
#####################
|
||||||
|
|
||||||
|
|
||||||
|
def load_flax_checkpoint_in_pytorch_model(model, flax_checkpoint_path):
|
||||||
|
"""Load flax checkpoints in a PyTorch model"""
|
||||||
|
flax_checkpoint_path = os.path.abspath(flax_checkpoint_path)
|
||||||
|
logger.info(f"Loading Flax weights from {flax_checkpoint_path}")
|
||||||
|
|
||||||
|
# import correct flax class
|
||||||
|
flax_cls = getattr(transformers, "Flax" + model.__class__.__name__)
|
||||||
|
|
||||||
|
# load flax weight dict
|
||||||
|
with open(flax_checkpoint_path, "rb") as state_f:
|
||||||
|
try:
|
||||||
|
flax_state_dict = from_bytes(flax_cls, state_f.read())
|
||||||
|
except UnpicklingError:
|
||||||
|
raise EnvironmentError(f"Unable to convert {flax_checkpoint_path} to Flax deserializable object. ")
|
||||||
|
|
||||||
|
return load_flax_weights_in_pytorch_model(model, flax_state_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def load_flax_weights_in_pytorch_model(pt_model, flax_state):
|
||||||
|
"""Load flax checkpoints in a PyTorch model"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch # noqa: F401
|
||||||
|
except ImportError:
|
||||||
|
logger.error(
|
||||||
|
"Loading a Flax weights in PyTorch, requires both PyTorch and Flax to be installed. Please see "
|
||||||
|
"https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation instructions."
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
flax_state_dict = flatten_dict(flax_state)
|
||||||
|
pt_model_dict = pt_model.state_dict()
|
||||||
|
|
||||||
|
remove_base_model_prefix = (pt_model.base_model_prefix in flax_state) and (
|
||||||
|
pt_model.base_model_prefix not in set([k.split(".")[0] for k in pt_model_dict.keys()])
|
||||||
|
)
|
||||||
|
add_base_model_prefix = (pt_model.base_model_prefix not in flax_state) and (
|
||||||
|
pt_model.base_model_prefix in set([k.split(".")[0] for k in pt_model_dict.keys()])
|
||||||
|
)
|
||||||
|
|
||||||
|
# keep track of unexpected & missing keys
|
||||||
|
unexpected_keys = []
|
||||||
|
missing_keys = set(pt_model_dict.keys())
|
||||||
|
|
||||||
|
for flax_key_tuple, flax_tensor in flax_state_dict.items():
|
||||||
|
has_base_model_prefix = flax_key_tuple[0] == pt_model.base_model_prefix
|
||||||
|
require_base_model_prefix = ".".join((pt_model.base_model_prefix,) + flax_key_tuple) in pt_model_dict
|
||||||
|
|
||||||
|
# adapt flax_key to prepare for loading from/to base model only
|
||||||
|
if remove_base_model_prefix and has_base_model_prefix:
|
||||||
|
flax_key_tuple = flax_key_tuple[1:]
|
||||||
|
elif add_base_model_prefix and require_base_model_prefix:
|
||||||
|
flax_key_tuple = (pt_model.base_model_prefix,) + flax_key_tuple
|
||||||
|
|
||||||
|
# rename flax weights to PyTorch format
|
||||||
|
if flax_key_tuple[-1] == "kernel" and ".".join(flax_key_tuple) not in pt_model_dict:
|
||||||
|
flax_key_tuple = flax_key_tuple[:-1] + ("weight",)
|
||||||
|
flax_tensor = flax_tensor.T
|
||||||
|
elif flax_key_tuple[-1] in ["scale", "embedding"]:
|
||||||
|
flax_key_tuple = flax_key_tuple[:-1] + ("weight",)
|
||||||
|
|
||||||
|
flax_key = ".".join(flax_key_tuple)
|
||||||
|
|
||||||
|
if flax_key in pt_model_dict:
|
||||||
|
if flax_tensor.shape != pt_model_dict[flax_key].shape:
|
||||||
|
raise ValueError(
|
||||||
|
f"Flax checkpoint seems to be incorrect. Weight {flax_key_tuple} was expected"
|
||||||
|
f"to be of shape {pt_model_dict[flax_key].shape}, but is {flax_tensor.shape}."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# add weight to pytorch dict
|
||||||
|
flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor
|
||||||
|
pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)
|
||||||
|
# remove from missing keys
|
||||||
|
missing_keys.remove(flax_key)
|
||||||
|
else:
|
||||||
|
# weight is not expected by PyTorch model
|
||||||
|
unexpected_keys.append(flax_key)
|
||||||
|
|
||||||
|
pt_model.load_state_dict(pt_model_dict)
|
||||||
|
|
||||||
|
# re-transform missing_keys to list
|
||||||
|
missing_keys = list(missing_keys)
|
||||||
|
|
||||||
|
if len(unexpected_keys) > 0:
|
||||||
|
logger.warning(
|
||||||
|
"Some weights of the Flax model were not used when "
|
||||||
|
f"initializing the PyTorch model {pt_model.__class__.__name__}: {unexpected_keys}\n"
|
||||||
|
f"- This IS expected if you are initializing {pt_model.__class__.__name__} from a Flax model trained on another task "
|
||||||
|
"or with another architecture (e.g. initializing a BertForSequenceClassification model from a FlaxBertForPreTraining model).\n"
|
||||||
|
f"- This IS NOT expected if you are initializing {pt_model.__class__.__name__} from a Flax model that you expect "
|
||||||
|
"to be exactly identical (e.g. initializing a BertForSequenceClassification model from a FlaxBertForSequenceClassification model)."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(f"All Flax model weights were used when initializing {pt_model.__class__.__name__}.\n")
|
||||||
|
if len(missing_keys) > 0:
|
||||||
|
logger.warning(
|
||||||
|
f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model "
|
||||||
|
f"and are newly initialized: {missing_keys}\n"
|
||||||
|
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"All the weights of {pt_model.__class__.__name__} were initialized from the Flax model.\n"
|
||||||
|
"If your task is similar to the task the model of the checkpoint was trained on, "
|
||||||
|
f"you can already use {pt_model.__class__.__name__} for predictions without further training."
|
||||||
|
)
|
||||||
|
|
||||||
|
return pt_model
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team.
|
# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -22,7 +22,7 @@ from typing import Dict, Set, Tuple, Union
|
|||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
from flax.core.frozen_dict import FrozenDict, unfreeze
|
||||||
from flax.serialization import from_bytes, to_bytes
|
from flax.serialization import from_bytes, to_bytes
|
||||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||||
from jax.random import PRNGKey
|
from jax.random import PRNGKey
|
||||||
@@ -46,7 +46,7 @@ logger = logging.get_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
ACT2FN = {
|
ACT2FN = {
|
||||||
"gelu": nn.gelu,
|
"gelu": partial(nn.gelu, approximate=False),
|
||||||
"relu": nn.relu,
|
"relu": nn.relu,
|
||||||
"silu": nn.swish,
|
"silu": nn.swish,
|
||||||
"swish": nn.swish,
|
"swish": nn.swish,
|
||||||
@@ -129,7 +129,7 @@ class FlaxPreTrainedModel(ABC):
|
|||||||
"Some parameters are missing. Make sure that `params` include the following "
|
"Some parameters are missing. Make sure that `params` include the following "
|
||||||
f"parameters {self.required_params - param_keys}"
|
f"parameters {self.required_params - param_keys}"
|
||||||
)
|
)
|
||||||
self._params = freeze(params)
|
self._params = params
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
@@ -330,6 +330,10 @@ class FlaxPreTrainedModel(ABC):
|
|||||||
state = from_bytes(cls, state_f.read())
|
state = from_bytes(cls, state_f.read())
|
||||||
except UnpicklingError:
|
except UnpicklingError:
|
||||||
raise EnvironmentError(f"Unable to convert {archive_file} to Flax deserializable object. ")
|
raise EnvironmentError(f"Unable to convert {archive_file} to Flax deserializable object. ")
|
||||||
|
# make sure all arrays are stored as jnp.arrays
|
||||||
|
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
|
||||||
|
# https://github.com/google/flax/issues/1261
|
||||||
|
state = jax.tree_util.tree_map(jnp.array, state)
|
||||||
|
|
||||||
# if model is base model only use model_prefix key
|
# if model is base model only use model_prefix key
|
||||||
if cls.base_model_prefix not in dict(model.params) and cls.base_model_prefix in state:
|
if cls.base_model_prefix not in dict(model.params) and cls.base_model_prefix in state:
|
||||||
@@ -337,6 +341,7 @@ class FlaxPreTrainedModel(ABC):
|
|||||||
|
|
||||||
# flatten dicts
|
# flatten dicts
|
||||||
state = flatten_dict(state)
|
state = flatten_dict(state)
|
||||||
|
|
||||||
random_state = flatten_dict(unfreeze(model.params))
|
random_state = flatten_dict(unfreeze(model.params))
|
||||||
|
|
||||||
missing_keys = model.required_params - set(state.keys())
|
missing_keys = model.required_params - set(state.keys())
|
||||||
@@ -377,6 +382,7 @@ class FlaxPreTrainedModel(ABC):
|
|||||||
|
|
||||||
# set correct parameters
|
# set correct parameters
|
||||||
model.params = unflatten_dict(state)
|
model.params = unflatten_dict(state)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
|
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from .activations import get_activation
|
|||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
from .file_utils import (
|
from .file_utils import (
|
||||||
DUMMY_INPUTS,
|
DUMMY_INPUTS,
|
||||||
|
FLAX_WEIGHTS_NAME,
|
||||||
TF2_WEIGHTS_NAME,
|
TF2_WEIGHTS_NAME,
|
||||||
TF_WEIGHTS_NAME,
|
TF_WEIGHTS_NAME,
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
@@ -875,6 +876,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
|||||||
this case, ``from_tf`` should be set to :obj:`True` and a configuration object should be provided
|
this case, ``from_tf`` should be set to :obj:`True` and a configuration object should be provided
|
||||||
as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in
|
as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in
|
||||||
a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
|
a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
|
||||||
|
- A path or url to a model folder containing a `flax checkpoint file` in `.msgpack` format (e.g,
|
||||||
|
``./flax_model/`` containing ``flax_model.msgpack``). In this case, ``from_flax`` should be set
|
||||||
|
to :obj:`True`.
|
||||||
- :obj:`None` if you are both providing the configuration and state dictionary (resp. with keyword
|
- :obj:`None` if you are both providing the configuration and state dictionary (resp. with keyword
|
||||||
arguments ``config`` and ``state_dict``).
|
arguments ``config`` and ``state_dict``).
|
||||||
model_args (sequence of positional arguments, `optional`):
|
model_args (sequence of positional arguments, `optional`):
|
||||||
@@ -907,6 +911,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
|||||||
from_tf (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
from_tf (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Load the model weights from a TensorFlow checkpoint save file (see docstring of
|
Load the model weights from a TensorFlow checkpoint save file (see docstring of
|
||||||
``pretrained_model_name_or_path`` argument).
|
``pretrained_model_name_or_path`` argument).
|
||||||
|
from_flax (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Load the model weights from a Flax checkpoint save file (see docstring of
|
||||||
|
``pretrained_model_name_or_path`` argument).
|
||||||
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||||
cached versions if they exist.
|
cached versions if they exist.
|
||||||
@@ -968,11 +975,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
|||||||
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable).
|
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable).
|
||||||
>>> config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
|
>>> config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
|
||||||
>>> model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
>>> model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||||
|
>>> # Loading from a Flax checkpoint file instead of a PyTorch model (slower)
|
||||||
|
>>> model = BertModel.from_pretrained('bert-base-uncased', from_flax=True)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
config = kwargs.pop("config", None)
|
config = kwargs.pop("config", None)
|
||||||
state_dict = kwargs.pop("state_dict", None)
|
state_dict = kwargs.pop("state_dict", None)
|
||||||
cache_dir = kwargs.pop("cache_dir", None)
|
cache_dir = kwargs.pop("cache_dir", None)
|
||||||
from_tf = kwargs.pop("from_tf", False)
|
from_tf = kwargs.pop("from_tf", False)
|
||||||
|
from_flax = kwargs.pop("from_flax", False)
|
||||||
force_download = kwargs.pop("force_download", False)
|
force_download = kwargs.pop("force_download", False)
|
||||||
resume_download = kwargs.pop("resume_download", False)
|
resume_download = kwargs.pop("resume_download", False)
|
||||||
proxies = kwargs.pop("proxies", None)
|
proxies = kwargs.pop("proxies", None)
|
||||||
@@ -1023,13 +1034,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
|||||||
elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
|
elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
|
||||||
# Load from a TF 2.0 checkpoint in priority if from_tf
|
# Load from a TF 2.0 checkpoint in priority if from_tf
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
|
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
|
||||||
|
elif from_flax and os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
|
||||||
|
# Load from a Flax checkpoint in priority if from_flax
|
||||||
|
archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
|
||||||
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
||||||
# Load from a PyTorch checkpoint
|
# Load from a PyTorch checkpoint
|
||||||
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
||||||
else:
|
else:
|
||||||
raise EnvironmentError(
|
raise EnvironmentError(
|
||||||
f"Error no file named {[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + '.index']} found in "
|
f"Error no file named {[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + '.index', FLAX_WEIGHTS_NAME]} found in "
|
||||||
f"directory {pretrained_model_name_or_path} or `from_tf` set to False."
|
f"directory {pretrained_model_name_or_path} or `from_tf` and `from_flax` set to False."
|
||||||
)
|
)
|
||||||
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||||
archive_file = pretrained_model_name_or_path
|
archive_file = pretrained_model_name_or_path
|
||||||
@@ -1041,9 +1055,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
|||||||
)
|
)
|
||||||
archive_file = pretrained_model_name_or_path + ".index"
|
archive_file = pretrained_model_name_or_path + ".index"
|
||||||
else:
|
else:
|
||||||
|
# set correct filename
|
||||||
|
if from_tf:
|
||||||
|
filename = TF2_WEIGHTS_NAME
|
||||||
|
elif from_flax:
|
||||||
|
filename = FLAX_WEIGHTS_NAME
|
||||||
|
else:
|
||||||
|
filename = WEIGHTS_NAME
|
||||||
|
|
||||||
archive_file = hf_bucket_url(
|
archive_file = hf_bucket_url(
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
filename=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME),
|
filename=filename,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
mirror=mirror,
|
mirror=mirror,
|
||||||
)
|
)
|
||||||
@@ -1090,7 +1112,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
|||||||
else:
|
else:
|
||||||
model = cls(config, *model_args, **model_kwargs)
|
model = cls(config, *model_args, **model_kwargs)
|
||||||
|
|
||||||
if state_dict is None and not from_tf:
|
if state_dict is None and not (from_tf or from_flax):
|
||||||
try:
|
try:
|
||||||
state_dict = torch.load(resolved_archive_file, map_location="cpu")
|
state_dict = torch.load(resolved_archive_file, map_location="cpu")
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -1120,6 +1142,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
|||||||
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
|
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
elif from_flax:
|
||||||
|
try:
|
||||||
|
from .modeling_flax_pytorch_utils import load_flax_checkpoint_in_pytorch_model
|
||||||
|
|
||||||
|
model = load_flax_checkpoint_in_pytorch_model(model, resolved_archive_file)
|
||||||
|
except ImportError:
|
||||||
|
logger.error(
|
||||||
|
"Loading a Flax model in PyTorch, requires both PyTorch and Flax to be installed. Please see "
|
||||||
|
"https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation instructions."
|
||||||
|
)
|
||||||
|
raise
|
||||||
else:
|
else:
|
||||||
# Convert old format to new format if needed from a PyTorch state_dict
|
# Convert old format to new format if needed from a PyTorch state_dict
|
||||||
old_keys = []
|
old_keys = []
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team.
|
# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -95,68 +95,6 @@ BERT_INPUTS_DOCSTRING = r"""
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertLayerNorm(nn.Module):
|
|
||||||
"""
|
|
||||||
Layer normalization (https://arxiv.org/abs/1607.06450). Operates on the last axis of the input data.
|
|
||||||
"""
|
|
||||||
|
|
||||||
hidden_size: int
|
|
||||||
epsilon: float = 1e-6
|
|
||||||
dtype: jnp.dtype = jnp.float32
|
|
||||||
use_bias: bool = True
|
|
||||||
scale: bool = True
|
|
||||||
scale_init: Callable[..., np.ndarray] = jax.nn.initializers.ones
|
|
||||||
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
|
|
||||||
|
|
||||||
def setup(self):
|
|
||||||
self.weight = self.param("weight", self.scale_init, (self.hidden_size,))
|
|
||||||
self.bias = self.param("bias", self.scale_init, (self.hidden_size,))
|
|
||||||
|
|
||||||
def __call__(self, x):
|
|
||||||
"""
|
|
||||||
Applies layer normalization on the input. It normalizes the activations of the layer for each given example in
|
|
||||||
a batch independently, rather than across a batch like Batch Normalization. i.e. applies a transformation that
|
|
||||||
maintains the mean activation within each example close to 0 and the activation standard deviation close to 1
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x: the inputs
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Normalized inputs (the same shape as inputs).
|
|
||||||
"""
|
|
||||||
mean = jnp.mean(x, axis=-1, keepdims=True)
|
|
||||||
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
|
|
||||||
var = mean2 - jax.lax.square(mean)
|
|
||||||
mul = jax.lax.rsqrt(var + self.epsilon)
|
|
||||||
|
|
||||||
if self.scale:
|
|
||||||
mul = mul * jnp.asarray(self.weight)
|
|
||||||
y = (x - mean) * mul
|
|
||||||
|
|
||||||
if self.use_bias:
|
|
||||||
y = y + jnp.asarray(self.bias)
|
|
||||||
return y
|
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertEmbedding(nn.Module):
|
|
||||||
"""
|
|
||||||
Specify a new class for doing the embedding stuff as Flax's one use 'embedding' for the parameter name and PyTorch
|
|
||||||
use 'weight'
|
|
||||||
"""
|
|
||||||
|
|
||||||
vocab_size: int
|
|
||||||
hidden_size: int
|
|
||||||
initializer_range: float
|
|
||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
|
||||||
|
|
||||||
def setup(self):
|
|
||||||
init_fn: Callable[..., np.ndarray] = jax.nn.initializers.normal(stddev=self.initializer_range)
|
|
||||||
self.embeddings = self.param("weight", init_fn, (self.vocab_size, self.hidden_size))
|
|
||||||
|
|
||||||
def __call__(self, input_ids):
|
|
||||||
return jnp.take(self.embeddings, input_ids, axis=0)
|
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertEmbeddings(nn.Module):
|
class FlaxBertEmbeddings(nn.Module):
|
||||||
"""Construct the embeddings from word, position and token_type embeddings."""
|
"""Construct the embeddings from word, position and token_type embeddings."""
|
||||||
|
|
||||||
@@ -164,35 +102,37 @@ class FlaxBertEmbeddings(nn.Module):
|
|||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
self.word_embeddings = FlaxBertEmbedding(
|
self.word_embeddings = nn.Embed(
|
||||||
self.config.vocab_size,
|
self.config.vocab_size,
|
||||||
self.config.hidden_size,
|
self.config.hidden_size,
|
||||||
initializer_range=self.config.initializer_range,
|
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
self.position_embeddings = FlaxBertEmbedding(
|
self.position_embeddings = nn.Embed(
|
||||||
self.config.max_position_embeddings,
|
self.config.max_position_embeddings,
|
||||||
self.config.hidden_size,
|
self.config.hidden_size,
|
||||||
initializer_range=self.config.initializer_range,
|
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
self.token_type_embeddings = FlaxBertEmbedding(
|
self.token_type_embeddings = nn.Embed(
|
||||||
self.config.type_vocab_size,
|
self.config.type_vocab_size,
|
||||||
self.config.hidden_size,
|
self.config.hidden_size,
|
||||||
initializer_range=self.config.initializer_range,
|
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
self.LayerNorm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype)
|
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
||||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||||
|
|
||||||
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
|
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
|
||||||
|
batch_size, sequence_length = input_ids.shape
|
||||||
# Embed
|
# Embed
|
||||||
inputs_embeds = self.word_embeddings(jnp.atleast_2d(input_ids.astype("i4")))
|
inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
|
||||||
position_embeds = self.position_embeddings(jnp.atleast_2d(position_ids.astype("i4")))
|
position_embeds = self.position_embeddings(position_ids.astype("i4"))
|
||||||
token_type_embeddings = self.token_type_embeddings(jnp.atleast_2d(token_type_ids.astype("i4")))
|
token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4"))
|
||||||
|
|
||||||
# Sum all embeddings
|
# Sum all embeddings
|
||||||
hidden_states = inputs_embeds + jnp.broadcast_to(position_embeds, inputs_embeds.shape) + token_type_embeddings
|
hidden_states = inputs_embeds + token_type_embeddings + position_embeds
|
||||||
|
# hidden_states = hidden_states.reshape((batch_size, sequence_length, -1))
|
||||||
|
|
||||||
# Layer Norm
|
# Layer Norm
|
||||||
hidden_states = self.LayerNorm(hidden_states)
|
hidden_states = self.LayerNorm(hidden_states)
|
||||||
@@ -281,7 +221,7 @@ class FlaxBertSelfOutput(nn.Module):
|
|||||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
self.LayerNorm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size)
|
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
||||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||||
|
|
||||||
def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
|
def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
|
||||||
@@ -337,7 +277,7 @@ class FlaxBertOutput(nn.Module):
|
|||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||||
self.LayerNorm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype)
|
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(self, hidden_states, attention_output, deterministic: bool = True):
|
def __call__(self, hidden_states, attention_output, deterministic: bool = True):
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
@@ -372,7 +312,7 @@ class FlaxBertLayerCollection(nn.Module):
|
|||||||
]
|
]
|
||||||
|
|
||||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||||
for layer in self.layers:
|
for i, layer in enumerate(self.layers):
|
||||||
hidden_states = layer(hidden_states, attention_mask, deterministic=deterministic)
|
hidden_states = layer(hidden_states, attention_mask, deterministic=deterministic)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@@ -412,7 +352,7 @@ class FlaxBertPredictionHeadTransform(nn.Module):
|
|||||||
def setup(self):
|
def setup(self):
|
||||||
self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype)
|
self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype)
|
||||||
self.activation = ACT2FN[self.config.hidden_act]
|
self.activation = ACT2FN[self.config.hidden_act]
|
||||||
self.LayerNorm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype)
|
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(self, hidden_states):
|
def __call__(self, hidden_states):
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
@@ -423,14 +363,22 @@ class FlaxBertPredictionHeadTransform(nn.Module):
|
|||||||
class FlaxBertLMPredictionHead(nn.Module):
|
class FlaxBertLMPredictionHead(nn.Module):
|
||||||
config: BertConfig
|
config: BertConfig
|
||||||
dtype: jnp.dtype = jnp.float32
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
self.transform = FlaxBertPredictionHeadTransform(self.config, dtype=self.dtype)
|
self.transform = FlaxBertPredictionHeadTransform(self.config, dtype=self.dtype)
|
||||||
self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype)
|
self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype, use_bias=False)
|
||||||
|
self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,))
|
||||||
|
|
||||||
def __call__(self, hidden_states):
|
def __call__(self, hidden_states, shared_embedding=None):
|
||||||
hidden_states = self.transform(hidden_states)
|
hidden_states = self.transform(hidden_states)
|
||||||
hidden_states = self.decoder(hidden_states)
|
|
||||||
|
if shared_embedding is not None:
|
||||||
|
hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
|
||||||
|
else:
|
||||||
|
hidden_states = self.decoder(hidden_states)
|
||||||
|
|
||||||
|
hidden_states += self.bias
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@@ -441,8 +389,8 @@ class FlaxBertOnlyMLMHead(nn.Module):
|
|||||||
def setup(self):
|
def setup(self):
|
||||||
self.predictions = FlaxBertLMPredictionHead(self.config, dtype=self.dtype)
|
self.predictions = FlaxBertLMPredictionHead(self.config, dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(self, hidden_states):
|
def __call__(self, hidden_states, shared_embedding=None):
|
||||||
hidden_states = self.predictions(hidden_states)
|
hidden_states = self.predictions(hidden_states, shared_embedding=shared_embedding)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@@ -464,8 +412,8 @@ class FlaxBertPreTrainingHeads(nn.Module):
|
|||||||
self.predictions = FlaxBertLMPredictionHead(self.config, dtype=self.dtype)
|
self.predictions = FlaxBertLMPredictionHead(self.config, dtype=self.dtype)
|
||||||
self.seq_relationship = nn.Dense(2, dtype=self.dtype)
|
self.seq_relationship = nn.Dense(2, dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(self, hidden_states, pooled_output):
|
def __call__(self, hidden_states, pooled_output, shared_embedding=None):
|
||||||
prediction_scores = self.predictions(hidden_states)
|
prediction_scores = self.predictions(hidden_states, shared_embedding=shared_embedding)
|
||||||
seq_relationship_score = self.seq_relationship(pooled_output)
|
seq_relationship_score = self.seq_relationship(pooled_output)
|
||||||
return prediction_scores, seq_relationship_score
|
return prediction_scores, seq_relationship_score
|
||||||
|
|
||||||
@@ -490,7 +438,7 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
# init input tensors
|
# init input tensors
|
||||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||||
token_type_ids = jnp.ones_like(input_ids)
|
token_type_ids = jnp.ones_like(input_ids)
|
||||||
position_ids = jnp.arange(jnp.atleast_2d(input_ids).shape[-1])
|
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
|
||||||
attention_mask = jnp.ones_like(input_ids)
|
attention_mask = jnp.ones_like(input_ids)
|
||||||
|
|
||||||
params_rng, dropout_rng = jax.random.split(rng)
|
params_rng, dropout_rng = jax.random.split(rng)
|
||||||
@@ -514,7 +462,7 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
token_type_ids = jnp.ones_like(input_ids)
|
token_type_ids = jnp.ones_like(input_ids)
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = jnp.arange(jnp.atleast_2d(input_ids).shape[-1])
|
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
|
||||||
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = jnp.ones_like(input_ids)
|
attention_mask = jnp.ones_like(input_ids)
|
||||||
@@ -546,7 +494,6 @@ class FlaxBertModule(nn.Module):
|
|||||||
self.pooler = FlaxBertPooler(self.config, dtype=self.dtype)
|
self.pooler = FlaxBertPooler(self.config, dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
|
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
|
||||||
|
|
||||||
hidden_states = self.embeddings(
|
hidden_states = self.embeddings(
|
||||||
input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
|
input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
|
||||||
)
|
)
|
||||||
@@ -582,7 +529,15 @@ class FlaxBertForPreTrainingModule(nn.Module):
|
|||||||
hidden_states, pooled_output = self.bert(
|
hidden_states, pooled_output = self.bert(
|
||||||
input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic
|
input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic
|
||||||
)
|
)
|
||||||
prediction_scores, seq_relationship_score = self.cls(hidden_states, pooled_output)
|
|
||||||
|
if self.config.tie_word_embeddings:
|
||||||
|
shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
|
||||||
|
else:
|
||||||
|
shared_embedding = None
|
||||||
|
|
||||||
|
prediction_scores, seq_relationship_score = self.cls(
|
||||||
|
hidden_states, pooled_output, shared_embedding=shared_embedding
|
||||||
|
)
|
||||||
|
|
||||||
return (prediction_scores, seq_relationship_score)
|
return (prediction_scores, seq_relationship_score)
|
||||||
|
|
||||||
@@ -612,8 +567,13 @@ class FlaxBertForMaskedLMModule(nn.Module):
|
|||||||
# Model
|
# Model
|
||||||
hidden_states = self.bert(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic)
|
hidden_states = self.bert(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic)
|
||||||
|
|
||||||
|
if self.config.tie_word_embeddings:
|
||||||
|
shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
|
||||||
|
else:
|
||||||
|
shared_embedding = None
|
||||||
|
|
||||||
# Compute the prediction scores
|
# Compute the prediction scores
|
||||||
logits = self.cls(hidden_states)
|
logits = self.cls(hidden_states, shared_embedding=shared_embedding)
|
||||||
|
|
||||||
return (logits,)
|
return (logits,)
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team.
|
# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -12,9 +12,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Callable, Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
@@ -110,70 +108,6 @@ ROBERTA_INPUTS_DOCSTRING = r"""
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerNorm with Bert->Roberta
|
|
||||||
class FlaxRobertaLayerNorm(nn.Module):
|
|
||||||
"""
|
|
||||||
Layer normalization (https://arxiv.org/abs/1607.06450). Operates on the last axis of the input data.
|
|
||||||
"""
|
|
||||||
|
|
||||||
hidden_size: int
|
|
||||||
epsilon: float = 1e-6
|
|
||||||
dtype: jnp.dtype = jnp.float32
|
|
||||||
use_bias: bool = True
|
|
||||||
scale: bool = True
|
|
||||||
scale_init: Callable[..., np.ndarray] = jax.nn.initializers.ones
|
|
||||||
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
|
|
||||||
|
|
||||||
def setup(self):
|
|
||||||
self.weight = self.param("weight", self.scale_init, (self.hidden_size,))
|
|
||||||
self.bias = self.param("bias", self.scale_init, (self.hidden_size,))
|
|
||||||
|
|
||||||
def __call__(self, x):
|
|
||||||
"""
|
|
||||||
Applies layer normalization on the input. It normalizes the activations of the layer for each given example in
|
|
||||||
a batch independently, rather than across a batch like Batch Normalization. i.e. applies a transformation that
|
|
||||||
maintains the mean activation within each example close to 0 and the activation standard deviation close to 1
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x: the inputs
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Normalized inputs (the same shape as inputs).
|
|
||||||
"""
|
|
||||||
mean = jnp.mean(x, axis=-1, keepdims=True)
|
|
||||||
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
|
|
||||||
var = mean2 - jax.lax.square(mean)
|
|
||||||
mul = jax.lax.rsqrt(var + self.epsilon)
|
|
||||||
|
|
||||||
if self.scale:
|
|
||||||
mul = mul * jnp.asarray(self.weight)
|
|
||||||
y = (x - mean) * mul
|
|
||||||
|
|
||||||
if self.use_bias:
|
|
||||||
y = y + jnp.asarray(self.bias)
|
|
||||||
return y
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbedding with Bert->Roberta
|
|
||||||
class FlaxRobertaEmbedding(nn.Module):
|
|
||||||
"""
|
|
||||||
Specify a new class for doing the embedding stuff as Flax's one use 'embedding' for the parameter name and PyTorch
|
|
||||||
use 'weight'
|
|
||||||
"""
|
|
||||||
|
|
||||||
vocab_size: int
|
|
||||||
hidden_size: int
|
|
||||||
initializer_range: float
|
|
||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
|
||||||
|
|
||||||
def setup(self):
|
|
||||||
init_fn: Callable[..., np.ndarray] = jax.nn.initializers.normal(stddev=self.initializer_range)
|
|
||||||
self.embeddings = self.param("weight", init_fn, (self.vocab_size, self.hidden_size))
|
|
||||||
|
|
||||||
def __call__(self, input_ids):
|
|
||||||
return jnp.take(self.embeddings, input_ids, axis=0)
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings with Bert->Roberta
|
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings with Bert->Roberta
|
||||||
class FlaxRobertaEmbeddings(nn.Module):
|
class FlaxRobertaEmbeddings(nn.Module):
|
||||||
"""Construct the embeddings from word, position and token_type embeddings."""
|
"""Construct the embeddings from word, position and token_type embeddings."""
|
||||||
@@ -182,35 +116,37 @@ class FlaxRobertaEmbeddings(nn.Module):
|
|||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
self.word_embeddings = FlaxRobertaEmbedding(
|
self.word_embeddings = nn.Embed(
|
||||||
self.config.vocab_size,
|
self.config.vocab_size,
|
||||||
self.config.hidden_size,
|
self.config.hidden_size,
|
||||||
initializer_range=self.config.initializer_range,
|
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
self.position_embeddings = FlaxRobertaEmbedding(
|
self.position_embeddings = nn.Embed(
|
||||||
self.config.max_position_embeddings,
|
self.config.max_position_embeddings,
|
||||||
self.config.hidden_size,
|
self.config.hidden_size,
|
||||||
initializer_range=self.config.initializer_range,
|
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
self.token_type_embeddings = FlaxRobertaEmbedding(
|
self.token_type_embeddings = nn.Embed(
|
||||||
self.config.type_vocab_size,
|
self.config.type_vocab_size,
|
||||||
self.config.hidden_size,
|
self.config.hidden_size,
|
||||||
initializer_range=self.config.initializer_range,
|
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
self.LayerNorm = FlaxRobertaLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype)
|
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
||||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||||
|
|
||||||
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
|
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
|
||||||
|
batch_size, sequence_length = input_ids.shape
|
||||||
# Embed
|
# Embed
|
||||||
inputs_embeds = self.word_embeddings(jnp.atleast_2d(input_ids.astype("i4")))
|
inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
|
||||||
position_embeds = self.position_embeddings(jnp.atleast_2d(position_ids.astype("i4")))
|
position_embeds = self.position_embeddings(position_ids.astype("i4"))
|
||||||
token_type_embeddings = self.token_type_embeddings(jnp.atleast_2d(token_type_ids.astype("i4")))
|
token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4"))
|
||||||
|
|
||||||
# Sum all embeddings
|
# Sum all embeddings
|
||||||
hidden_states = inputs_embeds + jnp.broadcast_to(position_embeds, inputs_embeds.shape) + token_type_embeddings
|
hidden_states = inputs_embeds + token_type_embeddings + position_embeds
|
||||||
|
# hidden_states = hidden_states.reshape((batch_size, sequence_length, -1))
|
||||||
|
|
||||||
# Layer Norm
|
# Layer Norm
|
||||||
hidden_states = self.LayerNorm(hidden_states)
|
hidden_states = self.LayerNorm(hidden_states)
|
||||||
@@ -301,7 +237,7 @@ class FlaxRobertaSelfOutput(nn.Module):
|
|||||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
self.LayerNorm = FlaxRobertaLayerNorm(hidden_size=self.config.hidden_size)
|
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
||||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||||
|
|
||||||
def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
|
def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
|
||||||
@@ -360,7 +296,7 @@ class FlaxRobertaOutput(nn.Module):
|
|||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||||
self.LayerNorm = FlaxRobertaLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype)
|
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(self, hidden_states, attention_output, deterministic: bool = True):
|
def __call__(self, hidden_states, attention_output, deterministic: bool = True):
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
@@ -397,7 +333,7 @@ class FlaxRobertaLayerCollection(nn.Module):
|
|||||||
]
|
]
|
||||||
|
|
||||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||||
for layer in self.layers:
|
for i, layer in enumerate(self.layers):
|
||||||
hidden_states = layer(hidden_states, attention_mask, deterministic=deterministic)
|
hidden_states = layer(hidden_states, attention_mask, deterministic=deterministic)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@@ -515,7 +451,6 @@ class FlaxRobertaModule(nn.Module):
|
|||||||
self.pooler = FlaxRobertaPooler(self.config, dtype=self.dtype)
|
self.pooler = FlaxRobertaPooler(self.config, dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
|
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
|
||||||
|
|
||||||
hidden_states = self.embeddings(
|
hidden_states = self.embeddings(
|
||||||
input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
|
input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -28,7 +28,10 @@ if is_flax_available():
|
|||||||
|
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from transformers.modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
|
from transformers.modeling_flax_pytorch_utils import (
|
||||||
|
convert_pytorch_state_dict_to_flax,
|
||||||
|
load_flax_weights_in_pytorch_model,
|
||||||
|
)
|
||||||
|
|
||||||
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
|
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
|
||||||
|
|
||||||
@@ -83,29 +86,32 @@ class FlaxModelTesterMixin:
|
|||||||
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
|
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
|
||||||
|
|
||||||
@is_pt_flax_cross_test
|
@is_pt_flax_cross_test
|
||||||
def test_equivalence_flax_pytorch(self):
|
def test_equivalence_pt_to_flax(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
with self.subTest(model_class.__name__):
|
with self.subTest(model_class.__name__):
|
||||||
|
# prepare inputs
|
||||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
|
||||||
|
|
||||||
|
# load corresponding PyTorch class
|
||||||
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
|
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
|
||||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||||
pt_model = pt_model_class(config).eval()
|
|
||||||
|
|
||||||
|
pt_model = pt_model_class(config).eval()
|
||||||
fx_model = model_class(config, dtype=jnp.float32)
|
fx_model = model_class(config, dtype=jnp.float32)
|
||||||
|
|
||||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||||
fx_model.params = fx_state
|
fx_model.params = fx_state
|
||||||
|
|
||||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||||
|
|
||||||
fx_outputs = fx_model(**prepared_inputs_dict)
|
fx_outputs = fx_model(**prepared_inputs_dict)
|
||||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 2e-3)
|
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
pt_model.save_pretrained(tmpdirname)
|
pt_model.save_pretrained(tmpdirname)
|
||||||
@@ -116,7 +122,50 @@ class FlaxModelTesterMixin:
|
|||||||
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
|
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
|
||||||
)
|
)
|
||||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
|
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
|
||||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 5e-3)
|
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-3)
|
||||||
|
|
||||||
|
@is_pt_flax_cross_test
|
||||||
|
def test_equivalence_flax_to_pt(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
with self.subTest(model_class.__name__):
|
||||||
|
# prepare inputs
|
||||||
|
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
|
||||||
|
|
||||||
|
# load corresponding PyTorch class
|
||||||
|
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
|
||||||
|
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||||
|
|
||||||
|
pt_model = pt_model_class(config).eval()
|
||||||
|
fx_model = model_class(config, dtype=jnp.float32)
|
||||||
|
|
||||||
|
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
||||||
|
|
||||||
|
# make sure weights are tied in PyTorch
|
||||||
|
pt_model.tie_weights()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||||
|
|
||||||
|
fx_outputs = fx_model(**prepared_inputs_dict)
|
||||||
|
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||||
|
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||||
|
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
fx_model.save_pretrained(tmpdirname)
|
||||||
|
pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
|
||||||
|
)
|
||||||
|
for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded):
|
||||||
|
self.assert_almost_equals(fx_output, pt_output.numpy(), 5e-3)
|
||||||
|
|
||||||
def test_from_pretrained_save_pretrained(self):
|
def test_from_pretrained_save_pretrained(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
@@ -134,7 +183,7 @@ class FlaxModelTesterMixin:
|
|||||||
|
|
||||||
outputs_loaded = model_loaded(**prepared_inputs_dict)
|
outputs_loaded = model_loaded(**prepared_inputs_dict)
|
||||||
for output_loaded, output in zip(outputs_loaded, outputs):
|
for output_loaded, output in zip(outputs_loaded, outputs):
|
||||||
self.assert_almost_equals(output_loaded, output, 5e-3)
|
self.assert_almost_equals(output_loaded, output, 1e-3)
|
||||||
|
|
||||||
def test_jit_compilation(self):
|
def test_jit_compilation(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|||||||
Reference in New Issue
Block a user