[WIP][Flax] Add general conversion script (#10809)
* save intermediate * finish first version * delete some more * improve import * fix roberta * Update src/transformers/modeling_flax_pytorch_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_flax_pytorch_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * small corrections * apply all comments * fix deterministic * make fix-copies Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
604c085087
commit
8780caa388
100
src/transformers/modeling_flax_pytorch_utils.py
Normal file
100
src/transformers/modeling_flax_pytorch_utils.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2021 The HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" PyTorch - TF 2.0 general utilities."""
|
||||||
|
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from flax.core.frozen_dict import unfreeze
|
||||||
|
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||||
|
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
#####################
|
||||||
|
# PyTorch => Flax #
|
||||||
|
#####################
|
||||||
|
|
||||||
|
|
||||||
|
def load_pytorch_checkpoint_in_flax_state_dict(flax_model, pytorch_checkpoint_path, allow_missing_keys=False):
|
||||||
|
"""Load pytorch checkpoints in a flax model"""
|
||||||
|
try:
|
||||||
|
import torch # noqa: F401
|
||||||
|
except ImportError:
|
||||||
|
logger.error(
|
||||||
|
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see "
|
||||||
|
"https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation instructions."
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
pt_path = os.path.abspath(pytorch_checkpoint_path)
|
||||||
|
logger.info("Loading PyTorch weights from {}".format(pt_path))
|
||||||
|
|
||||||
|
pt_state_dict = torch.load(pt_path, map_location="cpu")
|
||||||
|
logger.info("PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values())} parameters.")
|
||||||
|
|
||||||
|
flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
|
||||||
|
|
||||||
|
return flax_state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
|
||||||
|
# convert pytorch tensor to numpy
|
||||||
|
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
|
||||||
|
|
||||||
|
random_flax_state_dict = flatten_dict(unfreeze(flax_model.params))
|
||||||
|
flax_state_dict = {}
|
||||||
|
|
||||||
|
remove_base_model_prefix = (flax_model.base_model_prefix not in flax_model.params) and (
|
||||||
|
flax_model.base_model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()])
|
||||||
|
)
|
||||||
|
add_base_model_prefix = (flax_model.base_model_prefix in flax_model.params) and (
|
||||||
|
flax_model.base_model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()])
|
||||||
|
)
|
||||||
|
|
||||||
|
# Need to change some parameters name to match Flax names so that we don't have to fork any layer
|
||||||
|
for pt_key, pt_tensor in pt_state_dict.items():
|
||||||
|
|
||||||
|
pt_tuple_key = tuple(pt_key.split("."))
|
||||||
|
|
||||||
|
has_base_model_prefix = pt_tuple_key[0] == flax_model.base_model_prefix
|
||||||
|
require_base_model_prefix = (flax_model.base_model_prefix,) + pt_tuple_key in random_flax_state_dict
|
||||||
|
|
||||||
|
if remove_base_model_prefix and has_base_model_prefix:
|
||||||
|
pt_tuple_key = pt_tuple_key[1:]
|
||||||
|
elif add_base_model_prefix and require_base_model_prefix:
|
||||||
|
pt_tuple_key = (flax_model.base_model_prefix,) + pt_tuple_key
|
||||||
|
|
||||||
|
if pt_tuple_key[-1] == "weight" and pt_tuple_key not in random_flax_state_dict:
|
||||||
|
pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
|
||||||
|
pt_tensor = pt_tensor.T
|
||||||
|
elif pt_tuple_key[-1] == "gamma":
|
||||||
|
pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
|
||||||
|
elif pt_tuple_key[-1] == "beta":
|
||||||
|
pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
|
||||||
|
|
||||||
|
if pt_tuple_key in random_flax_state_dict:
|
||||||
|
if random_flax_state_dict[pt_tuple_key].shape != pt_tensor.shape:
|
||||||
|
raise ValueError(
|
||||||
|
"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape {random_flax_state_dict[pt_tuple_key].shape}, but is {pt_tensor.shape}."
|
||||||
|
)
|
||||||
|
|
||||||
|
# add unexpected weight so that warning is thrown
|
||||||
|
flax_state_dict[pt_tuple_key] = pt_tensor
|
||||||
|
|
||||||
|
return unflatten_dict(flax_state_dict)
|
||||||
@@ -14,7 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pickle import UnpicklingError
|
from pickle import UnpicklingError
|
||||||
from typing import Dict, Set, Tuple, Union
|
from typing import Dict, Set, Tuple, Union
|
||||||
@@ -29,6 +29,7 @@ from jax.random import PRNGKey
|
|||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
from .file_utils import FLAX_WEIGHTS_NAME, WEIGHTS_NAME, cached_path, hf_bucket_url, is_offline_mode, is_remote_url
|
from .file_utils import FLAX_WEIGHTS_NAME, WEIGHTS_NAME, cached_path, hf_bucket_url, is_offline_mode, is_remote_url
|
||||||
|
from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -121,11 +122,6 @@ class FlaxPreTrainedModel(ABC):
|
|||||||
)
|
)
|
||||||
self._params = freeze(params)
|
self._params = freeze(params)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@abstractmethod
|
|
||||||
def convert_from_pytorch(pt_state: Dict, config: PretrainedConfig) -> Dict:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
cls,
|
cls,
|
||||||
@@ -307,25 +303,18 @@ class FlaxPreTrainedModel(ABC):
|
|||||||
else:
|
else:
|
||||||
resolved_archive_file = None
|
resolved_archive_file = None
|
||||||
|
|
||||||
# Instantiate model.
|
|
||||||
with open(resolved_archive_file, "rb") as state_f:
|
|
||||||
try:
|
|
||||||
if from_pt:
|
|
||||||
import torch
|
|
||||||
|
|
||||||
state = torch.load(state_f)
|
|
||||||
|
|
||||||
state = convert_state_dict_from_pt(cls, state, config)
|
|
||||||
else:
|
|
||||||
state = from_bytes(cls, state_f.read())
|
|
||||||
except UnpicklingError:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"Unable to convert pytorch model {archive_file} to Flax deserializable object. "
|
|
||||||
)
|
|
||||||
|
|
||||||
# init random models
|
# init random models
|
||||||
model = cls(config, *model_args, **model_kwargs)
|
model = cls(config, *model_args, **model_kwargs)
|
||||||
|
|
||||||
|
if from_pt:
|
||||||
|
state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file)
|
||||||
|
else:
|
||||||
|
with open(resolved_archive_file, "rb") as state_f:
|
||||||
|
try:
|
||||||
|
state = from_bytes(cls, state_f.read())
|
||||||
|
except UnpicklingError:
|
||||||
|
raise EnvironmentError(f"Unable to convert {archive_file} to Flax deserializable object. ")
|
||||||
|
|
||||||
# if model is base model only use model_prefix key
|
# if model is base model only use model_prefix key
|
||||||
if cls.base_model_prefix not in dict(model.params) and cls.base_model_prefix in state:
|
if cls.base_model_prefix not in dict(model.params) and cls.base_model_prefix in state:
|
||||||
state = state[cls.base_model_prefix]
|
state = state[cls.base_model_prefix]
|
||||||
@@ -341,6 +330,10 @@ class FlaxPreTrainedModel(ABC):
|
|||||||
for missing_key in missing_keys:
|
for missing_key in missing_keys:
|
||||||
state[missing_key] = random_state[missing_key]
|
state[missing_key] = random_state[missing_key]
|
||||||
|
|
||||||
|
# remove unexpected keys to not be saved again
|
||||||
|
for unexpected_key in unexpected_keys:
|
||||||
|
del state[unexpected_key]
|
||||||
|
|
||||||
if len(unexpected_keys) > 0:
|
if len(unexpected_keys) > 0:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
|
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
|
||||||
@@ -393,13 +386,3 @@ class FlaxPreTrainedModel(ABC):
|
|||||||
with open(os.path.join(save_directory, FLAX_WEIGHTS_NAME), "wb") as f:
|
with open(os.path.join(save_directory, FLAX_WEIGHTS_NAME), "wb") as f:
|
||||||
model_bytes = to_bytes(self.params)
|
model_bytes = to_bytes(self.params)
|
||||||
f.write(model_bytes)
|
f.write(model_bytes)
|
||||||
|
|
||||||
|
|
||||||
def convert_state_dict_from_pt(model_class: ABC, state: Dict, config: PretrainedConfig):
|
|
||||||
"""
|
|
||||||
Converts a PyTorch parameter state dict to an equivalent Flax parameter state dict
|
|
||||||
"""
|
|
||||||
state = {k: v.numpy() for k, v in state.items()}
|
|
||||||
state = model_class.convert_from_pytorch(state, config)
|
|
||||||
state = unflatten_dict({tuple(k.split(".")): v for k, v in state.items()})
|
|
||||||
return state
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Callable, Dict, Tuple
|
from typing import Callable, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -21,6 +21,8 @@ import flax.linen as nn
|
|||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from flax.core.frozen_dict import FrozenDict
|
from flax.core.frozen_dict import FrozenDict
|
||||||
|
from flax.linen import dot_product_attention
|
||||||
|
from jax import lax
|
||||||
from jax.random import PRNGKey
|
from jax.random import PRNGKey
|
||||||
|
|
||||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
||||||
@@ -99,17 +101,15 @@ class FlaxBertLayerNorm(nn.Module):
|
|||||||
|
|
||||||
hidden_size: int
|
hidden_size: int
|
||||||
epsilon: float = 1e-6
|
epsilon: float = 1e-6
|
||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32
|
||||||
bias: bool = True # If True, bias (beta) is added.
|
use_bias: bool = True
|
||||||
scale: bool = True # If True, multiply by scale (gamma). When the next layer is linear
|
scale: bool = True
|
||||||
# (also e.g. nn.relu), this can be disabled since the scaling will be
|
|
||||||
# done by the next layer.
|
|
||||||
scale_init: Callable[..., np.ndarray] = jax.nn.initializers.ones
|
scale_init: Callable[..., np.ndarray] = jax.nn.initializers.ones
|
||||||
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
|
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
self.gamma = self.param("gamma", self.scale_init, (self.hidden_size,))
|
self.weight = self.param("weight", self.scale_init, (self.hidden_size,))
|
||||||
self.beta = self.param("beta", self.scale_init, (self.hidden_size,))
|
self.bias = self.param("bias", self.scale_init, (self.hidden_size,))
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
"""
|
"""
|
||||||
@@ -129,11 +129,11 @@ class FlaxBertLayerNorm(nn.Module):
|
|||||||
mul = jax.lax.rsqrt(var + self.epsilon)
|
mul = jax.lax.rsqrt(var + self.epsilon)
|
||||||
|
|
||||||
if self.scale:
|
if self.scale:
|
||||||
mul = mul * jnp.asarray(self.gamma)
|
mul = mul * jnp.asarray(self.weight)
|
||||||
y = (x - mean) * mul
|
y = (x - mean) * mul
|
||||||
|
|
||||||
if self.bias:
|
if self.use_bias:
|
||||||
y = y + jnp.asarray(self.beta)
|
y = y + jnp.asarray(self.bias)
|
||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
||||||
@@ -167,24 +167,21 @@ class FlaxBertEmbeddings(nn.Module):
|
|||||||
self.config.vocab_size,
|
self.config.vocab_size,
|
||||||
self.config.hidden_size,
|
self.config.hidden_size,
|
||||||
initializer_range=self.config.initializer_range,
|
initializer_range=self.config.initializer_range,
|
||||||
name="word_embeddings",
|
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
self.position_embeddings = FlaxBertEmbedding(
|
self.position_embeddings = FlaxBertEmbedding(
|
||||||
self.config.max_position_embeddings,
|
self.config.max_position_embeddings,
|
||||||
self.config.hidden_size,
|
self.config.hidden_size,
|
||||||
initializer_range=self.config.initializer_range,
|
initializer_range=self.config.initializer_range,
|
||||||
name="position_embeddings",
|
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
self.token_type_embeddings = FlaxBertEmbedding(
|
self.token_type_embeddings = FlaxBertEmbedding(
|
||||||
self.config.type_vocab_size,
|
self.config.type_vocab_size,
|
||||||
self.config.hidden_size,
|
self.config.hidden_size,
|
||||||
initializer_range=self.config.initializer_range,
|
initializer_range=self.config.initializer_range,
|
||||||
name="token_type_embeddings",
|
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
self.layer_norm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype)
|
self.LayerNorm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype)
|
||||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||||
|
|
||||||
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
|
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
|
||||||
@@ -197,35 +194,116 @@ class FlaxBertEmbeddings(nn.Module):
|
|||||||
hidden_states = inputs_embeds + jnp.broadcast_to(position_embeds, inputs_embeds.shape) + token_type_embeddings
|
hidden_states = inputs_embeds + jnp.broadcast_to(position_embeds, inputs_embeds.shape) + token_type_embeddings
|
||||||
|
|
||||||
# Layer Norm
|
# Layer Norm
|
||||||
hidden_states = self.layer_norm(hidden_states)
|
hidden_states = self.LayerNorm(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxBertSelfAttention(nn.Module):
|
||||||
|
config: BertConfig
|
||||||
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
if self.config.hidden_size % self.config.num_attention_heads != 0:
|
||||||
|
raise ValueError(
|
||||||
|
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`: {self.config.num_attention_heads}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.query = nn.Dense(
|
||||||
|
self.config.hidden_size,
|
||||||
|
dtype=self.dtype,
|
||||||
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||||
|
)
|
||||||
|
self.key = nn.Dense(
|
||||||
|
self.config.hidden_size,
|
||||||
|
dtype=self.dtype,
|
||||||
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||||
|
)
|
||||||
|
self.value = nn.Dense(
|
||||||
|
self.config.hidden_size,
|
||||||
|
dtype=self.dtype,
|
||||||
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, hidden_states, attention_mask, deterministic=True):
|
||||||
|
head_dim = self.config.hidden_size // self.config.num_attention_heads
|
||||||
|
|
||||||
|
query_states = self.query(hidden_states).reshape(
|
||||||
|
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
||||||
|
)
|
||||||
|
value_states = self.value(hidden_states).reshape(
|
||||||
|
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
||||||
|
)
|
||||||
|
key_states = self.key(hidden_states).reshape(
|
||||||
|
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert the boolean attention mask to an attention bias.
|
||||||
|
if attention_mask is not None:
|
||||||
|
# attention mask in the form of attention bias
|
||||||
|
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
||||||
|
attention_bias = lax.select(
|
||||||
|
attention_mask > 0,
|
||||||
|
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||||
|
jnp.full(attention_mask.shape, -1e10).astype(self.dtype),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attention_bias = None
|
||||||
|
|
||||||
|
dropout_rng = None
|
||||||
|
if not deterministic and self.dropout_rate > 0.0:
|
||||||
|
dropout_rng = self.make_rng("dropout")
|
||||||
|
|
||||||
|
attn_output = dot_product_attention(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
bias=attention_bias,
|
||||||
|
dropout_rng=dropout_rng,
|
||||||
|
dropout_rate=self.config.attention_probs_dropout_prob,
|
||||||
|
broadcast_dropout=True,
|
||||||
|
deterministic=deterministic,
|
||||||
|
dtype=self.dtype,
|
||||||
|
precision=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return attn_output.reshape(attn_output.shape[:2] + (-1,))
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxBertSelfOutput(nn.Module):
|
||||||
|
config: BertConfig
|
||||||
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
self.dense = nn.Dense(
|
||||||
|
self.config.hidden_size,
|
||||||
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||||
|
dtype=self.dtype,
|
||||||
|
)
|
||||||
|
self.LayerNorm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size)
|
||||||
|
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||||
|
|
||||||
|
def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
|
||||||
|
hidden_states = self.dense(hidden_states)
|
||||||
|
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
||||||
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertAttention(nn.Module):
|
class FlaxBertAttention(nn.Module):
|
||||||
config: BertConfig
|
config: BertConfig
|
||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
self.self_attention = nn.attention.SelfAttention(
|
self.self = FlaxBertSelfAttention(self.config, dtype=self.dtype)
|
||||||
num_heads=self.config.num_attention_heads,
|
self.output = FlaxBertSelfOutput(self.config, dtype=self.dtype)
|
||||||
qkv_features=self.config.hidden_size,
|
|
||||||
dropout_rate=self.config.attention_probs_dropout_prob,
|
|
||||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
|
||||||
bias_init=jax.nn.initializers.zeros,
|
|
||||||
name="self",
|
|
||||||
dtype=self.dtype,
|
|
||||||
)
|
|
||||||
self.layer_norm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype)
|
|
||||||
|
|
||||||
def __call__(self, hidden_states, attention_mask, deterministic=True):
|
def __call__(self, hidden_states, attention_mask, deterministic=True):
|
||||||
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
||||||
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
||||||
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
||||||
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
attn_output = self.self(hidden_states, attention_mask, deterministic=deterministic)
|
||||||
self_attn_output = self.self_attention(hidden_states, attention_mask, deterministic=deterministic)
|
hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
|
||||||
|
|
||||||
hidden_states = self.layer_norm(self_attn_output + hidden_states)
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@@ -237,7 +315,6 @@ class FlaxBertIntermediate(nn.Module):
|
|||||||
self.dense = nn.Dense(
|
self.dense = nn.Dense(
|
||||||
self.config.intermediate_size,
|
self.config.intermediate_size,
|
||||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||||
name="dense",
|
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
self.activation = ACT2FN[self.config.hidden_act]
|
self.activation = ACT2FN[self.config.hidden_act]
|
||||||
@@ -256,16 +333,15 @@ class FlaxBertOutput(nn.Module):
|
|||||||
self.dense = nn.Dense(
|
self.dense = nn.Dense(
|
||||||
self.config.hidden_size,
|
self.config.hidden_size,
|
||||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||||
name="dense",
|
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||||
self.layer_norm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype)
|
self.LayerNorm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(self, hidden_states, attention_output, deterministic: bool = True):
|
def __call__(self, hidden_states, attention_output, deterministic: bool = True):
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
||||||
hidden_states = self.layer_norm(hidden_states + attention_output)
|
hidden_states = self.LayerNorm(hidden_states + attention_output)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@@ -274,9 +350,9 @@ class FlaxBertLayer(nn.Module):
|
|||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
self.attention = FlaxBertAttention(self.config, name="attention", dtype=self.dtype)
|
self.attention = FlaxBertAttention(self.config, dtype=self.dtype)
|
||||||
self.intermediate = FlaxBertIntermediate(self.config, name="intermediate", dtype=self.dtype)
|
self.intermediate = FlaxBertIntermediate(self.config, dtype=self.dtype)
|
||||||
self.output = FlaxBertOutput(self.config, name="output", dtype=self.dtype)
|
self.output = FlaxBertOutput(self.config, dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||||
attention_output = self.attention(hidden_states, attention_mask, deterministic=deterministic)
|
attention_output = self.attention(hidden_states, attention_mask, deterministic=deterministic)
|
||||||
@@ -305,10 +381,10 @@ class FlaxBertEncoder(nn.Module):
|
|||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
self.layers = FlaxBertLayerCollection(self.config, name="layer", dtype=self.dtype)
|
self.layer = FlaxBertLayerCollection(self.config, dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||||
return self.layers(hidden_states, attention_mask, deterministic=deterministic)
|
return self.layer(hidden_states, attention_mask, deterministic=deterministic)
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertPooler(nn.Module):
|
class FlaxBertPooler(nn.Module):
|
||||||
@@ -319,7 +395,6 @@ class FlaxBertPooler(nn.Module):
|
|||||||
self.dense = nn.Dense(
|
self.dense = nn.Dense(
|
||||||
self.config.hidden_size,
|
self.config.hidden_size,
|
||||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||||
name="dense",
|
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -334,14 +409,14 @@ class FlaxBertPredictionHeadTransform(nn.Module):
|
|||||||
dtype: jnp.dtype = jnp.float32
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
self.dense = nn.Dense(self.config.hidden_size, name="dense", dtype=self.dtype)
|
self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype)
|
||||||
self.activation = ACT2FN[self.config.hidden_act]
|
self.activation = ACT2FN[self.config.hidden_act]
|
||||||
self.layer_norm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype)
|
self.LayerNorm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(self, hidden_states):
|
def __call__(self, hidden_states):
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.activation(hidden_states)
|
hidden_states = self.activation(hidden_states)
|
||||||
return self.layer_norm(hidden_states)
|
return self.LayerNorm(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertLMPredictionHead(nn.Module):
|
class FlaxBertLMPredictionHead(nn.Module):
|
||||||
@@ -349,14 +424,10 @@ class FlaxBertLMPredictionHead(nn.Module):
|
|||||||
dtype: jnp.dtype = jnp.float32
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
self.transform = FlaxBertPredictionHeadTransform(self.config, name="transform", dtype=self.dtype)
|
self.transform = FlaxBertPredictionHeadTransform(self.config, dtype=self.dtype)
|
||||||
self.decoder = nn.Dense(self.config.vocab_size, name="decoder", dtype=self.dtype)
|
self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(self, hidden_states):
|
def __call__(self, hidden_states):
|
||||||
# TODO: The output weights are the same as the input embeddings, but there is
|
|
||||||
# an output-only bias for each token.
|
|
||||||
# Need a link between the two variables so that the bias is correctly
|
|
||||||
# resized with `resize_token_embeddings`
|
|
||||||
hidden_states = self.transform(hidden_states)
|
hidden_states = self.transform(hidden_states)
|
||||||
hidden_states = self.decoder(hidden_states)
|
hidden_states = self.decoder(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
@@ -367,10 +438,10 @@ class FlaxBertOnlyMLMHead(nn.Module):
|
|||||||
dtype: jnp.dtype = jnp.float32
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
self.mlm_head = FlaxBertLMPredictionHead(self.config, name="predictions", dtype=self.dtype)
|
self.predictions = FlaxBertLMPredictionHead(self.config, dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(self, hidden_states):
|
def __call__(self, hidden_states):
|
||||||
hidden_states = self.mlm_head(hidden_states)
|
hidden_states = self.predictions(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@@ -405,85 +476,6 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
|
|
||||||
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"]
|
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"]
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def convert_from_pytorch(pt_state: Dict, config: BertConfig) -> Dict:
|
|
||||||
jax_state = dict(pt_state)
|
|
||||||
|
|
||||||
# Need to change some parameters name to match Flax names so that we don't have to fork any layer
|
|
||||||
for key, tensor in pt_state.items():
|
|
||||||
# Key parts
|
|
||||||
key_parts = set(key.split("."))
|
|
||||||
|
|
||||||
# Every dense layer has "kernel" parameters instead of "weight"
|
|
||||||
if "dense.weight" in key:
|
|
||||||
del jax_state[key]
|
|
||||||
key = key.replace("weight", "kernel")
|
|
||||||
jax_state[key] = tensor
|
|
||||||
|
|
||||||
if "decoder.weight" in key:
|
|
||||||
del jax_state[key]
|
|
||||||
key = key.replace("weight", "kernel")
|
|
||||||
jax_state[key] = tensor.T
|
|
||||||
|
|
||||||
# SelfAttention needs also to replace "weight" by "kernel"
|
|
||||||
if {"query", "key", "value"} & key_parts:
|
|
||||||
|
|
||||||
# Flax SelfAttention decomposes the heads (num_head, size // num_heads)
|
|
||||||
if "bias" in key:
|
|
||||||
jax_state[key] = tensor.reshape((config.num_attention_heads, -1))
|
|
||||||
elif "weight":
|
|
||||||
del jax_state[key]
|
|
||||||
key = key.replace("weight", "kernel")
|
|
||||||
tensor = tensor.reshape((config.num_attention_heads, -1, config.hidden_size)).transpose((2, 0, 1))
|
|
||||||
jax_state[key] = tensor
|
|
||||||
|
|
||||||
# SelfAttention output is not a separate layer, remove one nesting
|
|
||||||
if "attention.output.dense" in key:
|
|
||||||
del jax_state[key]
|
|
||||||
key = key.replace("attention.output.dense", "attention.self.out")
|
|
||||||
jax_state[key] = tensor
|
|
||||||
|
|
||||||
# SelfAttention output is not a separate layer, remove nesting on layer norm
|
|
||||||
if "attention.output.LayerNorm" in key:
|
|
||||||
del jax_state[key]
|
|
||||||
key = key.replace("attention.output.LayerNorm", "attention.LayerNorm")
|
|
||||||
jax_state[key] = tensor
|
|
||||||
|
|
||||||
# There are some transposed parameters w.r.t their PyTorch counterpart
|
|
||||||
if "intermediate.dense.kernel" in key or "output.dense.kernel" in key or "transform.dense.kernel" in key:
|
|
||||||
jax_state[key] = tensor.T
|
|
||||||
|
|
||||||
# Self Attention output projection needs to be transposed
|
|
||||||
if "out.kernel" in key:
|
|
||||||
jax_state[key] = tensor.reshape((config.hidden_size, config.num_attention_heads, -1)).transpose(
|
|
||||||
1, 2, 0
|
|
||||||
)
|
|
||||||
|
|
||||||
# Pooler needs to transpose its kernel
|
|
||||||
if "pooler.dense.kernel" in key:
|
|
||||||
jax_state[key] = tensor.T
|
|
||||||
|
|
||||||
# Hack to correctly load some pytorch models
|
|
||||||
if "predictions.bias" in key:
|
|
||||||
del jax_state[key]
|
|
||||||
jax_state[".".join(key.split(".")[:2]) + ".decoder.bias"] = tensor
|
|
||||||
|
|
||||||
# Handle LayerNorm conversion
|
|
||||||
if "LayerNorm" in key:
|
|
||||||
del jax_state[key]
|
|
||||||
|
|
||||||
# Replace LayerNorm by layer_norm
|
|
||||||
new_key = key.replace("LayerNorm", "layer_norm")
|
|
||||||
|
|
||||||
if "weight" in key:
|
|
||||||
new_key = new_key.replace("weight", "gamma")
|
|
||||||
elif "bias" in key:
|
|
||||||
new_key = new_key.replace("bias", "beta")
|
|
||||||
|
|
||||||
jax_state[new_key] = tensor
|
|
||||||
|
|
||||||
return jax_state
|
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
|
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
|
||||||
@@ -541,9 +533,9 @@ class FlaxBertModule(nn.Module):
|
|||||||
add_pooling_layer: bool = True
|
add_pooling_layer: bool = True
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
self.embeddings = FlaxBertEmbeddings(self.config, name="embeddings", dtype=self.dtype)
|
self.embeddings = FlaxBertEmbeddings(self.config, dtype=self.dtype)
|
||||||
self.encoder = FlaxBertEncoder(self.config, name="encoder", dtype=self.dtype)
|
self.encoder = FlaxBertEncoder(self.config, dtype=self.dtype)
|
||||||
self.pooler = FlaxBertPooler(self.config, name="pooler", dtype=self.dtype)
|
self.pooler = FlaxBertPooler(self.config, dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
|
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
|
||||||
|
|
||||||
@@ -602,15 +594,13 @@ class FlaxBertForMaskedLMModule(nn.Module):
|
|||||||
dtype: jnp.dtype = jnp.float32
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
self.encoder = FlaxBertModule(
|
self.bert = FlaxBertModule(
|
||||||
config=self.config,
|
config=self.config,
|
||||||
add_pooling_layer=False,
|
add_pooling_layer=False,
|
||||||
name="bert",
|
|
||||||
)
|
)
|
||||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||||
self.mlm_head = FlaxBertOnlyMLMHead(
|
self.cls = FlaxBertOnlyMLMHead(
|
||||||
config=self.config,
|
config=self.config,
|
||||||
name="cls",
|
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -618,12 +608,10 @@ class FlaxBertForMaskedLMModule(nn.Module):
|
|||||||
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
|
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
|
||||||
):
|
):
|
||||||
# Model
|
# Model
|
||||||
hidden_states = self.encoder(
|
hidden_states = self.bert(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic)
|
||||||
input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic
|
|
||||||
)
|
|
||||||
|
|
||||||
# Compute the prediction scores
|
# Compute the prediction scores
|
||||||
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
||||||
logits = self.mlm_head(hidden_states)
|
logits = self.cls(hidden_states)
|
||||||
|
|
||||||
return (logits,)
|
return (logits,)
|
||||||
|
|||||||
@@ -12,7 +12,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Callable, Dict, Tuple
|
from typing import Callable, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -20,6 +20,8 @@ import flax.linen as nn
|
|||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from flax.core.frozen_dict import FrozenDict
|
from flax.core.frozen_dict import FrozenDict
|
||||||
|
from flax.linen import dot_product_attention
|
||||||
|
from jax import lax
|
||||||
from jax.random import PRNGKey
|
from jax.random import PRNGKey
|
||||||
|
|
||||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
||||||
@@ -116,17 +118,15 @@ class FlaxRobertaLayerNorm(nn.Module):
|
|||||||
|
|
||||||
hidden_size: int
|
hidden_size: int
|
||||||
epsilon: float = 1e-6
|
epsilon: float = 1e-6
|
||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32
|
||||||
bias: bool = True # If True, bias (beta) is added.
|
use_bias: bool = True
|
||||||
scale: bool = True # If True, multiply by scale (gamma). When the next layer is linear
|
scale: bool = True
|
||||||
# (also e.g. nn.relu), this can be disabled since the scaling will be
|
|
||||||
# done by the next layer.
|
|
||||||
scale_init: Callable[..., np.ndarray] = jax.nn.initializers.ones
|
scale_init: Callable[..., np.ndarray] = jax.nn.initializers.ones
|
||||||
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
|
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
self.gamma = self.param("gamma", self.scale_init, (self.hidden_size,))
|
self.weight = self.param("weight", self.scale_init, (self.hidden_size,))
|
||||||
self.beta = self.param("beta", self.scale_init, (self.hidden_size,))
|
self.bias = self.param("bias", self.scale_init, (self.hidden_size,))
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
"""
|
"""
|
||||||
@@ -146,11 +146,11 @@ class FlaxRobertaLayerNorm(nn.Module):
|
|||||||
mul = jax.lax.rsqrt(var + self.epsilon)
|
mul = jax.lax.rsqrt(var + self.epsilon)
|
||||||
|
|
||||||
if self.scale:
|
if self.scale:
|
||||||
mul = mul * jnp.asarray(self.gamma)
|
mul = mul * jnp.asarray(self.weight)
|
||||||
y = (x - mean) * mul
|
y = (x - mean) * mul
|
||||||
|
|
||||||
if self.bias:
|
if self.use_bias:
|
||||||
y = y + jnp.asarray(self.beta)
|
y = y + jnp.asarray(self.bias)
|
||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
||||||
@@ -186,26 +186,21 @@ class FlaxRobertaEmbeddings(nn.Module):
|
|||||||
self.config.vocab_size,
|
self.config.vocab_size,
|
||||||
self.config.hidden_size,
|
self.config.hidden_size,
|
||||||
initializer_range=self.config.initializer_range,
|
initializer_range=self.config.initializer_range,
|
||||||
name="word_embeddings",
|
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
self.position_embeddings = FlaxRobertaEmbedding(
|
self.position_embeddings = FlaxRobertaEmbedding(
|
||||||
self.config.max_position_embeddings,
|
self.config.max_position_embeddings,
|
||||||
self.config.hidden_size,
|
self.config.hidden_size,
|
||||||
initializer_range=self.config.initializer_range,
|
initializer_range=self.config.initializer_range,
|
||||||
name="position_embeddings",
|
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
self.token_type_embeddings = FlaxRobertaEmbedding(
|
self.token_type_embeddings = FlaxRobertaEmbedding(
|
||||||
self.config.type_vocab_size,
|
self.config.type_vocab_size,
|
||||||
self.config.hidden_size,
|
self.config.hidden_size,
|
||||||
initializer_range=self.config.initializer_range,
|
initializer_range=self.config.initializer_range,
|
||||||
name="token_type_embeddings",
|
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
self.layer_norm = FlaxRobertaLayerNorm(
|
self.LayerNorm = FlaxRobertaLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype)
|
||||||
hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype
|
|
||||||
)
|
|
||||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||||
|
|
||||||
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
|
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
|
||||||
@@ -218,38 +213,119 @@ class FlaxRobertaEmbeddings(nn.Module):
|
|||||||
hidden_states = inputs_embeds + jnp.broadcast_to(position_embeds, inputs_embeds.shape) + token_type_embeddings
|
hidden_states = inputs_embeds + jnp.broadcast_to(position_embeds, inputs_embeds.shape) + token_type_embeddings
|
||||||
|
|
||||||
# Layer Norm
|
# Layer Norm
|
||||||
hidden_states = self.layer_norm(hidden_states)
|
hidden_states = self.LayerNorm(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->Roberta
|
||||||
|
class FlaxRobertaSelfAttention(nn.Module):
|
||||||
|
config: RobertaConfig
|
||||||
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
if self.config.hidden_size % self.config.num_attention_heads != 0:
|
||||||
|
raise ValueError(
|
||||||
|
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`: {self.config.num_attention_heads}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.query = nn.Dense(
|
||||||
|
self.config.hidden_size,
|
||||||
|
dtype=self.dtype,
|
||||||
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||||
|
)
|
||||||
|
self.key = nn.Dense(
|
||||||
|
self.config.hidden_size,
|
||||||
|
dtype=self.dtype,
|
||||||
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||||
|
)
|
||||||
|
self.value = nn.Dense(
|
||||||
|
self.config.hidden_size,
|
||||||
|
dtype=self.dtype,
|
||||||
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, hidden_states, attention_mask, deterministic=True):
|
||||||
|
head_dim = self.config.hidden_size // self.config.num_attention_heads
|
||||||
|
|
||||||
|
query_states = self.query(hidden_states).reshape(
|
||||||
|
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
||||||
|
)
|
||||||
|
value_states = self.value(hidden_states).reshape(
|
||||||
|
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
||||||
|
)
|
||||||
|
key_states = self.key(hidden_states).reshape(
|
||||||
|
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert the boolean attention mask to an attention bias.
|
||||||
|
if attention_mask is not None:
|
||||||
|
# attention mask in the form of attention bias
|
||||||
|
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
||||||
|
attention_bias = lax.select(
|
||||||
|
attention_mask > 0,
|
||||||
|
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||||
|
jnp.full(attention_mask.shape, -1e10).astype(self.dtype),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attention_bias = None
|
||||||
|
|
||||||
|
dropout_rng = None
|
||||||
|
if not deterministic and self.dropout_rate > 0.0:
|
||||||
|
dropout_rng = self.make_rng("dropout")
|
||||||
|
|
||||||
|
attn_output = dot_product_attention(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
bias=attention_bias,
|
||||||
|
dropout_rng=dropout_rng,
|
||||||
|
dropout_rate=self.config.attention_probs_dropout_prob,
|
||||||
|
broadcast_dropout=True,
|
||||||
|
deterministic=deterministic,
|
||||||
|
dtype=self.dtype,
|
||||||
|
precision=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return attn_output.reshape(attn_output.shape[:2] + (-1,))
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput with Bert->Roberta
|
||||||
|
class FlaxRobertaSelfOutput(nn.Module):
|
||||||
|
config: RobertaConfig
|
||||||
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
self.dense = nn.Dense(
|
||||||
|
self.config.hidden_size,
|
||||||
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||||
|
dtype=self.dtype,
|
||||||
|
)
|
||||||
|
self.LayerNorm = FlaxRobertaLayerNorm(hidden_size=self.config.hidden_size)
|
||||||
|
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||||
|
|
||||||
|
def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
|
||||||
|
hidden_states = self.dense(hidden_states)
|
||||||
|
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
||||||
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Roberta
|
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Roberta
|
||||||
class FlaxRobertaAttention(nn.Module):
|
class FlaxRobertaAttention(nn.Module):
|
||||||
config: RobertaConfig
|
config: RobertaConfig
|
||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
self.self_attention = nn.attention.SelfAttention(
|
self.self = FlaxRobertaSelfAttention(self.config, dtype=self.dtype)
|
||||||
num_heads=self.config.num_attention_heads,
|
self.output = FlaxRobertaSelfOutput(self.config, dtype=self.dtype)
|
||||||
qkv_features=self.config.hidden_size,
|
|
||||||
dropout_rate=self.config.attention_probs_dropout_prob,
|
|
||||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
|
||||||
bias_init=jax.nn.initializers.zeros,
|
|
||||||
name="self",
|
|
||||||
dtype=self.dtype,
|
|
||||||
)
|
|
||||||
self.layer_norm = FlaxRobertaLayerNorm(
|
|
||||||
hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(self, hidden_states, attention_mask, deterministic=True):
|
def __call__(self, hidden_states, attention_mask, deterministic=True):
|
||||||
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
||||||
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
||||||
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
||||||
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
attn_output = self.self(hidden_states, attention_mask, deterministic=deterministic)
|
||||||
self_attn_output = self.self_attention(hidden_states, attention_mask, deterministic=deterministic)
|
hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
|
||||||
|
|
||||||
hidden_states = self.layer_norm(self_attn_output + hidden_states)
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@@ -262,7 +338,6 @@ class FlaxRobertaIntermediate(nn.Module):
|
|||||||
self.dense = nn.Dense(
|
self.dense = nn.Dense(
|
||||||
self.config.intermediate_size,
|
self.config.intermediate_size,
|
||||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||||
name="dense",
|
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
self.activation = ACT2FN[self.config.hidden_act]
|
self.activation = ACT2FN[self.config.hidden_act]
|
||||||
@@ -282,18 +357,15 @@ class FlaxRobertaOutput(nn.Module):
|
|||||||
self.dense = nn.Dense(
|
self.dense = nn.Dense(
|
||||||
self.config.hidden_size,
|
self.config.hidden_size,
|
||||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||||
name="dense",
|
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||||
self.layer_norm = FlaxRobertaLayerNorm(
|
self.LayerNorm = FlaxRobertaLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype)
|
||||||
hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(self, hidden_states, attention_output, deterministic: bool = True):
|
def __call__(self, hidden_states, attention_output, deterministic: bool = True):
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
||||||
hidden_states = self.layer_norm(hidden_states + attention_output)
|
hidden_states = self.LayerNorm(hidden_states + attention_output)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@@ -303,9 +375,9 @@ class FlaxRobertaLayer(nn.Module):
|
|||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
self.attention = FlaxRobertaAttention(self.config, name="attention", dtype=self.dtype)
|
self.attention = FlaxRobertaAttention(self.config, dtype=self.dtype)
|
||||||
self.intermediate = FlaxRobertaIntermediate(self.config, name="intermediate", dtype=self.dtype)
|
self.intermediate = FlaxRobertaIntermediate(self.config, dtype=self.dtype)
|
||||||
self.output = FlaxRobertaOutput(self.config, name="output", dtype=self.dtype)
|
self.output = FlaxRobertaOutput(self.config, dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||||
attention_output = self.attention(hidden_states, attention_mask, deterministic=deterministic)
|
attention_output = self.attention(hidden_states, attention_mask, deterministic=deterministic)
|
||||||
@@ -336,10 +408,10 @@ class FlaxRobertaEncoder(nn.Module):
|
|||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
self.layers = FlaxRobertaLayerCollection(self.config, name="layer", dtype=self.dtype)
|
self.layer = FlaxRobertaLayerCollection(self.config, dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||||
return self.layers(hidden_states, attention_mask, deterministic=deterministic)
|
return self.layer(hidden_states, attention_mask, deterministic=deterministic)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPooler with Bert->Roberta
|
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPooler with Bert->Roberta
|
||||||
@@ -351,7 +423,6 @@ class FlaxRobertaPooler(nn.Module):
|
|||||||
self.dense = nn.Dense(
|
self.dense = nn.Dense(
|
||||||
self.config.hidden_size,
|
self.config.hidden_size,
|
||||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||||
name="dense",
|
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -370,75 +441,6 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
config_class = RobertaConfig
|
config_class = RobertaConfig
|
||||||
base_model_prefix = "roberta"
|
base_model_prefix = "roberta"
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def convert_from_pytorch(pt_state: Dict, config: RobertaConfig) -> Dict:
|
|
||||||
jax_state = dict(pt_state)
|
|
||||||
|
|
||||||
# Need to change some parameters name to match Flax names so that we don't have to fork any layer
|
|
||||||
for key, tensor in pt_state.items():
|
|
||||||
# Key parts
|
|
||||||
key_parts = set(key.split("."))
|
|
||||||
|
|
||||||
# Every dense layer has "kernel" parameters instead of "weight"
|
|
||||||
if "dense.weight" in key:
|
|
||||||
del jax_state[key]
|
|
||||||
key = key.replace("weight", "kernel")
|
|
||||||
jax_state[key] = tensor
|
|
||||||
|
|
||||||
# SelfAttention needs also to replace "weight" by "kernel"
|
|
||||||
if {"query", "key", "value"} & key_parts:
|
|
||||||
|
|
||||||
# Flax SelfAttention decomposes the heads (num_head, size // num_heads)
|
|
||||||
if "bias" in key:
|
|
||||||
jax_state[key] = tensor.reshape((config.num_attention_heads, -1))
|
|
||||||
elif "weight":
|
|
||||||
del jax_state[key]
|
|
||||||
key = key.replace("weight", "kernel")
|
|
||||||
tensor = tensor.reshape((config.num_attention_heads, -1, config.hidden_size)).transpose((2, 0, 1))
|
|
||||||
jax_state[key] = tensor
|
|
||||||
|
|
||||||
# SelfAttention output is not a separate layer, remove one nesting
|
|
||||||
if "attention.output.dense" in key:
|
|
||||||
del jax_state[key]
|
|
||||||
key = key.replace("attention.output.dense", "attention.self.out")
|
|
||||||
jax_state[key] = tensor
|
|
||||||
|
|
||||||
# SelfAttention output is not a separate layer, remove nesting on layer norm
|
|
||||||
if "attention.output.LayerNorm" in key:
|
|
||||||
del jax_state[key]
|
|
||||||
key = key.replace("attention.output.LayerNorm", "attention.LayerNorm")
|
|
||||||
jax_state[key] = tensor
|
|
||||||
|
|
||||||
# There are some transposed parameters w.r.t their PyTorch counterpart
|
|
||||||
if "intermediate.dense.kernel" in key or "output.dense.kernel" in key:
|
|
||||||
jax_state[key] = tensor.T
|
|
||||||
|
|
||||||
# Self Attention output projection needs to be transposed
|
|
||||||
if "out.kernel" in key:
|
|
||||||
jax_state[key] = tensor.reshape((config.hidden_size, config.num_attention_heads, -1)).transpose(
|
|
||||||
1, 2, 0
|
|
||||||
)
|
|
||||||
|
|
||||||
# Pooler needs to transpose its kernel
|
|
||||||
if "pooler.dense.kernel" in key:
|
|
||||||
jax_state[key] = tensor.T
|
|
||||||
|
|
||||||
# Handle LayerNorm conversion
|
|
||||||
if "LayerNorm" in key:
|
|
||||||
del jax_state[key]
|
|
||||||
|
|
||||||
# Replace LayerNorm by layer_norm
|
|
||||||
new_key = key.replace("LayerNorm", "layer_norm")
|
|
||||||
|
|
||||||
if "weight" in key:
|
|
||||||
new_key = new_key.replace("weight", "gamma")
|
|
||||||
elif "bias" in key:
|
|
||||||
new_key = new_key.replace("bias", "beta")
|
|
||||||
|
|
||||||
jax_state[new_key] = tensor
|
|
||||||
|
|
||||||
return jax_state
|
|
||||||
|
|
||||||
def init(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
def init(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||||
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||||
jnp.zeros(input_shape, dtype="i4"), None, None, None
|
jnp.zeros(input_shape, dtype="i4"), None, None, None
|
||||||
@@ -523,9 +525,9 @@ class FlaxRobertaModule(nn.Module):
|
|||||||
add_pooling_layer: bool = True
|
add_pooling_layer: bool = True
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
self.embeddings = FlaxRobertaEmbeddings(self.config, name="embeddings", dtype=self.dtype)
|
self.embeddings = FlaxRobertaEmbeddings(self.config, dtype=self.dtype)
|
||||||
self.encoder = FlaxRobertaEncoder(self.config, name="encoder", dtype=self.dtype)
|
self.encoder = FlaxRobertaEncoder(self.config, dtype=self.dtype)
|
||||||
self.pooler = FlaxRobertaPooler(self.config, name="pooler", dtype=self.dtype)
|
self.pooler = FlaxRobertaPooler(self.config, dtype=self.dtype)
|
||||||
|
|
||||||
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
|
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
|
||||||
|
|
||||||
|
|||||||
@@ -115,6 +115,6 @@ class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
|||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
for model_class_name in self.all_model_classes:
|
for model_class_name in self.all_model_classes:
|
||||||
model = model_class_name.from_pretrained("bert-base-cased")
|
model = model_class_name.from_pretrained("bert-base-cased", from_pt=True)
|
||||||
outputs = model(np.ones((1, 1)))
|
outputs = model(np.ones((1, 1)))
|
||||||
self.assertIsNotNone(outputs)
|
self.assertIsNotNone(outputs)
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ if is_flax_available():
|
|||||||
|
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from transformers.modeling_flax_utils import convert_state_dict_from_pt
|
from transformers.modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
|
||||||
|
|
||||||
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
|
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
|
||||||
|
|
||||||
@@ -79,8 +79,8 @@ class FlaxModelTesterMixin:
|
|||||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||||
pt_model = pt_model_class(config).eval()
|
pt_model = pt_model_class(config).eval()
|
||||||
|
|
||||||
fx_state = convert_state_dict_from_pt(model_class, pt_model.state_dict(), config)
|
|
||||||
fx_model = model_class(config, dtype=jnp.float32)
|
fx_model = model_class(config, dtype=jnp.float32)
|
||||||
|
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||||
fx_model.params = fx_state
|
fx_model.params = fx_state
|
||||||
|
|
||||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in inputs_dict.items()}
|
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in inputs_dict.items()}
|
||||||
|
|||||||
@@ -115,6 +115,6 @@ class FlaxRobertaModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
|||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
for model_class_name in self.all_model_classes:
|
for model_class_name in self.all_model_classes:
|
||||||
model = model_class_name.from_pretrained("roberta-base")
|
model = model_class_name.from_pretrained("roberta-base", from_pt=True)
|
||||||
outputs = model(np.ones((1, 1)))
|
outputs = model(np.ones((1, 1)))
|
||||||
self.assertIsNotNone(outputs)
|
self.assertIsNotNone(outputs)
|
||||||
|
|||||||
Reference in New Issue
Block a user