[WIP][Flax] Add general conversion script (#10809)

* save intermediate

* finish first version

* delete some more

* improve import

* fix roberta

* Update src/transformers/modeling_flax_pytorch_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/modeling_flax_pytorch_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* small corrections

* apply all comments

* fix deterministic

* make fix-copies

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Patrick von Platen
2021-03-30 12:13:59 +03:00
committed by GitHub
parent 604c085087
commit 8780caa388
7 changed files with 370 additions and 297 deletions

View File

@@ -0,0 +1,100 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch - TF 2.0 general utilities."""
import os
from flax.core.frozen_dict import unfreeze
from flax.traverse_util import flatten_dict, unflatten_dict
from .utils import logging
logger = logging.get_logger(__name__)
#####################
# PyTorch => Flax #
#####################
def load_pytorch_checkpoint_in_flax_state_dict(flax_model, pytorch_checkpoint_path, allow_missing_keys=False):
"""Load pytorch checkpoints in a flax model"""
try:
import torch # noqa: F401
except ImportError:
logger.error(
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see "
"https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation instructions."
)
raise
pt_path = os.path.abspath(pytorch_checkpoint_path)
logger.info("Loading PyTorch weights from {}".format(pt_path))
pt_state_dict = torch.load(pt_path, map_location="cpu")
logger.info("PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values())} parameters.")
flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
return flax_state_dict
def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
# convert pytorch tensor to numpy
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
random_flax_state_dict = flatten_dict(unfreeze(flax_model.params))
flax_state_dict = {}
remove_base_model_prefix = (flax_model.base_model_prefix not in flax_model.params) and (
flax_model.base_model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()])
)
add_base_model_prefix = (flax_model.base_model_prefix in flax_model.params) and (
flax_model.base_model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()])
)
# Need to change some parameters name to match Flax names so that we don't have to fork any layer
for pt_key, pt_tensor in pt_state_dict.items():
pt_tuple_key = tuple(pt_key.split("."))
has_base_model_prefix = pt_tuple_key[0] == flax_model.base_model_prefix
require_base_model_prefix = (flax_model.base_model_prefix,) + pt_tuple_key in random_flax_state_dict
if remove_base_model_prefix and has_base_model_prefix:
pt_tuple_key = pt_tuple_key[1:]
elif add_base_model_prefix and require_base_model_prefix:
pt_tuple_key = (flax_model.base_model_prefix,) + pt_tuple_key
if pt_tuple_key[-1] == "weight" and pt_tuple_key not in random_flax_state_dict:
pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
pt_tensor = pt_tensor.T
elif pt_tuple_key[-1] == "gamma":
pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
elif pt_tuple_key[-1] == "beta":
pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
if pt_tuple_key in random_flax_state_dict:
if random_flax_state_dict[pt_tuple_key].shape != pt_tensor.shape:
raise ValueError(
"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape {random_flax_state_dict[pt_tuple_key].shape}, but is {pt_tensor.shape}."
)
# add unexpected weight so that warning is thrown
flax_state_dict[pt_tuple_key] = pt_tensor
return unflatten_dict(flax_state_dict)

View File

@@ -14,7 +14,7 @@
# limitations under the License.
import os
from abc import ABC, abstractmethod
from abc import ABC
from functools import partial
from pickle import UnpicklingError
from typing import Dict, Set, Tuple, Union
@@ -29,6 +29,7 @@ from jax.random import PRNGKey
from .configuration_utils import PretrainedConfig
from .file_utils import FLAX_WEIGHTS_NAME, WEIGHTS_NAME, cached_path, hf_bucket_url, is_offline_mode, is_remote_url
from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict
from .utils import logging
@@ -121,11 +122,6 @@ class FlaxPreTrainedModel(ABC):
)
self._params = freeze(params)
@staticmethod
@abstractmethod
def convert_from_pytorch(pt_state: Dict, config: PretrainedConfig) -> Dict:
raise NotImplementedError()
@classmethod
def from_pretrained(
cls,
@@ -307,25 +303,18 @@ class FlaxPreTrainedModel(ABC):
else:
resolved_archive_file = None
# Instantiate model.
with open(resolved_archive_file, "rb") as state_f:
try:
if from_pt:
import torch
state = torch.load(state_f)
state = convert_state_dict_from_pt(cls, state, config)
else:
state = from_bytes(cls, state_f.read())
except UnpicklingError:
raise EnvironmentError(
f"Unable to convert pytorch model {archive_file} to Flax deserializable object. "
)
# init random models
model = cls(config, *model_args, **model_kwargs)
if from_pt:
state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file)
else:
with open(resolved_archive_file, "rb") as state_f:
try:
state = from_bytes(cls, state_f.read())
except UnpicklingError:
raise EnvironmentError(f"Unable to convert {archive_file} to Flax deserializable object. ")
# if model is base model only use model_prefix key
if cls.base_model_prefix not in dict(model.params) and cls.base_model_prefix in state:
state = state[cls.base_model_prefix]
@@ -341,6 +330,10 @@ class FlaxPreTrainedModel(ABC):
for missing_key in missing_keys:
state[missing_key] = random_state[missing_key]
# remove unexpected keys to not be saved again
for unexpected_key in unexpected_keys:
del state[unexpected_key]
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
@@ -393,13 +386,3 @@ class FlaxPreTrainedModel(ABC):
with open(os.path.join(save_directory, FLAX_WEIGHTS_NAME), "wb") as f:
model_bytes = to_bytes(self.params)
f.write(model_bytes)
def convert_state_dict_from_pt(model_class: ABC, state: Dict, config: PretrainedConfig):
"""
Converts a PyTorch parameter state dict to an equivalent Flax parameter state dict
"""
state = {k: v.numpy() for k, v in state.items()}
state = model_class.convert_from_pytorch(state, config)
state = unflatten_dict({tuple(k.split(".")): v for k, v in state.items()})
return state

View File

@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Dict, Tuple
from typing import Callable, Tuple
import numpy as np
@@ -21,6 +21,8 @@ import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from flax.linen import dot_product_attention
from jax import lax
from jax.random import PRNGKey
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
@@ -99,17 +101,15 @@ class FlaxBertLayerNorm(nn.Module):
hidden_size: int
epsilon: float = 1e-6
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
bias: bool = True # If True, bias (beta) is added.
scale: bool = True # If True, multiply by scale (gamma). When the next layer is linear
# (also e.g. nn.relu), this can be disabled since the scaling will be
# done by the next layer.
dtype: jnp.dtype = jnp.float32
use_bias: bool = True
scale: bool = True
scale_init: Callable[..., np.ndarray] = jax.nn.initializers.ones
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
def setup(self):
self.gamma = self.param("gamma", self.scale_init, (self.hidden_size,))
self.beta = self.param("beta", self.scale_init, (self.hidden_size,))
self.weight = self.param("weight", self.scale_init, (self.hidden_size,))
self.bias = self.param("bias", self.scale_init, (self.hidden_size,))
def __call__(self, x):
"""
@@ -129,11 +129,11 @@ class FlaxBertLayerNorm(nn.Module):
mul = jax.lax.rsqrt(var + self.epsilon)
if self.scale:
mul = mul * jnp.asarray(self.gamma)
mul = mul * jnp.asarray(self.weight)
y = (x - mean) * mul
if self.bias:
y = y + jnp.asarray(self.beta)
if self.use_bias:
y = y + jnp.asarray(self.bias)
return y
@@ -167,24 +167,21 @@ class FlaxBertEmbeddings(nn.Module):
self.config.vocab_size,
self.config.hidden_size,
initializer_range=self.config.initializer_range,
name="word_embeddings",
dtype=self.dtype,
)
self.position_embeddings = FlaxBertEmbedding(
self.config.max_position_embeddings,
self.config.hidden_size,
initializer_range=self.config.initializer_range,
name="position_embeddings",
dtype=self.dtype,
)
self.token_type_embeddings = FlaxBertEmbedding(
self.config.type_vocab_size,
self.config.hidden_size,
initializer_range=self.config.initializer_range,
name="token_type_embeddings",
dtype=self.dtype,
)
self.layer_norm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype)
self.LayerNorm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
@@ -197,35 +194,116 @@ class FlaxBertEmbeddings(nn.Module):
hidden_states = inputs_embeds + jnp.broadcast_to(position_embeds, inputs_embeds.shape) + token_type_embeddings
# Layer Norm
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
return hidden_states
class FlaxBertSelfAttention(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
if self.config.hidden_size % self.config.num_attention_heads != 0:
raise ValueError(
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`: {self.config.num_attention_heads}"
)
self.query = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
)
self.key = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
)
self.value = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
)
def __call__(self, hidden_states, attention_mask, deterministic=True):
head_dim = self.config.hidden_size // self.config.num_attention_heads
query_states = self.query(hidden_states).reshape(
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
)
value_states = self.value(hidden_states).reshape(
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
)
key_states = self.key(hidden_states).reshape(
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
)
# Convert the boolean attention mask to an attention bias.
if attention_mask is not None:
# attention mask in the form of attention bias
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
attention_bias = lax.select(
attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e10).astype(self.dtype),
)
else:
attention_bias = None
dropout_rng = None
if not deterministic and self.dropout_rate > 0.0:
dropout_rng = self.make_rng("dropout")
attn_output = dot_product_attention(
query_states,
key_states,
value_states,
bias=attention_bias,
dropout_rng=dropout_rng,
dropout_rate=self.config.attention_probs_dropout_prob,
broadcast_dropout=True,
deterministic=deterministic,
dtype=self.dtype,
precision=None,
)
return attn_output.reshape(attn_output.shape[:2] + (-1,))
class FlaxBertSelfOutput(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
dtype=self.dtype,
)
self.LayerNorm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class FlaxBertAttention(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
dtype: jnp.dtype = jnp.float32
def setup(self):
self.self_attention = nn.attention.SelfAttention(
num_heads=self.config.num_attention_heads,
qkv_features=self.config.hidden_size,
dropout_rate=self.config.attention_probs_dropout_prob,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
bias_init=jax.nn.initializers.zeros,
name="self",
dtype=self.dtype,
)
self.layer_norm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype)
self.self = FlaxBertSelfAttention(self.config, dtype=self.dtype)
self.output = FlaxBertSelfOutput(self.config, dtype=self.dtype)
def __call__(self, hidden_states, attention_mask, deterministic=True):
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
self_attn_output = self.self_attention(hidden_states, attention_mask, deterministic=deterministic)
hidden_states = self.layer_norm(self_attn_output + hidden_states)
attn_output = self.self(hidden_states, attention_mask, deterministic=deterministic)
hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
return hidden_states
@@ -237,7 +315,6 @@ class FlaxBertIntermediate(nn.Module):
self.dense = nn.Dense(
self.config.intermediate_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
name="dense",
dtype=self.dtype,
)
self.activation = ACT2FN[self.config.hidden_act]
@@ -256,16 +333,15 @@ class FlaxBertOutput(nn.Module):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
name="dense",
dtype=self.dtype,
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.layer_norm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype)
self.LayerNorm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype)
def __call__(self, hidden_states, attention_output, deterministic: bool = True):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = self.layer_norm(hidden_states + attention_output)
hidden_states = self.LayerNorm(hidden_states + attention_output)
return hidden_states
@@ -274,9 +350,9 @@ class FlaxBertLayer(nn.Module):
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.attention = FlaxBertAttention(self.config, name="attention", dtype=self.dtype)
self.intermediate = FlaxBertIntermediate(self.config, name="intermediate", dtype=self.dtype)
self.output = FlaxBertOutput(self.config, name="output", dtype=self.dtype)
self.attention = FlaxBertAttention(self.config, dtype=self.dtype)
self.intermediate = FlaxBertIntermediate(self.config, dtype=self.dtype)
self.output = FlaxBertOutput(self.config, dtype=self.dtype)
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
attention_output = self.attention(hidden_states, attention_mask, deterministic=deterministic)
@@ -305,10 +381,10 @@ class FlaxBertEncoder(nn.Module):
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.layers = FlaxBertLayerCollection(self.config, name="layer", dtype=self.dtype)
self.layer = FlaxBertLayerCollection(self.config, dtype=self.dtype)
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
return self.layers(hidden_states, attention_mask, deterministic=deterministic)
return self.layer(hidden_states, attention_mask, deterministic=deterministic)
class FlaxBertPooler(nn.Module):
@@ -319,7 +395,6 @@ class FlaxBertPooler(nn.Module):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
name="dense",
dtype=self.dtype,
)
@@ -334,14 +409,14 @@ class FlaxBertPredictionHeadTransform(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
self.dense = nn.Dense(self.config.hidden_size, name="dense", dtype=self.dtype)
self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype)
self.activation = ACT2FN[self.config.hidden_act]
self.layer_norm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype)
self.LayerNorm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype)
def __call__(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.activation(hidden_states)
return self.layer_norm(hidden_states)
return self.LayerNorm(hidden_states)
class FlaxBertLMPredictionHead(nn.Module):
@@ -349,14 +424,10 @@ class FlaxBertLMPredictionHead(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
self.transform = FlaxBertPredictionHeadTransform(self.config, name="transform", dtype=self.dtype)
self.decoder = nn.Dense(self.config.vocab_size, name="decoder", dtype=self.dtype)
self.transform = FlaxBertPredictionHeadTransform(self.config, dtype=self.dtype)
self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype)
def __call__(self, hidden_states):
# TODO: The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
# Need a link between the two variables so that the bias is correctly
# resized with `resize_token_embeddings`
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states
@@ -367,10 +438,10 @@ class FlaxBertOnlyMLMHead(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
self.mlm_head = FlaxBertLMPredictionHead(self.config, name="predictions", dtype=self.dtype)
self.predictions = FlaxBertLMPredictionHead(self.config, dtype=self.dtype)
def __call__(self, hidden_states):
hidden_states = self.mlm_head(hidden_states)
hidden_states = self.predictions(hidden_states)
return hidden_states
@@ -405,85 +476,6 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"]
@staticmethod
def convert_from_pytorch(pt_state: Dict, config: BertConfig) -> Dict:
jax_state = dict(pt_state)
# Need to change some parameters name to match Flax names so that we don't have to fork any layer
for key, tensor in pt_state.items():
# Key parts
key_parts = set(key.split("."))
# Every dense layer has "kernel" parameters instead of "weight"
if "dense.weight" in key:
del jax_state[key]
key = key.replace("weight", "kernel")
jax_state[key] = tensor
if "decoder.weight" in key:
del jax_state[key]
key = key.replace("weight", "kernel")
jax_state[key] = tensor.T
# SelfAttention needs also to replace "weight" by "kernel"
if {"query", "key", "value"} & key_parts:
# Flax SelfAttention decomposes the heads (num_head, size // num_heads)
if "bias" in key:
jax_state[key] = tensor.reshape((config.num_attention_heads, -1))
elif "weight":
del jax_state[key]
key = key.replace("weight", "kernel")
tensor = tensor.reshape((config.num_attention_heads, -1, config.hidden_size)).transpose((2, 0, 1))
jax_state[key] = tensor
# SelfAttention output is not a separate layer, remove one nesting
if "attention.output.dense" in key:
del jax_state[key]
key = key.replace("attention.output.dense", "attention.self.out")
jax_state[key] = tensor
# SelfAttention output is not a separate layer, remove nesting on layer norm
if "attention.output.LayerNorm" in key:
del jax_state[key]
key = key.replace("attention.output.LayerNorm", "attention.LayerNorm")
jax_state[key] = tensor
# There are some transposed parameters w.r.t their PyTorch counterpart
if "intermediate.dense.kernel" in key or "output.dense.kernel" in key or "transform.dense.kernel" in key:
jax_state[key] = tensor.T
# Self Attention output projection needs to be transposed
if "out.kernel" in key:
jax_state[key] = tensor.reshape((config.hidden_size, config.num_attention_heads, -1)).transpose(
1, 2, 0
)
# Pooler needs to transpose its kernel
if "pooler.dense.kernel" in key:
jax_state[key] = tensor.T
# Hack to correctly load some pytorch models
if "predictions.bias" in key:
del jax_state[key]
jax_state[".".join(key.split(".")[:2]) + ".decoder.bias"] = tensor
# Handle LayerNorm conversion
if "LayerNorm" in key:
del jax_state[key]
# Replace LayerNorm by layer_norm
new_key = key.replace("LayerNorm", "layer_norm")
if "weight" in key:
new_key = new_key.replace("weight", "gamma")
elif "bias" in key:
new_key = new_key.replace("bias", "beta")
jax_state[new_key] = tensor
return jax_state
@add_start_docstrings(
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
@@ -541,9 +533,9 @@ class FlaxBertModule(nn.Module):
add_pooling_layer: bool = True
def setup(self):
self.embeddings = FlaxBertEmbeddings(self.config, name="embeddings", dtype=self.dtype)
self.encoder = FlaxBertEncoder(self.config, name="encoder", dtype=self.dtype)
self.pooler = FlaxBertPooler(self.config, name="pooler", dtype=self.dtype)
self.embeddings = FlaxBertEmbeddings(self.config, dtype=self.dtype)
self.encoder = FlaxBertEncoder(self.config, dtype=self.dtype)
self.pooler = FlaxBertPooler(self.config, dtype=self.dtype)
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
@@ -602,15 +594,13 @@ class FlaxBertForMaskedLMModule(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
self.encoder = FlaxBertModule(
self.bert = FlaxBertModule(
config=self.config,
add_pooling_layer=False,
name="bert",
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.mlm_head = FlaxBertOnlyMLMHead(
self.cls = FlaxBertOnlyMLMHead(
config=self.config,
name="cls",
dtype=self.dtype,
)
@@ -618,12 +608,10 @@ class FlaxBertForMaskedLMModule(nn.Module):
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
):
# Model
hidden_states = self.encoder(
input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic
)
hidden_states = self.bert(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic)
# Compute the prediction scores
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
logits = self.mlm_head(hidden_states)
logits = self.cls(hidden_states)
return (logits,)

View File

@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Dict, Tuple
from typing import Callable, Tuple
import numpy as np
@@ -20,6 +20,8 @@ import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from flax.linen import dot_product_attention
from jax import lax
from jax.random import PRNGKey
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
@@ -116,17 +118,15 @@ class FlaxRobertaLayerNorm(nn.Module):
hidden_size: int
epsilon: float = 1e-6
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
bias: bool = True # If True, bias (beta) is added.
scale: bool = True # If True, multiply by scale (gamma). When the next layer is linear
# (also e.g. nn.relu), this can be disabled since the scaling will be
# done by the next layer.
dtype: jnp.dtype = jnp.float32
use_bias: bool = True
scale: bool = True
scale_init: Callable[..., np.ndarray] = jax.nn.initializers.ones
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
def setup(self):
self.gamma = self.param("gamma", self.scale_init, (self.hidden_size,))
self.beta = self.param("beta", self.scale_init, (self.hidden_size,))
self.weight = self.param("weight", self.scale_init, (self.hidden_size,))
self.bias = self.param("bias", self.scale_init, (self.hidden_size,))
def __call__(self, x):
"""
@@ -146,11 +146,11 @@ class FlaxRobertaLayerNorm(nn.Module):
mul = jax.lax.rsqrt(var + self.epsilon)
if self.scale:
mul = mul * jnp.asarray(self.gamma)
mul = mul * jnp.asarray(self.weight)
y = (x - mean) * mul
if self.bias:
y = y + jnp.asarray(self.beta)
if self.use_bias:
y = y + jnp.asarray(self.bias)
return y
@@ -186,26 +186,21 @@ class FlaxRobertaEmbeddings(nn.Module):
self.config.vocab_size,
self.config.hidden_size,
initializer_range=self.config.initializer_range,
name="word_embeddings",
dtype=self.dtype,
)
self.position_embeddings = FlaxRobertaEmbedding(
self.config.max_position_embeddings,
self.config.hidden_size,
initializer_range=self.config.initializer_range,
name="position_embeddings",
dtype=self.dtype,
)
self.token_type_embeddings = FlaxRobertaEmbedding(
self.config.type_vocab_size,
self.config.hidden_size,
initializer_range=self.config.initializer_range,
name="token_type_embeddings",
dtype=self.dtype,
)
self.layer_norm = FlaxRobertaLayerNorm(
hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype
)
self.LayerNorm = FlaxRobertaLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
@@ -218,38 +213,119 @@ class FlaxRobertaEmbeddings(nn.Module):
hidden_states = inputs_embeds + jnp.broadcast_to(position_embeds, inputs_embeds.shape) + token_type_embeddings
# Layer Norm
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
return hidden_states
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->Roberta
class FlaxRobertaSelfAttention(nn.Module):
config: RobertaConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
if self.config.hidden_size % self.config.num_attention_heads != 0:
raise ValueError(
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`: {self.config.num_attention_heads}"
)
self.query = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
)
self.key = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
)
self.value = nn.Dense(
self.config.hidden_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
)
def __call__(self, hidden_states, attention_mask, deterministic=True):
head_dim = self.config.hidden_size // self.config.num_attention_heads
query_states = self.query(hidden_states).reshape(
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
)
value_states = self.value(hidden_states).reshape(
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
)
key_states = self.key(hidden_states).reshape(
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
)
# Convert the boolean attention mask to an attention bias.
if attention_mask is not None:
# attention mask in the form of attention bias
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
attention_bias = lax.select(
attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e10).astype(self.dtype),
)
else:
attention_bias = None
dropout_rng = None
if not deterministic and self.dropout_rate > 0.0:
dropout_rng = self.make_rng("dropout")
attn_output = dot_product_attention(
query_states,
key_states,
value_states,
bias=attention_bias,
dropout_rng=dropout_rng,
dropout_rate=self.config.attention_probs_dropout_prob,
broadcast_dropout=True,
deterministic=deterministic,
dtype=self.dtype,
precision=None,
)
return attn_output.reshape(attn_output.shape[:2] + (-1,))
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput with Bert->Roberta
class FlaxRobertaSelfOutput(nn.Module):
config: RobertaConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
dtype=self.dtype,
)
self.LayerNorm = FlaxRobertaLayerNorm(hidden_size=self.config.hidden_size)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Roberta
class FlaxRobertaAttention(nn.Module):
config: RobertaConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
dtype: jnp.dtype = jnp.float32
def setup(self):
self.self_attention = nn.attention.SelfAttention(
num_heads=self.config.num_attention_heads,
qkv_features=self.config.hidden_size,
dropout_rate=self.config.attention_probs_dropout_prob,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
bias_init=jax.nn.initializers.zeros,
name="self",
dtype=self.dtype,
)
self.layer_norm = FlaxRobertaLayerNorm(
hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype
)
self.self = FlaxRobertaSelfAttention(self.config, dtype=self.dtype)
self.output = FlaxRobertaSelfOutput(self.config, dtype=self.dtype)
def __call__(self, hidden_states, attention_mask, deterministic=True):
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
self_attn_output = self.self_attention(hidden_states, attention_mask, deterministic=deterministic)
hidden_states = self.layer_norm(self_attn_output + hidden_states)
attn_output = self.self(hidden_states, attention_mask, deterministic=deterministic)
hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
return hidden_states
@@ -262,7 +338,6 @@ class FlaxRobertaIntermediate(nn.Module):
self.dense = nn.Dense(
self.config.intermediate_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
name="dense",
dtype=self.dtype,
)
self.activation = ACT2FN[self.config.hidden_act]
@@ -282,18 +357,15 @@ class FlaxRobertaOutput(nn.Module):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
name="dense",
dtype=self.dtype,
)
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
self.layer_norm = FlaxRobertaLayerNorm(
hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype
)
self.LayerNorm = FlaxRobertaLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype)
def __call__(self, hidden_states, attention_output, deterministic: bool = True):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = self.layer_norm(hidden_states + attention_output)
hidden_states = self.LayerNorm(hidden_states + attention_output)
return hidden_states
@@ -303,9 +375,9 @@ class FlaxRobertaLayer(nn.Module):
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.attention = FlaxRobertaAttention(self.config, name="attention", dtype=self.dtype)
self.intermediate = FlaxRobertaIntermediate(self.config, name="intermediate", dtype=self.dtype)
self.output = FlaxRobertaOutput(self.config, name="output", dtype=self.dtype)
self.attention = FlaxRobertaAttention(self.config, dtype=self.dtype)
self.intermediate = FlaxRobertaIntermediate(self.config, dtype=self.dtype)
self.output = FlaxRobertaOutput(self.config, dtype=self.dtype)
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
attention_output = self.attention(hidden_states, attention_mask, deterministic=deterministic)
@@ -336,10 +408,10 @@ class FlaxRobertaEncoder(nn.Module):
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.layers = FlaxRobertaLayerCollection(self.config, name="layer", dtype=self.dtype)
self.layer = FlaxRobertaLayerCollection(self.config, dtype=self.dtype)
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
return self.layers(hidden_states, attention_mask, deterministic=deterministic)
return self.layer(hidden_states, attention_mask, deterministic=deterministic)
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPooler with Bert->Roberta
@@ -351,7 +423,6 @@ class FlaxRobertaPooler(nn.Module):
self.dense = nn.Dense(
self.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
name="dense",
dtype=self.dtype,
)
@@ -370,75 +441,6 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
config_class = RobertaConfig
base_model_prefix = "roberta"
@staticmethod
def convert_from_pytorch(pt_state: Dict, config: RobertaConfig) -> Dict:
jax_state = dict(pt_state)
# Need to change some parameters name to match Flax names so that we don't have to fork any layer
for key, tensor in pt_state.items():
# Key parts
key_parts = set(key.split("."))
# Every dense layer has "kernel" parameters instead of "weight"
if "dense.weight" in key:
del jax_state[key]
key = key.replace("weight", "kernel")
jax_state[key] = tensor
# SelfAttention needs also to replace "weight" by "kernel"
if {"query", "key", "value"} & key_parts:
# Flax SelfAttention decomposes the heads (num_head, size // num_heads)
if "bias" in key:
jax_state[key] = tensor.reshape((config.num_attention_heads, -1))
elif "weight":
del jax_state[key]
key = key.replace("weight", "kernel")
tensor = tensor.reshape((config.num_attention_heads, -1, config.hidden_size)).transpose((2, 0, 1))
jax_state[key] = tensor
# SelfAttention output is not a separate layer, remove one nesting
if "attention.output.dense" in key:
del jax_state[key]
key = key.replace("attention.output.dense", "attention.self.out")
jax_state[key] = tensor
# SelfAttention output is not a separate layer, remove nesting on layer norm
if "attention.output.LayerNorm" in key:
del jax_state[key]
key = key.replace("attention.output.LayerNorm", "attention.LayerNorm")
jax_state[key] = tensor
# There are some transposed parameters w.r.t their PyTorch counterpart
if "intermediate.dense.kernel" in key or "output.dense.kernel" in key:
jax_state[key] = tensor.T
# Self Attention output projection needs to be transposed
if "out.kernel" in key:
jax_state[key] = tensor.reshape((config.hidden_size, config.num_attention_heads, -1)).transpose(
1, 2, 0
)
# Pooler needs to transpose its kernel
if "pooler.dense.kernel" in key:
jax_state[key] = tensor.T
# Handle LayerNorm conversion
if "LayerNorm" in key:
del jax_state[key]
# Replace LayerNorm by layer_norm
new_key = key.replace("LayerNorm", "layer_norm")
if "weight" in key:
new_key = new_key.replace("weight", "gamma")
elif "bias" in key:
new_key = new_key.replace("bias", "beta")
jax_state[new_key] = tensor
return jax_state
def init(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
jnp.zeros(input_shape, dtype="i4"), None, None, None
@@ -523,9 +525,9 @@ class FlaxRobertaModule(nn.Module):
add_pooling_layer: bool = True
def setup(self):
self.embeddings = FlaxRobertaEmbeddings(self.config, name="embeddings", dtype=self.dtype)
self.encoder = FlaxRobertaEncoder(self.config, name="encoder", dtype=self.dtype)
self.pooler = FlaxRobertaPooler(self.config, name="pooler", dtype=self.dtype)
self.embeddings = FlaxRobertaEmbeddings(self.config, dtype=self.dtype)
self.encoder = FlaxRobertaEncoder(self.config, dtype=self.dtype)
self.pooler = FlaxRobertaPooler(self.config, dtype=self.dtype)
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):

View File

@@ -115,6 +115,6 @@ class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase):
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:
model = model_class_name.from_pretrained("bert-base-cased")
model = model_class_name.from_pretrained("bert-base-cased", from_pt=True)
outputs = model(np.ones((1, 1)))
self.assertIsNotNone(outputs)

View File

@@ -27,7 +27,7 @@ if is_flax_available():
import jax
import jax.numpy as jnp
from transformers.modeling_flax_utils import convert_state_dict_from_pt
from transformers.modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
@@ -79,8 +79,8 @@ class FlaxModelTesterMixin:
pt_model_class = getattr(transformers, pt_model_class_name)
pt_model = pt_model_class(config).eval()
fx_state = convert_state_dict_from_pt(model_class, pt_model.state_dict(), config)
fx_model = model_class(config, dtype=jnp.float32)
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in inputs_dict.items()}

View File

@@ -115,6 +115,6 @@ class FlaxRobertaModelTest(FlaxModelTesterMixin, unittest.TestCase):
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:
model = model_class_name.from_pretrained("roberta-base")
model = model_class_name.from_pretrained("roberta-base", from_pt=True)
outputs = model(np.ones((1, 1)))
self.assertIsNotNone(outputs)