[WIP][Flax] Add general conversion script (#10809)
* save intermediate * finish first version * delete some more * improve import * fix roberta * Update src/transformers/modeling_flax_pytorch_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_flax_pytorch_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * small corrections * apply all comments * fix deterministic * make fix-copies Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
604c085087
commit
8780caa388
100
src/transformers/modeling_flax_pytorch_utils.py
Normal file
100
src/transformers/modeling_flax_pytorch_utils.py
Normal file
@@ -0,0 +1,100 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch - TF 2.0 general utilities."""
|
||||
|
||||
|
||||
import os
|
||||
|
||||
from flax.core.frozen_dict import unfreeze
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
|
||||
from .utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
#####################
|
||||
# PyTorch => Flax #
|
||||
#####################
|
||||
|
||||
|
||||
def load_pytorch_checkpoint_in_flax_state_dict(flax_model, pytorch_checkpoint_path, allow_missing_keys=False):
|
||||
"""Load pytorch checkpoints in a flax model"""
|
||||
try:
|
||||
import torch # noqa: F401
|
||||
except ImportError:
|
||||
logger.error(
|
||||
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see "
|
||||
"https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation instructions."
|
||||
)
|
||||
raise
|
||||
|
||||
pt_path = os.path.abspath(pytorch_checkpoint_path)
|
||||
logger.info("Loading PyTorch weights from {}".format(pt_path))
|
||||
|
||||
pt_state_dict = torch.load(pt_path, map_location="cpu")
|
||||
logger.info("PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values())} parameters.")
|
||||
|
||||
flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
|
||||
|
||||
return flax_state_dict
|
||||
|
||||
|
||||
def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
|
||||
# convert pytorch tensor to numpy
|
||||
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
|
||||
|
||||
random_flax_state_dict = flatten_dict(unfreeze(flax_model.params))
|
||||
flax_state_dict = {}
|
||||
|
||||
remove_base_model_prefix = (flax_model.base_model_prefix not in flax_model.params) and (
|
||||
flax_model.base_model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()])
|
||||
)
|
||||
add_base_model_prefix = (flax_model.base_model_prefix in flax_model.params) and (
|
||||
flax_model.base_model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()])
|
||||
)
|
||||
|
||||
# Need to change some parameters name to match Flax names so that we don't have to fork any layer
|
||||
for pt_key, pt_tensor in pt_state_dict.items():
|
||||
|
||||
pt_tuple_key = tuple(pt_key.split("."))
|
||||
|
||||
has_base_model_prefix = pt_tuple_key[0] == flax_model.base_model_prefix
|
||||
require_base_model_prefix = (flax_model.base_model_prefix,) + pt_tuple_key in random_flax_state_dict
|
||||
|
||||
if remove_base_model_prefix and has_base_model_prefix:
|
||||
pt_tuple_key = pt_tuple_key[1:]
|
||||
elif add_base_model_prefix and require_base_model_prefix:
|
||||
pt_tuple_key = (flax_model.base_model_prefix,) + pt_tuple_key
|
||||
|
||||
if pt_tuple_key[-1] == "weight" and pt_tuple_key not in random_flax_state_dict:
|
||||
pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
|
||||
pt_tensor = pt_tensor.T
|
||||
elif pt_tuple_key[-1] == "gamma":
|
||||
pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
|
||||
elif pt_tuple_key[-1] == "beta":
|
||||
pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
|
||||
|
||||
if pt_tuple_key in random_flax_state_dict:
|
||||
if random_flax_state_dict[pt_tuple_key].shape != pt_tensor.shape:
|
||||
raise ValueError(
|
||||
"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape {random_flax_state_dict[pt_tuple_key].shape}, but is {pt_tensor.shape}."
|
||||
)
|
||||
|
||||
# add unexpected weight so that warning is thrown
|
||||
flax_state_dict[pt_tuple_key] = pt_tensor
|
||||
|
||||
return unflatten_dict(flax_state_dict)
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from abc import ABC
|
||||
from functools import partial
|
||||
from pickle import UnpicklingError
|
||||
from typing import Dict, Set, Tuple, Union
|
||||
@@ -29,6 +29,7 @@ from jax.random import PRNGKey
|
||||
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .file_utils import FLAX_WEIGHTS_NAME, WEIGHTS_NAME, cached_path, hf_bucket_url, is_offline_mode, is_remote_url
|
||||
from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict
|
||||
from .utils import logging
|
||||
|
||||
|
||||
@@ -121,11 +122,6 @@ class FlaxPreTrainedModel(ABC):
|
||||
)
|
||||
self._params = freeze(params)
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def convert_from_pytorch(pt_state: Dict, config: PretrainedConfig) -> Dict:
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
@@ -307,25 +303,18 @@ class FlaxPreTrainedModel(ABC):
|
||||
else:
|
||||
resolved_archive_file = None
|
||||
|
||||
# Instantiate model.
|
||||
with open(resolved_archive_file, "rb") as state_f:
|
||||
try:
|
||||
if from_pt:
|
||||
import torch
|
||||
|
||||
state = torch.load(state_f)
|
||||
|
||||
state = convert_state_dict_from_pt(cls, state, config)
|
||||
else:
|
||||
state = from_bytes(cls, state_f.read())
|
||||
except UnpicklingError:
|
||||
raise EnvironmentError(
|
||||
f"Unable to convert pytorch model {archive_file} to Flax deserializable object. "
|
||||
)
|
||||
|
||||
# init random models
|
||||
model = cls(config, *model_args, **model_kwargs)
|
||||
|
||||
if from_pt:
|
||||
state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file)
|
||||
else:
|
||||
with open(resolved_archive_file, "rb") as state_f:
|
||||
try:
|
||||
state = from_bytes(cls, state_f.read())
|
||||
except UnpicklingError:
|
||||
raise EnvironmentError(f"Unable to convert {archive_file} to Flax deserializable object. ")
|
||||
|
||||
# if model is base model only use model_prefix key
|
||||
if cls.base_model_prefix not in dict(model.params) and cls.base_model_prefix in state:
|
||||
state = state[cls.base_model_prefix]
|
||||
@@ -341,6 +330,10 @@ class FlaxPreTrainedModel(ABC):
|
||||
for missing_key in missing_keys:
|
||||
state[missing_key] = random_state[missing_key]
|
||||
|
||||
# remove unexpected keys to not be saved again
|
||||
for unexpected_key in unexpected_keys:
|
||||
del state[unexpected_key]
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warning(
|
||||
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
|
||||
@@ -393,13 +386,3 @@ class FlaxPreTrainedModel(ABC):
|
||||
with open(os.path.join(save_directory, FLAX_WEIGHTS_NAME), "wb") as f:
|
||||
model_bytes = to_bytes(self.params)
|
||||
f.write(model_bytes)
|
||||
|
||||
|
||||
def convert_state_dict_from_pt(model_class: ABC, state: Dict, config: PretrainedConfig):
|
||||
"""
|
||||
Converts a PyTorch parameter state dict to an equivalent Flax parameter state dict
|
||||
"""
|
||||
state = {k: v.numpy() for k, v in state.items()}
|
||||
state = model_class.convert_from_pytorch(state, config)
|
||||
state = unflatten_dict({tuple(k.split(".")): v for k, v in state.items()})
|
||||
return state
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Callable, Dict, Tuple
|
||||
from typing import Callable, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -21,6 +21,8 @@ import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
from flax.linen import dot_product_attention
|
||||
from jax import lax
|
||||
from jax.random import PRNGKey
|
||||
|
||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
@@ -99,17 +101,15 @@ class FlaxBertLayerNorm(nn.Module):
|
||||
|
||||
hidden_size: int
|
||||
epsilon: float = 1e-6
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
bias: bool = True # If True, bias (beta) is added.
|
||||
scale: bool = True # If True, multiply by scale (gamma). When the next layer is linear
|
||||
# (also e.g. nn.relu), this can be disabled since the scaling will be
|
||||
# done by the next layer.
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
use_bias: bool = True
|
||||
scale: bool = True
|
||||
scale_init: Callable[..., np.ndarray] = jax.nn.initializers.ones
|
||||
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
|
||||
|
||||
def setup(self):
|
||||
self.gamma = self.param("gamma", self.scale_init, (self.hidden_size,))
|
||||
self.beta = self.param("beta", self.scale_init, (self.hidden_size,))
|
||||
self.weight = self.param("weight", self.scale_init, (self.hidden_size,))
|
||||
self.bias = self.param("bias", self.scale_init, (self.hidden_size,))
|
||||
|
||||
def __call__(self, x):
|
||||
"""
|
||||
@@ -129,11 +129,11 @@ class FlaxBertLayerNorm(nn.Module):
|
||||
mul = jax.lax.rsqrt(var + self.epsilon)
|
||||
|
||||
if self.scale:
|
||||
mul = mul * jnp.asarray(self.gamma)
|
||||
mul = mul * jnp.asarray(self.weight)
|
||||
y = (x - mean) * mul
|
||||
|
||||
if self.bias:
|
||||
y = y + jnp.asarray(self.beta)
|
||||
if self.use_bias:
|
||||
y = y + jnp.asarray(self.bias)
|
||||
return y
|
||||
|
||||
|
||||
@@ -167,24 +167,21 @@ class FlaxBertEmbeddings(nn.Module):
|
||||
self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
initializer_range=self.config.initializer_range,
|
||||
name="word_embeddings",
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.position_embeddings = FlaxBertEmbedding(
|
||||
self.config.max_position_embeddings,
|
||||
self.config.hidden_size,
|
||||
initializer_range=self.config.initializer_range,
|
||||
name="position_embeddings",
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.token_type_embeddings = FlaxBertEmbedding(
|
||||
self.config.type_vocab_size,
|
||||
self.config.hidden_size,
|
||||
initializer_range=self.config.initializer_range,
|
||||
name="token_type_embeddings",
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.layer_norm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype)
|
||||
self.LayerNorm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
|
||||
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
|
||||
@@ -197,35 +194,116 @@ class FlaxBertEmbeddings(nn.Module):
|
||||
hidden_states = inputs_embeds + jnp.broadcast_to(position_embeds, inputs_embeds.shape) + token_type_embeddings
|
||||
|
||||
# Layer Norm
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxBertSelfAttention(nn.Module):
|
||||
config: BertConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
def setup(self):
|
||||
if self.config.hidden_size % self.config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`: {self.config.num_attention_heads}"
|
||||
)
|
||||
|
||||
self.query = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
)
|
||||
self.key = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
)
|
||||
self.value = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
)
|
||||
|
||||
def __call__(self, hidden_states, attention_mask, deterministic=True):
|
||||
head_dim = self.config.hidden_size // self.config.num_attention_heads
|
||||
|
||||
query_states = self.query(hidden_states).reshape(
|
||||
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
||||
)
|
||||
value_states = self.value(hidden_states).reshape(
|
||||
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
||||
)
|
||||
key_states = self.key(hidden_states).reshape(
|
||||
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
||||
)
|
||||
|
||||
# Convert the boolean attention mask to an attention bias.
|
||||
if attention_mask is not None:
|
||||
# attention mask in the form of attention bias
|
||||
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
||||
attention_bias = lax.select(
|
||||
attention_mask > 0,
|
||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||
jnp.full(attention_mask.shape, -1e10).astype(self.dtype),
|
||||
)
|
||||
else:
|
||||
attention_bias = None
|
||||
|
||||
dropout_rng = None
|
||||
if not deterministic and self.dropout_rate > 0.0:
|
||||
dropout_rng = self.make_rng("dropout")
|
||||
|
||||
attn_output = dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
bias=attention_bias,
|
||||
dropout_rng=dropout_rng,
|
||||
dropout_rate=self.config.attention_probs_dropout_prob,
|
||||
broadcast_dropout=True,
|
||||
deterministic=deterministic,
|
||||
dtype=self.dtype,
|
||||
precision=None,
|
||||
)
|
||||
|
||||
return attn_output.reshape(attn_output.shape[:2] + (-1,))
|
||||
|
||||
|
||||
class FlaxBertSelfOutput(nn.Module):
|
||||
config: BertConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.LayerNorm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
|
||||
def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxBertAttention(nn.Module):
|
||||
config: BertConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.self_attention = nn.attention.SelfAttention(
|
||||
num_heads=self.config.num_attention_heads,
|
||||
qkv_features=self.config.hidden_size,
|
||||
dropout_rate=self.config.attention_probs_dropout_prob,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
bias_init=jax.nn.initializers.zeros,
|
||||
name="self",
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.layer_norm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype)
|
||||
self.self = FlaxBertSelfAttention(self.config, dtype=self.dtype)
|
||||
self.output = FlaxBertSelfOutput(self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, attention_mask, deterministic=True):
|
||||
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
||||
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
||||
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
||||
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
||||
self_attn_output = self.self_attention(hidden_states, attention_mask, deterministic=deterministic)
|
||||
|
||||
hidden_states = self.layer_norm(self_attn_output + hidden_states)
|
||||
attn_output = self.self(hidden_states, attention_mask, deterministic=deterministic)
|
||||
hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -237,7 +315,6 @@ class FlaxBertIntermediate(nn.Module):
|
||||
self.dense = nn.Dense(
|
||||
self.config.intermediate_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
name="dense",
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.activation = ACT2FN[self.config.hidden_act]
|
||||
@@ -256,16 +333,15 @@ class FlaxBertOutput(nn.Module):
|
||||
self.dense = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
name="dense",
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
self.layer_norm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype)
|
||||
self.LayerNorm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, attention_output, deterministic: bool = True):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
||||
hidden_states = self.layer_norm(hidden_states + attention_output)
|
||||
hidden_states = self.LayerNorm(hidden_states + attention_output)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -274,9 +350,9 @@ class FlaxBertLayer(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
def setup(self):
|
||||
self.attention = FlaxBertAttention(self.config, name="attention", dtype=self.dtype)
|
||||
self.intermediate = FlaxBertIntermediate(self.config, name="intermediate", dtype=self.dtype)
|
||||
self.output = FlaxBertOutput(self.config, name="output", dtype=self.dtype)
|
||||
self.attention = FlaxBertAttention(self.config, dtype=self.dtype)
|
||||
self.intermediate = FlaxBertIntermediate(self.config, dtype=self.dtype)
|
||||
self.output = FlaxBertOutput(self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||
attention_output = self.attention(hidden_states, attention_mask, deterministic=deterministic)
|
||||
@@ -305,10 +381,10 @@ class FlaxBertEncoder(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
def setup(self):
|
||||
self.layers = FlaxBertLayerCollection(self.config, name="layer", dtype=self.dtype)
|
||||
self.layer = FlaxBertLayerCollection(self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||
return self.layers(hidden_states, attention_mask, deterministic=deterministic)
|
||||
return self.layer(hidden_states, attention_mask, deterministic=deterministic)
|
||||
|
||||
|
||||
class FlaxBertPooler(nn.Module):
|
||||
@@ -319,7 +395,6 @@ class FlaxBertPooler(nn.Module):
|
||||
self.dense = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
name="dense",
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
@@ -334,14 +409,14 @@ class FlaxBertPredictionHeadTransform(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(self.config.hidden_size, name="dense", dtype=self.dtype)
|
||||
self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype)
|
||||
self.activation = ACT2FN[self.config.hidden_act]
|
||||
self.layer_norm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype)
|
||||
self.LayerNorm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.activation(hidden_states)
|
||||
return self.layer_norm(hidden_states)
|
||||
return self.LayerNorm(hidden_states)
|
||||
|
||||
|
||||
class FlaxBertLMPredictionHead(nn.Module):
|
||||
@@ -349,14 +424,10 @@ class FlaxBertLMPredictionHead(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.transform = FlaxBertPredictionHeadTransform(self.config, name="transform", dtype=self.dtype)
|
||||
self.decoder = nn.Dense(self.config.vocab_size, name="decoder", dtype=self.dtype)
|
||||
self.transform = FlaxBertPredictionHeadTransform(self.config, dtype=self.dtype)
|
||||
self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states):
|
||||
# TODO: The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
# Need a link between the two variables so that the bias is correctly
|
||||
# resized with `resize_token_embeddings`
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
return hidden_states
|
||||
@@ -367,10 +438,10 @@ class FlaxBertOnlyMLMHead(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.mlm_head = FlaxBertLMPredictionHead(self.config, name="predictions", dtype=self.dtype)
|
||||
self.predictions = FlaxBertLMPredictionHead(self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states):
|
||||
hidden_states = self.mlm_head(hidden_states)
|
||||
hidden_states = self.predictions(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -405,85 +476,6 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
|
||||
|
||||
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"]
|
||||
|
||||
@staticmethod
|
||||
def convert_from_pytorch(pt_state: Dict, config: BertConfig) -> Dict:
|
||||
jax_state = dict(pt_state)
|
||||
|
||||
# Need to change some parameters name to match Flax names so that we don't have to fork any layer
|
||||
for key, tensor in pt_state.items():
|
||||
# Key parts
|
||||
key_parts = set(key.split("."))
|
||||
|
||||
# Every dense layer has "kernel" parameters instead of "weight"
|
||||
if "dense.weight" in key:
|
||||
del jax_state[key]
|
||||
key = key.replace("weight", "kernel")
|
||||
jax_state[key] = tensor
|
||||
|
||||
if "decoder.weight" in key:
|
||||
del jax_state[key]
|
||||
key = key.replace("weight", "kernel")
|
||||
jax_state[key] = tensor.T
|
||||
|
||||
# SelfAttention needs also to replace "weight" by "kernel"
|
||||
if {"query", "key", "value"} & key_parts:
|
||||
|
||||
# Flax SelfAttention decomposes the heads (num_head, size // num_heads)
|
||||
if "bias" in key:
|
||||
jax_state[key] = tensor.reshape((config.num_attention_heads, -1))
|
||||
elif "weight":
|
||||
del jax_state[key]
|
||||
key = key.replace("weight", "kernel")
|
||||
tensor = tensor.reshape((config.num_attention_heads, -1, config.hidden_size)).transpose((2, 0, 1))
|
||||
jax_state[key] = tensor
|
||||
|
||||
# SelfAttention output is not a separate layer, remove one nesting
|
||||
if "attention.output.dense" in key:
|
||||
del jax_state[key]
|
||||
key = key.replace("attention.output.dense", "attention.self.out")
|
||||
jax_state[key] = tensor
|
||||
|
||||
# SelfAttention output is not a separate layer, remove nesting on layer norm
|
||||
if "attention.output.LayerNorm" in key:
|
||||
del jax_state[key]
|
||||
key = key.replace("attention.output.LayerNorm", "attention.LayerNorm")
|
||||
jax_state[key] = tensor
|
||||
|
||||
# There are some transposed parameters w.r.t their PyTorch counterpart
|
||||
if "intermediate.dense.kernel" in key or "output.dense.kernel" in key or "transform.dense.kernel" in key:
|
||||
jax_state[key] = tensor.T
|
||||
|
||||
# Self Attention output projection needs to be transposed
|
||||
if "out.kernel" in key:
|
||||
jax_state[key] = tensor.reshape((config.hidden_size, config.num_attention_heads, -1)).transpose(
|
||||
1, 2, 0
|
||||
)
|
||||
|
||||
# Pooler needs to transpose its kernel
|
||||
if "pooler.dense.kernel" in key:
|
||||
jax_state[key] = tensor.T
|
||||
|
||||
# Hack to correctly load some pytorch models
|
||||
if "predictions.bias" in key:
|
||||
del jax_state[key]
|
||||
jax_state[".".join(key.split(".")[:2]) + ".decoder.bias"] = tensor
|
||||
|
||||
# Handle LayerNorm conversion
|
||||
if "LayerNorm" in key:
|
||||
del jax_state[key]
|
||||
|
||||
# Replace LayerNorm by layer_norm
|
||||
new_key = key.replace("LayerNorm", "layer_norm")
|
||||
|
||||
if "weight" in key:
|
||||
new_key = new_key.replace("weight", "gamma")
|
||||
elif "bias" in key:
|
||||
new_key = new_key.replace("bias", "beta")
|
||||
|
||||
jax_state[new_key] = tensor
|
||||
|
||||
return jax_state
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
@@ -541,9 +533,9 @@ class FlaxBertModule(nn.Module):
|
||||
add_pooling_layer: bool = True
|
||||
|
||||
def setup(self):
|
||||
self.embeddings = FlaxBertEmbeddings(self.config, name="embeddings", dtype=self.dtype)
|
||||
self.encoder = FlaxBertEncoder(self.config, name="encoder", dtype=self.dtype)
|
||||
self.pooler = FlaxBertPooler(self.config, name="pooler", dtype=self.dtype)
|
||||
self.embeddings = FlaxBertEmbeddings(self.config, dtype=self.dtype)
|
||||
self.encoder = FlaxBertEncoder(self.config, dtype=self.dtype)
|
||||
self.pooler = FlaxBertPooler(self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
|
||||
|
||||
@@ -602,15 +594,13 @@ class FlaxBertForMaskedLMModule(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.encoder = FlaxBertModule(
|
||||
self.bert = FlaxBertModule(
|
||||
config=self.config,
|
||||
add_pooling_layer=False,
|
||||
name="bert",
|
||||
)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
self.mlm_head = FlaxBertOnlyMLMHead(
|
||||
self.cls = FlaxBertOnlyMLMHead(
|
||||
config=self.config,
|
||||
name="cls",
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
@@ -618,12 +608,10 @@ class FlaxBertForMaskedLMModule(nn.Module):
|
||||
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
|
||||
):
|
||||
# Model
|
||||
hidden_states = self.encoder(
|
||||
input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic
|
||||
)
|
||||
hidden_states = self.bert(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic)
|
||||
|
||||
# Compute the prediction scores
|
||||
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
||||
logits = self.mlm_head(hidden_states)
|
||||
logits = self.cls(hidden_states)
|
||||
|
||||
return (logits,)
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Callable, Dict, Tuple
|
||||
from typing import Callable, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -20,6 +20,8 @@ import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
from flax.linen import dot_product_attention
|
||||
from jax import lax
|
||||
from jax.random import PRNGKey
|
||||
|
||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
@@ -116,17 +118,15 @@ class FlaxRobertaLayerNorm(nn.Module):
|
||||
|
||||
hidden_size: int
|
||||
epsilon: float = 1e-6
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
bias: bool = True # If True, bias (beta) is added.
|
||||
scale: bool = True # If True, multiply by scale (gamma). When the next layer is linear
|
||||
# (also e.g. nn.relu), this can be disabled since the scaling will be
|
||||
# done by the next layer.
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
use_bias: bool = True
|
||||
scale: bool = True
|
||||
scale_init: Callable[..., np.ndarray] = jax.nn.initializers.ones
|
||||
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
|
||||
|
||||
def setup(self):
|
||||
self.gamma = self.param("gamma", self.scale_init, (self.hidden_size,))
|
||||
self.beta = self.param("beta", self.scale_init, (self.hidden_size,))
|
||||
self.weight = self.param("weight", self.scale_init, (self.hidden_size,))
|
||||
self.bias = self.param("bias", self.scale_init, (self.hidden_size,))
|
||||
|
||||
def __call__(self, x):
|
||||
"""
|
||||
@@ -146,11 +146,11 @@ class FlaxRobertaLayerNorm(nn.Module):
|
||||
mul = jax.lax.rsqrt(var + self.epsilon)
|
||||
|
||||
if self.scale:
|
||||
mul = mul * jnp.asarray(self.gamma)
|
||||
mul = mul * jnp.asarray(self.weight)
|
||||
y = (x - mean) * mul
|
||||
|
||||
if self.bias:
|
||||
y = y + jnp.asarray(self.beta)
|
||||
if self.use_bias:
|
||||
y = y + jnp.asarray(self.bias)
|
||||
return y
|
||||
|
||||
|
||||
@@ -186,26 +186,21 @@ class FlaxRobertaEmbeddings(nn.Module):
|
||||
self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
initializer_range=self.config.initializer_range,
|
||||
name="word_embeddings",
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.position_embeddings = FlaxRobertaEmbedding(
|
||||
self.config.max_position_embeddings,
|
||||
self.config.hidden_size,
|
||||
initializer_range=self.config.initializer_range,
|
||||
name="position_embeddings",
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.token_type_embeddings = FlaxRobertaEmbedding(
|
||||
self.config.type_vocab_size,
|
||||
self.config.hidden_size,
|
||||
initializer_range=self.config.initializer_range,
|
||||
name="token_type_embeddings",
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.layer_norm = FlaxRobertaLayerNorm(
|
||||
hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype
|
||||
)
|
||||
self.LayerNorm = FlaxRobertaLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
|
||||
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
|
||||
@@ -218,38 +213,119 @@ class FlaxRobertaEmbeddings(nn.Module):
|
||||
hidden_states = inputs_embeds + jnp.broadcast_to(position_embeds, inputs_embeds.shape) + token_type_embeddings
|
||||
|
||||
# Layer Norm
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->Roberta
|
||||
class FlaxRobertaSelfAttention(nn.Module):
|
||||
config: RobertaConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
def setup(self):
|
||||
if self.config.hidden_size % self.config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`: {self.config.num_attention_heads}"
|
||||
)
|
||||
|
||||
self.query = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
)
|
||||
self.key = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
)
|
||||
self.value = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
)
|
||||
|
||||
def __call__(self, hidden_states, attention_mask, deterministic=True):
|
||||
head_dim = self.config.hidden_size // self.config.num_attention_heads
|
||||
|
||||
query_states = self.query(hidden_states).reshape(
|
||||
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
||||
)
|
||||
value_states = self.value(hidden_states).reshape(
|
||||
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
||||
)
|
||||
key_states = self.key(hidden_states).reshape(
|
||||
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
||||
)
|
||||
|
||||
# Convert the boolean attention mask to an attention bias.
|
||||
if attention_mask is not None:
|
||||
# attention mask in the form of attention bias
|
||||
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
||||
attention_bias = lax.select(
|
||||
attention_mask > 0,
|
||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||
jnp.full(attention_mask.shape, -1e10).astype(self.dtype),
|
||||
)
|
||||
else:
|
||||
attention_bias = None
|
||||
|
||||
dropout_rng = None
|
||||
if not deterministic and self.dropout_rate > 0.0:
|
||||
dropout_rng = self.make_rng("dropout")
|
||||
|
||||
attn_output = dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
bias=attention_bias,
|
||||
dropout_rng=dropout_rng,
|
||||
dropout_rate=self.config.attention_probs_dropout_prob,
|
||||
broadcast_dropout=True,
|
||||
deterministic=deterministic,
|
||||
dtype=self.dtype,
|
||||
precision=None,
|
||||
)
|
||||
|
||||
return attn_output.reshape(attn_output.shape[:2] + (-1,))
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput with Bert->Roberta
|
||||
class FlaxRobertaSelfOutput(nn.Module):
|
||||
config: RobertaConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
def setup(self):
|
||||
self.dense = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.LayerNorm = FlaxRobertaLayerNorm(hidden_size=self.config.hidden_size)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
|
||||
def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Roberta
|
||||
class FlaxRobertaAttention(nn.Module):
|
||||
config: RobertaConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.self_attention = nn.attention.SelfAttention(
|
||||
num_heads=self.config.num_attention_heads,
|
||||
qkv_features=self.config.hidden_size,
|
||||
dropout_rate=self.config.attention_probs_dropout_prob,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
bias_init=jax.nn.initializers.zeros,
|
||||
name="self",
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.layer_norm = FlaxRobertaLayerNorm(
|
||||
hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype
|
||||
)
|
||||
self.self = FlaxRobertaSelfAttention(self.config, dtype=self.dtype)
|
||||
self.output = FlaxRobertaSelfOutput(self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, attention_mask, deterministic=True):
|
||||
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
||||
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
||||
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
||||
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
||||
self_attn_output = self.self_attention(hidden_states, attention_mask, deterministic=deterministic)
|
||||
|
||||
hidden_states = self.layer_norm(self_attn_output + hidden_states)
|
||||
attn_output = self.self(hidden_states, attention_mask, deterministic=deterministic)
|
||||
hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -262,7 +338,6 @@ class FlaxRobertaIntermediate(nn.Module):
|
||||
self.dense = nn.Dense(
|
||||
self.config.intermediate_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
name="dense",
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.activation = ACT2FN[self.config.hidden_act]
|
||||
@@ -282,18 +357,15 @@ class FlaxRobertaOutput(nn.Module):
|
||||
self.dense = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
name="dense",
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
self.layer_norm = FlaxRobertaLayerNorm(
|
||||
hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype
|
||||
)
|
||||
self.LayerNorm = FlaxRobertaLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, attention_output, deterministic: bool = True):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
||||
hidden_states = self.layer_norm(hidden_states + attention_output)
|
||||
hidden_states = self.LayerNorm(hidden_states + attention_output)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -303,9 +375,9 @@ class FlaxRobertaLayer(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
def setup(self):
|
||||
self.attention = FlaxRobertaAttention(self.config, name="attention", dtype=self.dtype)
|
||||
self.intermediate = FlaxRobertaIntermediate(self.config, name="intermediate", dtype=self.dtype)
|
||||
self.output = FlaxRobertaOutput(self.config, name="output", dtype=self.dtype)
|
||||
self.attention = FlaxRobertaAttention(self.config, dtype=self.dtype)
|
||||
self.intermediate = FlaxRobertaIntermediate(self.config, dtype=self.dtype)
|
||||
self.output = FlaxRobertaOutput(self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||
attention_output = self.attention(hidden_states, attention_mask, deterministic=deterministic)
|
||||
@@ -336,10 +408,10 @@ class FlaxRobertaEncoder(nn.Module):
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
|
||||
def setup(self):
|
||||
self.layers = FlaxRobertaLayerCollection(self.config, name="layer", dtype=self.dtype)
|
||||
self.layer = FlaxRobertaLayerCollection(self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||
return self.layers(hidden_states, attention_mask, deterministic=deterministic)
|
||||
return self.layer(hidden_states, attention_mask, deterministic=deterministic)
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPooler with Bert->Roberta
|
||||
@@ -351,7 +423,6 @@ class FlaxRobertaPooler(nn.Module):
|
||||
self.dense = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
|
||||
name="dense",
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
@@ -370,75 +441,6 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
|
||||
config_class = RobertaConfig
|
||||
base_model_prefix = "roberta"
|
||||
|
||||
@staticmethod
|
||||
def convert_from_pytorch(pt_state: Dict, config: RobertaConfig) -> Dict:
|
||||
jax_state = dict(pt_state)
|
||||
|
||||
# Need to change some parameters name to match Flax names so that we don't have to fork any layer
|
||||
for key, tensor in pt_state.items():
|
||||
# Key parts
|
||||
key_parts = set(key.split("."))
|
||||
|
||||
# Every dense layer has "kernel" parameters instead of "weight"
|
||||
if "dense.weight" in key:
|
||||
del jax_state[key]
|
||||
key = key.replace("weight", "kernel")
|
||||
jax_state[key] = tensor
|
||||
|
||||
# SelfAttention needs also to replace "weight" by "kernel"
|
||||
if {"query", "key", "value"} & key_parts:
|
||||
|
||||
# Flax SelfAttention decomposes the heads (num_head, size // num_heads)
|
||||
if "bias" in key:
|
||||
jax_state[key] = tensor.reshape((config.num_attention_heads, -1))
|
||||
elif "weight":
|
||||
del jax_state[key]
|
||||
key = key.replace("weight", "kernel")
|
||||
tensor = tensor.reshape((config.num_attention_heads, -1, config.hidden_size)).transpose((2, 0, 1))
|
||||
jax_state[key] = tensor
|
||||
|
||||
# SelfAttention output is not a separate layer, remove one nesting
|
||||
if "attention.output.dense" in key:
|
||||
del jax_state[key]
|
||||
key = key.replace("attention.output.dense", "attention.self.out")
|
||||
jax_state[key] = tensor
|
||||
|
||||
# SelfAttention output is not a separate layer, remove nesting on layer norm
|
||||
if "attention.output.LayerNorm" in key:
|
||||
del jax_state[key]
|
||||
key = key.replace("attention.output.LayerNorm", "attention.LayerNorm")
|
||||
jax_state[key] = tensor
|
||||
|
||||
# There are some transposed parameters w.r.t their PyTorch counterpart
|
||||
if "intermediate.dense.kernel" in key or "output.dense.kernel" in key:
|
||||
jax_state[key] = tensor.T
|
||||
|
||||
# Self Attention output projection needs to be transposed
|
||||
if "out.kernel" in key:
|
||||
jax_state[key] = tensor.reshape((config.hidden_size, config.num_attention_heads, -1)).transpose(
|
||||
1, 2, 0
|
||||
)
|
||||
|
||||
# Pooler needs to transpose its kernel
|
||||
if "pooler.dense.kernel" in key:
|
||||
jax_state[key] = tensor.T
|
||||
|
||||
# Handle LayerNorm conversion
|
||||
if "LayerNorm" in key:
|
||||
del jax_state[key]
|
||||
|
||||
# Replace LayerNorm by layer_norm
|
||||
new_key = key.replace("LayerNorm", "layer_norm")
|
||||
|
||||
if "weight" in key:
|
||||
new_key = new_key.replace("weight", "gamma")
|
||||
elif "bias" in key:
|
||||
new_key = new_key.replace("bias", "beta")
|
||||
|
||||
jax_state[new_key] = tensor
|
||||
|
||||
return jax_state
|
||||
|
||||
def init(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||
jnp.zeros(input_shape, dtype="i4"), None, None, None
|
||||
@@ -523,9 +525,9 @@ class FlaxRobertaModule(nn.Module):
|
||||
add_pooling_layer: bool = True
|
||||
|
||||
def setup(self):
|
||||
self.embeddings = FlaxRobertaEmbeddings(self.config, name="embeddings", dtype=self.dtype)
|
||||
self.encoder = FlaxRobertaEncoder(self.config, name="encoder", dtype=self.dtype)
|
||||
self.pooler = FlaxRobertaPooler(self.config, name="pooler", dtype=self.dtype)
|
||||
self.embeddings = FlaxRobertaEmbeddings(self.config, dtype=self.dtype)
|
||||
self.encoder = FlaxRobertaEncoder(self.config, dtype=self.dtype)
|
||||
self.pooler = FlaxRobertaPooler(self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
|
||||
|
||||
|
||||
@@ -115,6 +115,6 @@ class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_class_name in self.all_model_classes:
|
||||
model = model_class_name.from_pretrained("bert-base-cased")
|
||||
model = model_class_name.from_pretrained("bert-base-cased", from_pt=True)
|
||||
outputs = model(np.ones((1, 1)))
|
||||
self.assertIsNotNone(outputs)
|
||||
|
||||
@@ -27,7 +27,7 @@ if is_flax_available():
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from transformers.modeling_flax_utils import convert_state_dict_from_pt
|
||||
from transformers.modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
|
||||
|
||||
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
|
||||
|
||||
@@ -79,8 +79,8 @@ class FlaxModelTesterMixin:
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
pt_model = pt_model_class(config).eval()
|
||||
|
||||
fx_state = convert_state_dict_from_pt(model_class, pt_model.state_dict(), config)
|
||||
fx_model = model_class(config, dtype=jnp.float32)
|
||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||
fx_model.params = fx_state
|
||||
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in inputs_dict.items()}
|
||||
|
||||
@@ -115,6 +115,6 @@ class FlaxRobertaModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_class_name in self.all_model_classes:
|
||||
model = model_class_name.from_pretrained("roberta-base")
|
||||
model = model_class_name.from_pretrained("roberta-base", from_pt=True)
|
||||
outputs = model(np.ones((1, 1)))
|
||||
self.assertIsNotNone(outputs)
|
||||
|
||||
Reference in New Issue
Block a user