From 8780caa388c7b2aa937454ed96bcdd3f097f851d Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 30 Mar 2021 12:13:59 +0300 Subject: [PATCH] [WIP][Flax] Add general conversion script (#10809) * save intermediate * finish first version * delete some more * improve import * fix roberta * Update src/transformers/modeling_flax_pytorch_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_flax_pytorch_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * small corrections * apply all comments * fix deterministic * make fix-copies Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- .../modeling_flax_pytorch_utils.py | 100 +++++++ src/transformers/modeling_flax_utils.py | 47 +-- .../models/bert/modeling_flax_bert.py | 272 +++++++++--------- .../models/roberta/modeling_flax_roberta.py | 240 ++++++++-------- tests/test_modeling_flax_bert.py | 2 +- tests/test_modeling_flax_common.py | 4 +- tests/test_modeling_flax_roberta.py | 2 +- 7 files changed, 370 insertions(+), 297 deletions(-) create mode 100644 src/transformers/modeling_flax_pytorch_utils.py diff --git a/src/transformers/modeling_flax_pytorch_utils.py b/src/transformers/modeling_flax_pytorch_utils.py new file mode 100644 index 0000000000..31001b88ee --- /dev/null +++ b/src/transformers/modeling_flax_pytorch_utils.py @@ -0,0 +1,100 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch - TF 2.0 general utilities.""" + + +import os + +from flax.core.frozen_dict import unfreeze +from flax.traverse_util import flatten_dict, unflatten_dict + +from .utils import logging + + +logger = logging.get_logger(__name__) + + +##################### +# PyTorch => Flax # +##################### + + +def load_pytorch_checkpoint_in_flax_state_dict(flax_model, pytorch_checkpoint_path, allow_missing_keys=False): + """Load pytorch checkpoints in a flax model""" + try: + import torch # noqa: F401 + except ImportError: + logger.error( + "Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see " + "https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation instructions." + ) + raise + + pt_path = os.path.abspath(pytorch_checkpoint_path) + logger.info("Loading PyTorch weights from {}".format(pt_path)) + + pt_state_dict = torch.load(pt_path, map_location="cpu") + logger.info("PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values())} parameters.") + + flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model) + + return flax_state_dict + + +def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): + # convert pytorch tensor to numpy + pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} + + random_flax_state_dict = flatten_dict(unfreeze(flax_model.params)) + flax_state_dict = {} + + remove_base_model_prefix = (flax_model.base_model_prefix not in flax_model.params) and ( + flax_model.base_model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()]) + ) + add_base_model_prefix = (flax_model.base_model_prefix in flax_model.params) and ( + flax_model.base_model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()]) + ) + + # Need to change some parameters name to match Flax names so that we don't have to fork any layer + for pt_key, pt_tensor in pt_state_dict.items(): + + pt_tuple_key = tuple(pt_key.split(".")) + + has_base_model_prefix = pt_tuple_key[0] == flax_model.base_model_prefix + require_base_model_prefix = (flax_model.base_model_prefix,) + pt_tuple_key in random_flax_state_dict + + if remove_base_model_prefix and has_base_model_prefix: + pt_tuple_key = pt_tuple_key[1:] + elif add_base_model_prefix and require_base_model_prefix: + pt_tuple_key = (flax_model.base_model_prefix,) + pt_tuple_key + + if pt_tuple_key[-1] == "weight" and pt_tuple_key not in random_flax_state_dict: + pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) + pt_tensor = pt_tensor.T + elif pt_tuple_key[-1] == "gamma": + pt_tuple_key = pt_tuple_key[:-1] + ("weight",) + elif pt_tuple_key[-1] == "beta": + pt_tuple_key = pt_tuple_key[:-1] + ("bias",) + + if pt_tuple_key in random_flax_state_dict: + if random_flax_state_dict[pt_tuple_key].shape != pt_tensor.shape: + raise ValueError( + "PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape {random_flax_state_dict[pt_tuple_key].shape}, but is {pt_tensor.shape}." + ) + + # add unexpected weight so that warning is thrown + flax_state_dict[pt_tuple_key] = pt_tensor + + return unflatten_dict(flax_state_dict) diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index 8b245f6546..55d7e37143 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -14,7 +14,7 @@ # limitations under the License. import os -from abc import ABC, abstractmethod +from abc import ABC from functools import partial from pickle import UnpicklingError from typing import Dict, Set, Tuple, Union @@ -29,6 +29,7 @@ from jax.random import PRNGKey from .configuration_utils import PretrainedConfig from .file_utils import FLAX_WEIGHTS_NAME, WEIGHTS_NAME, cached_path, hf_bucket_url, is_offline_mode, is_remote_url +from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict from .utils import logging @@ -121,11 +122,6 @@ class FlaxPreTrainedModel(ABC): ) self._params = freeze(params) - @staticmethod - @abstractmethod - def convert_from_pytorch(pt_state: Dict, config: PretrainedConfig) -> Dict: - raise NotImplementedError() - @classmethod def from_pretrained( cls, @@ -307,25 +303,18 @@ class FlaxPreTrainedModel(ABC): else: resolved_archive_file = None - # Instantiate model. - with open(resolved_archive_file, "rb") as state_f: - try: - if from_pt: - import torch - - state = torch.load(state_f) - - state = convert_state_dict_from_pt(cls, state, config) - else: - state = from_bytes(cls, state_f.read()) - except UnpicklingError: - raise EnvironmentError( - f"Unable to convert pytorch model {archive_file} to Flax deserializable object. " - ) - # init random models model = cls(config, *model_args, **model_kwargs) + if from_pt: + state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file) + else: + with open(resolved_archive_file, "rb") as state_f: + try: + state = from_bytes(cls, state_f.read()) + except UnpicklingError: + raise EnvironmentError(f"Unable to convert {archive_file} to Flax deserializable object. ") + # if model is base model only use model_prefix key if cls.base_model_prefix not in dict(model.params) and cls.base_model_prefix in state: state = state[cls.base_model_prefix] @@ -341,6 +330,10 @@ class FlaxPreTrainedModel(ABC): for missing_key in missing_keys: state[missing_key] = random_state[missing_key] + # remove unexpected keys to not be saved again + for unexpected_key in unexpected_keys: + del state[unexpected_key] + if len(unexpected_keys) > 0: logger.warning( f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when " @@ -393,13 +386,3 @@ class FlaxPreTrainedModel(ABC): with open(os.path.join(save_directory, FLAX_WEIGHTS_NAME), "wb") as f: model_bytes = to_bytes(self.params) f.write(model_bytes) - - -def convert_state_dict_from_pt(model_class: ABC, state: Dict, config: PretrainedConfig): - """ - Converts a PyTorch parameter state dict to an equivalent Flax parameter state dict - """ - state = {k: v.numpy() for k, v in state.items()} - state = model_class.convert_from_pytorch(state, config) - state = unflatten_dict({tuple(k.split(".")): v for k, v in state.items()}) - return state diff --git a/src/transformers/models/bert/modeling_flax_bert.py b/src/transformers/models/bert/modeling_flax_bert.py index 97a219f12c..8a37721d7e 100644 --- a/src/transformers/models/bert/modeling_flax_bert.py +++ b/src/transformers/models/bert/modeling_flax_bert.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict, Tuple +from typing import Callable, Tuple import numpy as np @@ -21,6 +21,8 @@ import flax.linen as nn import jax import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict +from flax.linen import dot_product_attention +from jax import lax from jax.random import PRNGKey from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward @@ -99,17 +101,15 @@ class FlaxBertLayerNorm(nn.Module): hidden_size: int epsilon: float = 1e-6 - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - bias: bool = True # If True, bias (beta) is added. - scale: bool = True # If True, multiply by scale (gamma). When the next layer is linear - # (also e.g. nn.relu), this can be disabled since the scaling will be - # done by the next layer. + dtype: jnp.dtype = jnp.float32 + use_bias: bool = True + scale: bool = True scale_init: Callable[..., np.ndarray] = jax.nn.initializers.ones bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros def setup(self): - self.gamma = self.param("gamma", self.scale_init, (self.hidden_size,)) - self.beta = self.param("beta", self.scale_init, (self.hidden_size,)) + self.weight = self.param("weight", self.scale_init, (self.hidden_size,)) + self.bias = self.param("bias", self.scale_init, (self.hidden_size,)) def __call__(self, x): """ @@ -129,11 +129,11 @@ class FlaxBertLayerNorm(nn.Module): mul = jax.lax.rsqrt(var + self.epsilon) if self.scale: - mul = mul * jnp.asarray(self.gamma) + mul = mul * jnp.asarray(self.weight) y = (x - mean) * mul - if self.bias: - y = y + jnp.asarray(self.beta) + if self.use_bias: + y = y + jnp.asarray(self.bias) return y @@ -167,24 +167,21 @@ class FlaxBertEmbeddings(nn.Module): self.config.vocab_size, self.config.hidden_size, initializer_range=self.config.initializer_range, - name="word_embeddings", dtype=self.dtype, ) self.position_embeddings = FlaxBertEmbedding( self.config.max_position_embeddings, self.config.hidden_size, initializer_range=self.config.initializer_range, - name="position_embeddings", dtype=self.dtype, ) self.token_type_embeddings = FlaxBertEmbedding( self.config.type_vocab_size, self.config.hidden_size, initializer_range=self.config.initializer_range, - name="token_type_embeddings", dtype=self.dtype, ) - self.layer_norm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype) + self.LayerNorm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True): @@ -197,35 +194,116 @@ class FlaxBertEmbeddings(nn.Module): hidden_states = inputs_embeds + jnp.broadcast_to(position_embeds, inputs_embeds.shape) + token_type_embeddings # Layer Norm - hidden_states = self.layer_norm(hidden_states) + hidden_states = self.LayerNorm(hidden_states) hidden_states = self.dropout(hidden_states, deterministic=deterministic) return hidden_states +class FlaxBertSelfAttention(nn.Module): + config: BertConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + if self.config.hidden_size % self.config.num_attention_heads != 0: + raise ValueError( + "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`: {self.config.num_attention_heads}" + ) + + self.query = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + ) + self.key = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + ) + self.value = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + ) + + def __call__(self, hidden_states, attention_mask, deterministic=True): + head_dim = self.config.hidden_size // self.config.num_attention_heads + + query_states = self.query(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + value_states = self.value(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + key_states = self.key(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + + # Convert the boolean attention mask to an attention bias. + if attention_mask is not None: + # attention mask in the form of attention bias + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + attention_bias = lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, -1e10).astype(self.dtype), + ) + else: + attention_bias = None + + dropout_rng = None + if not deterministic and self.dropout_rate > 0.0: + dropout_rng = self.make_rng("dropout") + + attn_output = dot_product_attention( + query_states, + key_states, + value_states, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.config.attention_probs_dropout_prob, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + precision=None, + ) + + return attn_output.reshape(attn_output.shape[:2] + (-1,)) + + +class FlaxBertSelfOutput(nn.Module): + config: BertConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + dtype=self.dtype, + ) + self.LayerNorm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, hidden_states, input_tensor, deterministic: bool = True): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + class FlaxBertAttention(nn.Module): config: BertConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation + dtype: jnp.dtype = jnp.float32 def setup(self): - self.self_attention = nn.attention.SelfAttention( - num_heads=self.config.num_attention_heads, - qkv_features=self.config.hidden_size, - dropout_rate=self.config.attention_probs_dropout_prob, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), - bias_init=jax.nn.initializers.zeros, - name="self", - dtype=self.dtype, - ) - self.layer_norm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype) + self.self = FlaxBertSelfAttention(self.config, dtype=self.dtype) + self.output = FlaxBertSelfOutput(self.config, dtype=self.dtype) def __call__(self, hidden_states, attention_mask, deterministic=True): # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length) # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length) - attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) - self_attn_output = self.self_attention(hidden_states, attention_mask, deterministic=deterministic) - - hidden_states = self.layer_norm(self_attn_output + hidden_states) + attn_output = self.self(hidden_states, attention_mask, deterministic=deterministic) + hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic) return hidden_states @@ -237,7 +315,6 @@ class FlaxBertIntermediate(nn.Module): self.dense = nn.Dense( self.config.intermediate_size, kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), - name="dense", dtype=self.dtype, ) self.activation = ACT2FN[self.config.hidden_act] @@ -256,16 +333,15 @@ class FlaxBertOutput(nn.Module): self.dense = nn.Dense( self.config.hidden_size, kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), - name="dense", dtype=self.dtype, ) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - self.layer_norm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype) + self.LayerNorm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype) def __call__(self, hidden_states, attention_output, deterministic: bool = True): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states, deterministic=deterministic) - hidden_states = self.layer_norm(hidden_states + attention_output) + hidden_states = self.LayerNorm(hidden_states + attention_output) return hidden_states @@ -274,9 +350,9 @@ class FlaxBertLayer(nn.Module): dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): - self.attention = FlaxBertAttention(self.config, name="attention", dtype=self.dtype) - self.intermediate = FlaxBertIntermediate(self.config, name="intermediate", dtype=self.dtype) - self.output = FlaxBertOutput(self.config, name="output", dtype=self.dtype) + self.attention = FlaxBertAttention(self.config, dtype=self.dtype) + self.intermediate = FlaxBertIntermediate(self.config, dtype=self.dtype) + self.output = FlaxBertOutput(self.config, dtype=self.dtype) def __call__(self, hidden_states, attention_mask, deterministic: bool = True): attention_output = self.attention(hidden_states, attention_mask, deterministic=deterministic) @@ -305,10 +381,10 @@ class FlaxBertEncoder(nn.Module): dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): - self.layers = FlaxBertLayerCollection(self.config, name="layer", dtype=self.dtype) + self.layer = FlaxBertLayerCollection(self.config, dtype=self.dtype) def __call__(self, hidden_states, attention_mask, deterministic: bool = True): - return self.layers(hidden_states, attention_mask, deterministic=deterministic) + return self.layer(hidden_states, attention_mask, deterministic=deterministic) class FlaxBertPooler(nn.Module): @@ -319,7 +395,6 @@ class FlaxBertPooler(nn.Module): self.dense = nn.Dense( self.config.hidden_size, kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), - name="dense", dtype=self.dtype, ) @@ -334,14 +409,14 @@ class FlaxBertPredictionHeadTransform(nn.Module): dtype: jnp.dtype = jnp.float32 def setup(self): - self.dense = nn.Dense(self.config.hidden_size, name="dense", dtype=self.dtype) + self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype) self.activation = ACT2FN[self.config.hidden_act] - self.layer_norm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype) + self.LayerNorm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype) def __call__(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = self.activation(hidden_states) - return self.layer_norm(hidden_states) + return self.LayerNorm(hidden_states) class FlaxBertLMPredictionHead(nn.Module): @@ -349,14 +424,10 @@ class FlaxBertLMPredictionHead(nn.Module): dtype: jnp.dtype = jnp.float32 def setup(self): - self.transform = FlaxBertPredictionHeadTransform(self.config, name="transform", dtype=self.dtype) - self.decoder = nn.Dense(self.config.vocab_size, name="decoder", dtype=self.dtype) + self.transform = FlaxBertPredictionHeadTransform(self.config, dtype=self.dtype) + self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype) def __call__(self, hidden_states): - # TODO: The output weights are the same as the input embeddings, but there is - # an output-only bias for each token. - # Need a link between the two variables so that the bias is correctly - # resized with `resize_token_embeddings` hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) return hidden_states @@ -367,10 +438,10 @@ class FlaxBertOnlyMLMHead(nn.Module): dtype: jnp.dtype = jnp.float32 def setup(self): - self.mlm_head = FlaxBertLMPredictionHead(self.config, name="predictions", dtype=self.dtype) + self.predictions = FlaxBertLMPredictionHead(self.config, dtype=self.dtype) def __call__(self, hidden_states): - hidden_states = self.mlm_head(hidden_states) + hidden_states = self.predictions(hidden_states) return hidden_states @@ -405,85 +476,6 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel): return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"] - @staticmethod - def convert_from_pytorch(pt_state: Dict, config: BertConfig) -> Dict: - jax_state = dict(pt_state) - - # Need to change some parameters name to match Flax names so that we don't have to fork any layer - for key, tensor in pt_state.items(): - # Key parts - key_parts = set(key.split(".")) - - # Every dense layer has "kernel" parameters instead of "weight" - if "dense.weight" in key: - del jax_state[key] - key = key.replace("weight", "kernel") - jax_state[key] = tensor - - if "decoder.weight" in key: - del jax_state[key] - key = key.replace("weight", "kernel") - jax_state[key] = tensor.T - - # SelfAttention needs also to replace "weight" by "kernel" - if {"query", "key", "value"} & key_parts: - - # Flax SelfAttention decomposes the heads (num_head, size // num_heads) - if "bias" in key: - jax_state[key] = tensor.reshape((config.num_attention_heads, -1)) - elif "weight": - del jax_state[key] - key = key.replace("weight", "kernel") - tensor = tensor.reshape((config.num_attention_heads, -1, config.hidden_size)).transpose((2, 0, 1)) - jax_state[key] = tensor - - # SelfAttention output is not a separate layer, remove one nesting - if "attention.output.dense" in key: - del jax_state[key] - key = key.replace("attention.output.dense", "attention.self.out") - jax_state[key] = tensor - - # SelfAttention output is not a separate layer, remove nesting on layer norm - if "attention.output.LayerNorm" in key: - del jax_state[key] - key = key.replace("attention.output.LayerNorm", "attention.LayerNorm") - jax_state[key] = tensor - - # There are some transposed parameters w.r.t their PyTorch counterpart - if "intermediate.dense.kernel" in key or "output.dense.kernel" in key or "transform.dense.kernel" in key: - jax_state[key] = tensor.T - - # Self Attention output projection needs to be transposed - if "out.kernel" in key: - jax_state[key] = tensor.reshape((config.hidden_size, config.num_attention_heads, -1)).transpose( - 1, 2, 0 - ) - - # Pooler needs to transpose its kernel - if "pooler.dense.kernel" in key: - jax_state[key] = tensor.T - - # Hack to correctly load some pytorch models - if "predictions.bias" in key: - del jax_state[key] - jax_state[".".join(key.split(".")[:2]) + ".decoder.bias"] = tensor - - # Handle LayerNorm conversion - if "LayerNorm" in key: - del jax_state[key] - - # Replace LayerNorm by layer_norm - new_key = key.replace("LayerNorm", "layer_norm") - - if "weight" in key: - new_key = new_key.replace("weight", "gamma") - elif "bias" in key: - new_key = new_key.replace("bias", "beta") - - jax_state[new_key] = tensor - - return jax_state - @add_start_docstrings( "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.", @@ -541,9 +533,9 @@ class FlaxBertModule(nn.Module): add_pooling_layer: bool = True def setup(self): - self.embeddings = FlaxBertEmbeddings(self.config, name="embeddings", dtype=self.dtype) - self.encoder = FlaxBertEncoder(self.config, name="encoder", dtype=self.dtype) - self.pooler = FlaxBertPooler(self.config, name="pooler", dtype=self.dtype) + self.embeddings = FlaxBertEmbeddings(self.config, dtype=self.dtype) + self.encoder = FlaxBertEncoder(self.config, dtype=self.dtype) + self.pooler = FlaxBertPooler(self.config, dtype=self.dtype) def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True): @@ -602,15 +594,13 @@ class FlaxBertForMaskedLMModule(nn.Module): dtype: jnp.dtype = jnp.float32 def setup(self): - self.encoder = FlaxBertModule( + self.bert = FlaxBertModule( config=self.config, add_pooling_layer=False, - name="bert", ) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - self.mlm_head = FlaxBertOnlyMLMHead( + self.cls = FlaxBertOnlyMLMHead( config=self.config, - name="cls", dtype=self.dtype, ) @@ -618,12 +608,10 @@ class FlaxBertForMaskedLMModule(nn.Module): self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True ): # Model - hidden_states = self.encoder( - input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic - ) + hidden_states = self.bert(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic) # Compute the prediction scores hidden_states = self.dropout(hidden_states, deterministic=deterministic) - logits = self.mlm_head(hidden_states) + logits = self.cls(hidden_states) return (logits,) diff --git a/src/transformers/models/roberta/modeling_flax_roberta.py b/src/transformers/models/roberta/modeling_flax_roberta.py index eeff923fcf..25d8a247cc 100644 --- a/src/transformers/models/roberta/modeling_flax_roberta.py +++ b/src/transformers/models/roberta/modeling_flax_roberta.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict, Tuple +from typing import Callable, Tuple import numpy as np @@ -20,6 +20,8 @@ import flax.linen as nn import jax import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict +from flax.linen import dot_product_attention +from jax import lax from jax.random import PRNGKey from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward @@ -116,17 +118,15 @@ class FlaxRobertaLayerNorm(nn.Module): hidden_size: int epsilon: float = 1e-6 - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - bias: bool = True # If True, bias (beta) is added. - scale: bool = True # If True, multiply by scale (gamma). When the next layer is linear - # (also e.g. nn.relu), this can be disabled since the scaling will be - # done by the next layer. + dtype: jnp.dtype = jnp.float32 + use_bias: bool = True + scale: bool = True scale_init: Callable[..., np.ndarray] = jax.nn.initializers.ones bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros def setup(self): - self.gamma = self.param("gamma", self.scale_init, (self.hidden_size,)) - self.beta = self.param("beta", self.scale_init, (self.hidden_size,)) + self.weight = self.param("weight", self.scale_init, (self.hidden_size,)) + self.bias = self.param("bias", self.scale_init, (self.hidden_size,)) def __call__(self, x): """ @@ -146,11 +146,11 @@ class FlaxRobertaLayerNorm(nn.Module): mul = jax.lax.rsqrt(var + self.epsilon) if self.scale: - mul = mul * jnp.asarray(self.gamma) + mul = mul * jnp.asarray(self.weight) y = (x - mean) * mul - if self.bias: - y = y + jnp.asarray(self.beta) + if self.use_bias: + y = y + jnp.asarray(self.bias) return y @@ -186,26 +186,21 @@ class FlaxRobertaEmbeddings(nn.Module): self.config.vocab_size, self.config.hidden_size, initializer_range=self.config.initializer_range, - name="word_embeddings", dtype=self.dtype, ) self.position_embeddings = FlaxRobertaEmbedding( self.config.max_position_embeddings, self.config.hidden_size, initializer_range=self.config.initializer_range, - name="position_embeddings", dtype=self.dtype, ) self.token_type_embeddings = FlaxRobertaEmbedding( self.config.type_vocab_size, self.config.hidden_size, initializer_range=self.config.initializer_range, - name="token_type_embeddings", dtype=self.dtype, ) - self.layer_norm = FlaxRobertaLayerNorm( - hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype - ) + self.LayerNorm = FlaxRobertaLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True): @@ -218,38 +213,119 @@ class FlaxRobertaEmbeddings(nn.Module): hidden_states = inputs_embeds + jnp.broadcast_to(position_embeds, inputs_embeds.shape) + token_type_embeddings # Layer Norm - hidden_states = self.layer_norm(hidden_states) + hidden_states = self.LayerNorm(hidden_states) hidden_states = self.dropout(hidden_states, deterministic=deterministic) return hidden_states +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->Roberta +class FlaxRobertaSelfAttention(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + if self.config.hidden_size % self.config.num_attention_heads != 0: + raise ValueError( + "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`: {self.config.num_attention_heads}" + ) + + self.query = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + ) + self.key = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + ) + self.value = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + ) + + def __call__(self, hidden_states, attention_mask, deterministic=True): + head_dim = self.config.hidden_size // self.config.num_attention_heads + + query_states = self.query(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + value_states = self.value(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + key_states = self.key(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + + # Convert the boolean attention mask to an attention bias. + if attention_mask is not None: + # attention mask in the form of attention bias + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + attention_bias = lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, -1e10).astype(self.dtype), + ) + else: + attention_bias = None + + dropout_rng = None + if not deterministic and self.dropout_rate > 0.0: + dropout_rng = self.make_rng("dropout") + + attn_output = dot_product_attention( + query_states, + key_states, + value_states, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.config.attention_probs_dropout_prob, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + precision=None, + ) + + return attn_output.reshape(attn_output.shape[:2] + (-1,)) + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput with Bert->Roberta +class FlaxRobertaSelfOutput(nn.Module): + config: RobertaConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + dtype=self.dtype, + ) + self.LayerNorm = FlaxRobertaLayerNorm(hidden_size=self.config.hidden_size) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, hidden_states, input_tensor, deterministic: bool = True): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Roberta class FlaxRobertaAttention(nn.Module): config: RobertaConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation + dtype: jnp.dtype = jnp.float32 def setup(self): - self.self_attention = nn.attention.SelfAttention( - num_heads=self.config.num_attention_heads, - qkv_features=self.config.hidden_size, - dropout_rate=self.config.attention_probs_dropout_prob, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), - bias_init=jax.nn.initializers.zeros, - name="self", - dtype=self.dtype, - ) - self.layer_norm = FlaxRobertaLayerNorm( - hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype - ) + self.self = FlaxRobertaSelfAttention(self.config, dtype=self.dtype) + self.output = FlaxRobertaSelfOutput(self.config, dtype=self.dtype) def __call__(self, hidden_states, attention_mask, deterministic=True): # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length) # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length) - attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) - self_attn_output = self.self_attention(hidden_states, attention_mask, deterministic=deterministic) - - hidden_states = self.layer_norm(self_attn_output + hidden_states) + attn_output = self.self(hidden_states, attention_mask, deterministic=deterministic) + hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic) return hidden_states @@ -262,7 +338,6 @@ class FlaxRobertaIntermediate(nn.Module): self.dense = nn.Dense( self.config.intermediate_size, kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), - name="dense", dtype=self.dtype, ) self.activation = ACT2FN[self.config.hidden_act] @@ -282,18 +357,15 @@ class FlaxRobertaOutput(nn.Module): self.dense = nn.Dense( self.config.hidden_size, kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), - name="dense", dtype=self.dtype, ) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - self.layer_norm = FlaxRobertaLayerNorm( - hidden_size=self.config.hidden_size, name="layer_norm", dtype=self.dtype - ) + self.LayerNorm = FlaxRobertaLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype) def __call__(self, hidden_states, attention_output, deterministic: bool = True): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states, deterministic=deterministic) - hidden_states = self.layer_norm(hidden_states + attention_output) + hidden_states = self.LayerNorm(hidden_states + attention_output) return hidden_states @@ -303,9 +375,9 @@ class FlaxRobertaLayer(nn.Module): dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): - self.attention = FlaxRobertaAttention(self.config, name="attention", dtype=self.dtype) - self.intermediate = FlaxRobertaIntermediate(self.config, name="intermediate", dtype=self.dtype) - self.output = FlaxRobertaOutput(self.config, name="output", dtype=self.dtype) + self.attention = FlaxRobertaAttention(self.config, dtype=self.dtype) + self.intermediate = FlaxRobertaIntermediate(self.config, dtype=self.dtype) + self.output = FlaxRobertaOutput(self.config, dtype=self.dtype) def __call__(self, hidden_states, attention_mask, deterministic: bool = True): attention_output = self.attention(hidden_states, attention_mask, deterministic=deterministic) @@ -336,10 +408,10 @@ class FlaxRobertaEncoder(nn.Module): dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): - self.layers = FlaxRobertaLayerCollection(self.config, name="layer", dtype=self.dtype) + self.layer = FlaxRobertaLayerCollection(self.config, dtype=self.dtype) def __call__(self, hidden_states, attention_mask, deterministic: bool = True): - return self.layers(hidden_states, attention_mask, deterministic=deterministic) + return self.layer(hidden_states, attention_mask, deterministic=deterministic) # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPooler with Bert->Roberta @@ -351,7 +423,6 @@ class FlaxRobertaPooler(nn.Module): self.dense = nn.Dense( self.config.hidden_size, kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), - name="dense", dtype=self.dtype, ) @@ -370,75 +441,6 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel): config_class = RobertaConfig base_model_prefix = "roberta" - @staticmethod - def convert_from_pytorch(pt_state: Dict, config: RobertaConfig) -> Dict: - jax_state = dict(pt_state) - - # Need to change some parameters name to match Flax names so that we don't have to fork any layer - for key, tensor in pt_state.items(): - # Key parts - key_parts = set(key.split(".")) - - # Every dense layer has "kernel" parameters instead of "weight" - if "dense.weight" in key: - del jax_state[key] - key = key.replace("weight", "kernel") - jax_state[key] = tensor - - # SelfAttention needs also to replace "weight" by "kernel" - if {"query", "key", "value"} & key_parts: - - # Flax SelfAttention decomposes the heads (num_head, size // num_heads) - if "bias" in key: - jax_state[key] = tensor.reshape((config.num_attention_heads, -1)) - elif "weight": - del jax_state[key] - key = key.replace("weight", "kernel") - tensor = tensor.reshape((config.num_attention_heads, -1, config.hidden_size)).transpose((2, 0, 1)) - jax_state[key] = tensor - - # SelfAttention output is not a separate layer, remove one nesting - if "attention.output.dense" in key: - del jax_state[key] - key = key.replace("attention.output.dense", "attention.self.out") - jax_state[key] = tensor - - # SelfAttention output is not a separate layer, remove nesting on layer norm - if "attention.output.LayerNorm" in key: - del jax_state[key] - key = key.replace("attention.output.LayerNorm", "attention.LayerNorm") - jax_state[key] = tensor - - # There are some transposed parameters w.r.t their PyTorch counterpart - if "intermediate.dense.kernel" in key or "output.dense.kernel" in key: - jax_state[key] = tensor.T - - # Self Attention output projection needs to be transposed - if "out.kernel" in key: - jax_state[key] = tensor.reshape((config.hidden_size, config.num_attention_heads, -1)).transpose( - 1, 2, 0 - ) - - # Pooler needs to transpose its kernel - if "pooler.dense.kernel" in key: - jax_state[key] = tensor.T - - # Handle LayerNorm conversion - if "LayerNorm" in key: - del jax_state[key] - - # Replace LayerNorm by layer_norm - new_key = key.replace("LayerNorm", "layer_norm") - - if "weight" in key: - new_key = new_key.replace("weight", "gamma") - elif "bias" in key: - new_key = new_key.replace("bias", "beta") - - jax_state[new_key] = tensor - - return jax_state - def init(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs( jnp.zeros(input_shape, dtype="i4"), None, None, None @@ -523,9 +525,9 @@ class FlaxRobertaModule(nn.Module): add_pooling_layer: bool = True def setup(self): - self.embeddings = FlaxRobertaEmbeddings(self.config, name="embeddings", dtype=self.dtype) - self.encoder = FlaxRobertaEncoder(self.config, name="encoder", dtype=self.dtype) - self.pooler = FlaxRobertaPooler(self.config, name="pooler", dtype=self.dtype) + self.embeddings = FlaxRobertaEmbeddings(self.config, dtype=self.dtype) + self.encoder = FlaxRobertaEncoder(self.config, dtype=self.dtype) + self.pooler = FlaxRobertaPooler(self.config, dtype=self.dtype) def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True): diff --git a/tests/test_modeling_flax_bert.py b/tests/test_modeling_flax_bert.py index e201b8db82..c9946021f2 100644 --- a/tests/test_modeling_flax_bert.py +++ b/tests/test_modeling_flax_bert.py @@ -115,6 +115,6 @@ class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase): @slow def test_model_from_pretrained(self): for model_class_name in self.all_model_classes: - model = model_class_name.from_pretrained("bert-base-cased") + model = model_class_name.from_pretrained("bert-base-cased", from_pt=True) outputs = model(np.ones((1, 1))) self.assertIsNotNone(outputs) diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index 0b517a5f43..afa436a9cf 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -27,7 +27,7 @@ if is_flax_available(): import jax import jax.numpy as jnp - from transformers.modeling_flax_utils import convert_state_dict_from_pt + from transformers.modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8 @@ -79,8 +79,8 @@ class FlaxModelTesterMixin: pt_model_class = getattr(transformers, pt_model_class_name) pt_model = pt_model_class(config).eval() - fx_state = convert_state_dict_from_pt(model_class, pt_model.state_dict(), config) fx_model = model_class(config, dtype=jnp.float32) + fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) fx_model.params = fx_state pt_inputs = {k: torch.tensor(v.tolist()) for k, v in inputs_dict.items()} diff --git a/tests/test_modeling_flax_roberta.py b/tests/test_modeling_flax_roberta.py index 318d934ce3..3c75f17d9d 100644 --- a/tests/test_modeling_flax_roberta.py +++ b/tests/test_modeling_flax_roberta.py @@ -115,6 +115,6 @@ class FlaxRobertaModelTest(FlaxModelTesterMixin, unittest.TestCase): @slow def test_model_from_pretrained(self): for model_class_name in self.all_model_classes: - model = model_class_name.from_pretrained("roberta-base") + model = model_class_name.from_pretrained("roberta-base", from_pt=True) outputs = model(np.ones((1, 1))) self.assertIsNotNone(outputs)