From 8c9b5fcbaf27cbf1aa781670d598cf74c07b7e88 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 23 Apr 2021 09:53:09 +0200 Subject: [PATCH] [Flax] Big FlaxBert Refactor (#11364) * improve flax * refactor * typos * Update src/transformers/modeling_flax_utils.py * Apply suggestions from code review * Update src/transformers/modeling_flax_utils.py * fix typo * improve error tolerance * typo * correct nasty saving bug * fix from pretrained * correct tree map * add note * correct weight tying --- .../modeling_flax_pytorch_utils.py | 142 +++++++++++++++++- src/transformers/modeling_flax_utils.py | 14 +- src/transformers/modeling_utils.py | 41 ++++- .../models/bert/modeling_flax_bert.py | 140 ++++++----------- .../models/roberta/modeling_flax_roberta.py | 101 +++---------- tests/test_modeling_flax_common.py | 65 +++++++- 6 files changed, 306 insertions(+), 197 deletions(-) diff --git a/src/transformers/modeling_flax_pytorch_utils.py b/src/transformers/modeling_flax_pytorch_utils.py index f1bc431c6c..3ee48316cd 100644 --- a/src/transformers/modeling_flax_pytorch_utils.py +++ b/src/transformers/modeling_flax_pytorch_utils.py @@ -12,12 +12,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch - TF 2.0 general utilities.""" +""" PyTorch - Flax general utilities.""" import os +from pickle import UnpicklingError -from flax.core.frozen_dict import unfreeze +import numpy as np + +import jax.numpy as jnp +import transformers +from flax.serialization import from_bytes from flax.traverse_util import flatten_dict, unflatten_dict from .utils import logging @@ -37,7 +42,7 @@ def load_pytorch_checkpoint_in_flax_state_dict(flax_model, pytorch_checkpoint_pa import torch # noqa: F401 except ImportError: logger.error( - "Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see " + "Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see " "https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation instructions." ) raise @@ -57,7 +62,7 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): # convert pytorch tensor to numpy pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} - random_flax_state_dict = flatten_dict(unfreeze(flax_model.params)) + random_flax_state_dict = flatten_dict(flax_model.params) flax_state_dict = {} remove_base_model_prefix = (flax_model.base_model_prefix not in flax_model.params) and ( @@ -80,7 +85,12 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): elif add_base_model_prefix and require_base_model_prefix: pt_tuple_key = (flax_model.base_model_prefix,) + pt_tuple_key - if pt_tuple_key[-1] == "weight" and pt_tuple_key not in random_flax_state_dict: + # Correctly rename weight parameters + if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict: + pt_tuple_key = pt_tuple_key[:-1] + ("scale",) + if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict: + pt_tuple_key = pt_tuple_key[:-1] + ("embedding",) + elif pt_tuple_key[-1] == "weight" and pt_tuple_key not in random_flax_state_dict: pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) pt_tensor = pt_tensor.T elif pt_tuple_key[-1] == "gamma": @@ -89,12 +99,128 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): pt_tuple_key = pt_tuple_key[:-1] + ("bias",) if pt_tuple_key in random_flax_state_dict: - if random_flax_state_dict[pt_tuple_key].shape != pt_tensor.shape: + if pt_tensor.shape != random_flax_state_dict[pt_tuple_key].shape: raise ValueError( "PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape {random_flax_state_dict[pt_tuple_key].shape}, but is {pt_tensor.shape}." ) - # add unexpected weight so that warning is thrown - flax_state_dict[pt_tuple_key] = pt_tensor + # also add unexpected weight so that warning is thrown + flax_state_dict[pt_tuple_key] = jnp.asarray(pt_tensor) return unflatten_dict(flax_state_dict) + + +##################### +# Flax => PyTorch # +##################### + + +def load_flax_checkpoint_in_pytorch_model(model, flax_checkpoint_path): + """Load flax checkpoints in a PyTorch model""" + flax_checkpoint_path = os.path.abspath(flax_checkpoint_path) + logger.info(f"Loading Flax weights from {flax_checkpoint_path}") + + # import correct flax class + flax_cls = getattr(transformers, "Flax" + model.__class__.__name__) + + # load flax weight dict + with open(flax_checkpoint_path, "rb") as state_f: + try: + flax_state_dict = from_bytes(flax_cls, state_f.read()) + except UnpicklingError: + raise EnvironmentError(f"Unable to convert {flax_checkpoint_path} to Flax deserializable object. ") + + return load_flax_weights_in_pytorch_model(model, flax_state_dict) + + +def load_flax_weights_in_pytorch_model(pt_model, flax_state): + """Load flax checkpoints in a PyTorch model""" + + try: + import torch # noqa: F401 + except ImportError: + logger.error( + "Loading a Flax weights in PyTorch, requires both PyTorch and Flax to be installed. Please see " + "https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation instructions." + ) + raise + + flax_state_dict = flatten_dict(flax_state) + pt_model_dict = pt_model.state_dict() + + remove_base_model_prefix = (pt_model.base_model_prefix in flax_state) and ( + pt_model.base_model_prefix not in set([k.split(".")[0] for k in pt_model_dict.keys()]) + ) + add_base_model_prefix = (pt_model.base_model_prefix not in flax_state) and ( + pt_model.base_model_prefix in set([k.split(".")[0] for k in pt_model_dict.keys()]) + ) + + # keep track of unexpected & missing keys + unexpected_keys = [] + missing_keys = set(pt_model_dict.keys()) + + for flax_key_tuple, flax_tensor in flax_state_dict.items(): + has_base_model_prefix = flax_key_tuple[0] == pt_model.base_model_prefix + require_base_model_prefix = ".".join((pt_model.base_model_prefix,) + flax_key_tuple) in pt_model_dict + + # adapt flax_key to prepare for loading from/to base model only + if remove_base_model_prefix and has_base_model_prefix: + flax_key_tuple = flax_key_tuple[1:] + elif add_base_model_prefix and require_base_model_prefix: + flax_key_tuple = (pt_model.base_model_prefix,) + flax_key_tuple + + # rename flax weights to PyTorch format + if flax_key_tuple[-1] == "kernel" and ".".join(flax_key_tuple) not in pt_model_dict: + flax_key_tuple = flax_key_tuple[:-1] + ("weight",) + flax_tensor = flax_tensor.T + elif flax_key_tuple[-1] in ["scale", "embedding"]: + flax_key_tuple = flax_key_tuple[:-1] + ("weight",) + + flax_key = ".".join(flax_key_tuple) + + if flax_key in pt_model_dict: + if flax_tensor.shape != pt_model_dict[flax_key].shape: + raise ValueError( + f"Flax checkpoint seems to be incorrect. Weight {flax_key_tuple} was expected" + f"to be of shape {pt_model_dict[flax_key].shape}, but is {flax_tensor.shape}." + ) + else: + # add weight to pytorch dict + flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor + pt_model_dict[flax_key] = torch.from_numpy(flax_tensor) + # remove from missing keys + missing_keys.remove(flax_key) + else: + # weight is not expected by PyTorch model + unexpected_keys.append(flax_key) + + pt_model.load_state_dict(pt_model_dict) + + # re-transform missing_keys to list + missing_keys = list(missing_keys) + + if len(unexpected_keys) > 0: + logger.warning( + "Some weights of the Flax model were not used when " + f"initializing the PyTorch model {pt_model.__class__.__name__}: {unexpected_keys}\n" + f"- This IS expected if you are initializing {pt_model.__class__.__name__} from a Flax model trained on another task " + "or with another architecture (e.g. initializing a BertForSequenceClassification model from a FlaxBertForPreTraining model).\n" + f"- This IS NOT expected if you are initializing {pt_model.__class__.__name__} from a Flax model that you expect " + "to be exactly identical (e.g. initializing a BertForSequenceClassification model from a FlaxBertForSequenceClassification model)." + ) + else: + logger.warning(f"All Flax model weights were used when initializing {pt_model.__class__.__name__}.\n") + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model " + f"and are newly initialized: {missing_keys}\n" + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + else: + logger.warning( + f"All the weights of {pt_model.__class__.__name__} were initialized from the Flax model.\n" + "If your task is similar to the task the model of the checkpoint was trained on, " + f"you can already use {pt_model.__class__.__name__} for predictions without further training." + ) + + return pt_model diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index 7291332ef1..88db0678d9 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team. +# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ from typing import Dict, Set, Tuple, Union import flax.linen as nn import jax import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.core.frozen_dict import FrozenDict, unfreeze from flax.serialization import from_bytes, to_bytes from flax.traverse_util import flatten_dict, unflatten_dict from jax.random import PRNGKey @@ -46,7 +46,7 @@ logger = logging.get_logger(__name__) ACT2FN = { - "gelu": nn.gelu, + "gelu": partial(nn.gelu, approximate=False), "relu": nn.relu, "silu": nn.swish, "swish": nn.swish, @@ -129,7 +129,7 @@ class FlaxPreTrainedModel(ABC): "Some parameters are missing. Make sure that `params` include the following " f"parameters {self.required_params - param_keys}" ) - self._params = freeze(params) + self._params = params @classmethod def from_pretrained( @@ -330,6 +330,10 @@ class FlaxPreTrainedModel(ABC): state = from_bytes(cls, state_f.read()) except UnpicklingError: raise EnvironmentError(f"Unable to convert {archive_file} to Flax deserializable object. ") + # make sure all arrays are stored as jnp.arrays + # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4: + # https://github.com/google/flax/issues/1261 + state = jax.tree_util.tree_map(jnp.array, state) # if model is base model only use model_prefix key if cls.base_model_prefix not in dict(model.params) and cls.base_model_prefix in state: @@ -337,6 +341,7 @@ class FlaxPreTrainedModel(ABC): # flatten dicts state = flatten_dict(state) + random_state = flatten_dict(unfreeze(model.params)) missing_keys = model.required_params - set(state.keys()) @@ -377,6 +382,7 @@ class FlaxPreTrainedModel(ABC): # set correct parameters model.params = unflatten_dict(state) + return model def save_pretrained(self, save_directory: Union[str, os.PathLike]): diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 82a0a99179..d3c2c78f1b 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -30,6 +30,7 @@ from .activations import get_activation from .configuration_utils import PretrainedConfig from .file_utils import ( DUMMY_INPUTS, + FLAX_WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, WEIGHTS_NAME, @@ -875,6 +876,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): this case, ``from_tf`` should be set to :obj:`True` and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. + - A path or url to a model folder containing a `flax checkpoint file` in `.msgpack` format (e.g, + ``./flax_model/`` containing ``flax_model.msgpack``). In this case, ``from_flax`` should be set + to :obj:`True`. - :obj:`None` if you are both providing the configuration and state dictionary (resp. with keyword arguments ``config`` and ``state_dict``). model_args (sequence of positional arguments, `optional`): @@ -907,6 +911,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): from_tf (:obj:`bool`, `optional`, defaults to :obj:`False`): Load the model weights from a TensorFlow checkpoint save file (see docstring of ``pretrained_model_name_or_path`` argument). + from_flax (:obj:`bool`, `optional`, defaults to :obj:`False`): + Load the model weights from a Flax checkpoint save file (see docstring of + ``pretrained_model_name_or_path`` argument). force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. @@ -968,11 +975,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable). >>> config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json') >>> model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config) + >>> # Loading from a Flax checkpoint file instead of a PyTorch model (slower) + >>> model = BertModel.from_pretrained('bert-base-uncased', from_flax=True) + """ config = kwargs.pop("config", None) state_dict = kwargs.pop("state_dict", None) cache_dir = kwargs.pop("cache_dir", None) from_tf = kwargs.pop("from_tf", False) + from_flax = kwargs.pop("from_flax", False) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) @@ -1023,13 +1034,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)): # Load from a TF 2.0 checkpoint in priority if from_tf archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME) + elif from_flax and os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)): + # Load from a Flax checkpoint in priority if from_flax + archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME) elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): # Load from a PyTorch checkpoint archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) else: raise EnvironmentError( - f"Error no file named {[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + '.index']} found in " - f"directory {pretrained_model_name_or_path} or `from_tf` set to False." + f"Error no file named {[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + '.index', FLAX_WEIGHTS_NAME]} found in " + f"directory {pretrained_model_name_or_path} or `from_tf` and `from_flax` set to False." ) elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): archive_file = pretrained_model_name_or_path @@ -1041,9 +1055,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ) archive_file = pretrained_model_name_or_path + ".index" else: + # set correct filename + if from_tf: + filename = TF2_WEIGHTS_NAME + elif from_flax: + filename = FLAX_WEIGHTS_NAME + else: + filename = WEIGHTS_NAME + archive_file = hf_bucket_url( pretrained_model_name_or_path, - filename=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME), + filename=filename, revision=revision, mirror=mirror, ) @@ -1090,7 +1112,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): else: model = cls(config, *model_args, **model_kwargs) - if state_dict is None and not from_tf: + if state_dict is None and not (from_tf or from_flax): try: state_dict = torch.load(resolved_archive_file, map_location="cpu") except Exception: @@ -1120,6 +1142,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions." ) raise + elif from_flax: + try: + from .modeling_flax_pytorch_utils import load_flax_checkpoint_in_pytorch_model + + model = load_flax_checkpoint_in_pytorch_model(model, resolved_archive_file) + except ImportError: + logger.error( + "Loading a Flax model in PyTorch, requires both PyTorch and Flax to be installed. Please see " + "https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation instructions." + ) + raise else: # Convert old format to new format if needed from a PyTorch state_dict old_keys = [] diff --git a/src/transformers/models/bert/modeling_flax_bert.py b/src/transformers/models/bert/modeling_flax_bert.py index c675ddf30d..56a167ee85 100644 --- a/src/transformers/models/bert/modeling_flax_bert.py +++ b/src/transformers/models/bert/modeling_flax_bert.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team. +# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -95,68 +95,6 @@ BERT_INPUTS_DOCSTRING = r""" """ -class FlaxBertLayerNorm(nn.Module): - """ - Layer normalization (https://arxiv.org/abs/1607.06450). Operates on the last axis of the input data. - """ - - hidden_size: int - epsilon: float = 1e-6 - dtype: jnp.dtype = jnp.float32 - use_bias: bool = True - scale: bool = True - scale_init: Callable[..., np.ndarray] = jax.nn.initializers.ones - bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros - - def setup(self): - self.weight = self.param("weight", self.scale_init, (self.hidden_size,)) - self.bias = self.param("bias", self.scale_init, (self.hidden_size,)) - - def __call__(self, x): - """ - Applies layer normalization on the input. It normalizes the activations of the layer for each given example in - a batch independently, rather than across a batch like Batch Normalization. i.e. applies a transformation that - maintains the mean activation within each example close to 0 and the activation standard deviation close to 1 - - Args: - x: the inputs - - Returns: - Normalized inputs (the same shape as inputs). - """ - mean = jnp.mean(x, axis=-1, keepdims=True) - mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True) - var = mean2 - jax.lax.square(mean) - mul = jax.lax.rsqrt(var + self.epsilon) - - if self.scale: - mul = mul * jnp.asarray(self.weight) - y = (x - mean) * mul - - if self.use_bias: - y = y + jnp.asarray(self.bias) - return y - - -class FlaxBertEmbedding(nn.Module): - """ - Specify a new class for doing the embedding stuff as Flax's one use 'embedding' for the parameter name and PyTorch - use 'weight' - """ - - vocab_size: int - hidden_size: int - initializer_range: float - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - init_fn: Callable[..., np.ndarray] = jax.nn.initializers.normal(stddev=self.initializer_range) - self.embeddings = self.param("weight", init_fn, (self.vocab_size, self.hidden_size)) - - def __call__(self, input_ids): - return jnp.take(self.embeddings, input_ids, axis=0) - - class FlaxBertEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings.""" @@ -164,35 +102,37 @@ class FlaxBertEmbeddings(nn.Module): dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): - self.word_embeddings = FlaxBertEmbedding( + self.word_embeddings = nn.Embed( self.config.vocab_size, self.config.hidden_size, - initializer_range=self.config.initializer_range, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), dtype=self.dtype, ) - self.position_embeddings = FlaxBertEmbedding( + self.position_embeddings = nn.Embed( self.config.max_position_embeddings, self.config.hidden_size, - initializer_range=self.config.initializer_range, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), dtype=self.dtype, ) - self.token_type_embeddings = FlaxBertEmbedding( + self.token_type_embeddings = nn.Embed( self.config.type_vocab_size, self.config.hidden_size, - initializer_range=self.config.initializer_range, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), dtype=self.dtype, ) - self.LayerNorm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True): + batch_size, sequence_length = input_ids.shape # Embed - inputs_embeds = self.word_embeddings(jnp.atleast_2d(input_ids.astype("i4"))) - position_embeds = self.position_embeddings(jnp.atleast_2d(position_ids.astype("i4"))) - token_type_embeddings = self.token_type_embeddings(jnp.atleast_2d(token_type_ids.astype("i4"))) + inputs_embeds = self.word_embeddings(input_ids.astype("i4")) + position_embeds = self.position_embeddings(position_ids.astype("i4")) + token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4")) # Sum all embeddings - hidden_states = inputs_embeds + jnp.broadcast_to(position_embeds, inputs_embeds.shape) + token_type_embeddings + hidden_states = inputs_embeds + token_type_embeddings + position_embeds + # hidden_states = hidden_states.reshape((batch_size, sequence_length, -1)) # Layer Norm hidden_states = self.LayerNorm(hidden_states) @@ -281,7 +221,7 @@ class FlaxBertSelfOutput(nn.Module): kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), dtype=self.dtype, ) - self.LayerNorm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) def __call__(self, hidden_states, input_tensor, deterministic: bool = True): @@ -337,7 +277,7 @@ class FlaxBertOutput(nn.Module): dtype=self.dtype, ) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - self.LayerNorm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) def __call__(self, hidden_states, attention_output, deterministic: bool = True): hidden_states = self.dense(hidden_states) @@ -372,7 +312,7 @@ class FlaxBertLayerCollection(nn.Module): ] def __call__(self, hidden_states, attention_mask, deterministic: bool = True): - for layer in self.layers: + for i, layer in enumerate(self.layers): hidden_states = layer(hidden_states, attention_mask, deterministic=deterministic) return hidden_states @@ -412,7 +352,7 @@ class FlaxBertPredictionHeadTransform(nn.Module): def setup(self): self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype) self.activation = ACT2FN[self.config.hidden_act] - self.LayerNorm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) def __call__(self, hidden_states): hidden_states = self.dense(hidden_states) @@ -423,14 +363,22 @@ class FlaxBertPredictionHeadTransform(nn.Module): class FlaxBertLMPredictionHead(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 + bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros def setup(self): self.transform = FlaxBertPredictionHeadTransform(self.config, dtype=self.dtype) - self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype) + self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype, use_bias=False) + self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,)) - def __call__(self, hidden_states): + def __call__(self, hidden_states, shared_embedding=None): hidden_states = self.transform(hidden_states) - hidden_states = self.decoder(hidden_states) + + if shared_embedding is not None: + hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) + else: + hidden_states = self.decoder(hidden_states) + + hidden_states += self.bias return hidden_states @@ -441,8 +389,8 @@ class FlaxBertOnlyMLMHead(nn.Module): def setup(self): self.predictions = FlaxBertLMPredictionHead(self.config, dtype=self.dtype) - def __call__(self, hidden_states): - hidden_states = self.predictions(hidden_states) + def __call__(self, hidden_states, shared_embedding=None): + hidden_states = self.predictions(hidden_states, shared_embedding=shared_embedding) return hidden_states @@ -464,8 +412,8 @@ class FlaxBertPreTrainingHeads(nn.Module): self.predictions = FlaxBertLMPredictionHead(self.config, dtype=self.dtype) self.seq_relationship = nn.Dense(2, dtype=self.dtype) - def __call__(self, hidden_states, pooled_output): - prediction_scores = self.predictions(hidden_states) + def __call__(self, hidden_states, pooled_output, shared_embedding=None): + prediction_scores = self.predictions(hidden_states, shared_embedding=shared_embedding) seq_relationship_score = self.seq_relationship(pooled_output) return prediction_scores, seq_relationship_score @@ -490,7 +438,7 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel): # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") token_type_ids = jnp.ones_like(input_ids) - position_ids = jnp.arange(jnp.atleast_2d(input_ids).shape[-1]) + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) attention_mask = jnp.ones_like(input_ids) params_rng, dropout_rng = jax.random.split(rng) @@ -514,7 +462,7 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel): token_type_ids = jnp.ones_like(input_ids) if position_ids is None: - position_ids = jnp.arange(jnp.atleast_2d(input_ids).shape[-1]) + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) if attention_mask is None: attention_mask = jnp.ones_like(input_ids) @@ -546,7 +494,6 @@ class FlaxBertModule(nn.Module): self.pooler = FlaxBertPooler(self.config, dtype=self.dtype) def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True): - hidden_states = self.embeddings( input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic ) @@ -582,7 +529,15 @@ class FlaxBertForPreTrainingModule(nn.Module): hidden_states, pooled_output = self.bert( input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic ) - prediction_scores, seq_relationship_score = self.cls(hidden_states, pooled_output) + + if self.config.tie_word_embeddings: + shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"] + else: + shared_embedding = None + + prediction_scores, seq_relationship_score = self.cls( + hidden_states, pooled_output, shared_embedding=shared_embedding + ) return (prediction_scores, seq_relationship_score) @@ -612,8 +567,13 @@ class FlaxBertForMaskedLMModule(nn.Module): # Model hidden_states = self.bert(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic) + if self.config.tie_word_embeddings: + shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"] + else: + shared_embedding = None + # Compute the prediction scores - logits = self.cls(hidden_states) + logits = self.cls(hidden_states, shared_embedding=shared_embedding) return (logits,) diff --git a/src/transformers/models/roberta/modeling_flax_roberta.py b/src/transformers/models/roberta/modeling_flax_roberta.py index d4311a9286..ef0c46660f 100644 --- a/src/transformers/models/roberta/modeling_flax_roberta.py +++ b/src/transformers/models/roberta/modeling_flax_roberta.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team. +# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,9 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Tuple - -import numpy as np +from typing import Tuple import flax.linen as nn import jax @@ -110,70 +108,6 @@ ROBERTA_INPUTS_DOCSTRING = r""" """ -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerNorm with Bert->Roberta -class FlaxRobertaLayerNorm(nn.Module): - """ - Layer normalization (https://arxiv.org/abs/1607.06450). Operates on the last axis of the input data. - """ - - hidden_size: int - epsilon: float = 1e-6 - dtype: jnp.dtype = jnp.float32 - use_bias: bool = True - scale: bool = True - scale_init: Callable[..., np.ndarray] = jax.nn.initializers.ones - bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros - - def setup(self): - self.weight = self.param("weight", self.scale_init, (self.hidden_size,)) - self.bias = self.param("bias", self.scale_init, (self.hidden_size,)) - - def __call__(self, x): - """ - Applies layer normalization on the input. It normalizes the activations of the layer for each given example in - a batch independently, rather than across a batch like Batch Normalization. i.e. applies a transformation that - maintains the mean activation within each example close to 0 and the activation standard deviation close to 1 - - Args: - x: the inputs - - Returns: - Normalized inputs (the same shape as inputs). - """ - mean = jnp.mean(x, axis=-1, keepdims=True) - mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True) - var = mean2 - jax.lax.square(mean) - mul = jax.lax.rsqrt(var + self.epsilon) - - if self.scale: - mul = mul * jnp.asarray(self.weight) - y = (x - mean) * mul - - if self.use_bias: - y = y + jnp.asarray(self.bias) - return y - - -# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbedding with Bert->Roberta -class FlaxRobertaEmbedding(nn.Module): - """ - Specify a new class for doing the embedding stuff as Flax's one use 'embedding' for the parameter name and PyTorch - use 'weight' - """ - - vocab_size: int - hidden_size: int - initializer_range: float - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - - def setup(self): - init_fn: Callable[..., np.ndarray] = jax.nn.initializers.normal(stddev=self.initializer_range) - self.embeddings = self.param("weight", init_fn, (self.vocab_size, self.hidden_size)) - - def __call__(self, input_ids): - return jnp.take(self.embeddings, input_ids, axis=0) - - # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings with Bert->Roberta class FlaxRobertaEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings.""" @@ -182,35 +116,37 @@ class FlaxRobertaEmbeddings(nn.Module): dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): - self.word_embeddings = FlaxRobertaEmbedding( + self.word_embeddings = nn.Embed( self.config.vocab_size, self.config.hidden_size, - initializer_range=self.config.initializer_range, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), dtype=self.dtype, ) - self.position_embeddings = FlaxRobertaEmbedding( + self.position_embeddings = nn.Embed( self.config.max_position_embeddings, self.config.hidden_size, - initializer_range=self.config.initializer_range, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), dtype=self.dtype, ) - self.token_type_embeddings = FlaxRobertaEmbedding( + self.token_type_embeddings = nn.Embed( self.config.type_vocab_size, self.config.hidden_size, - initializer_range=self.config.initializer_range, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), dtype=self.dtype, ) - self.LayerNorm = FlaxRobertaLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True): + batch_size, sequence_length = input_ids.shape # Embed - inputs_embeds = self.word_embeddings(jnp.atleast_2d(input_ids.astype("i4"))) - position_embeds = self.position_embeddings(jnp.atleast_2d(position_ids.astype("i4"))) - token_type_embeddings = self.token_type_embeddings(jnp.atleast_2d(token_type_ids.astype("i4"))) + inputs_embeds = self.word_embeddings(input_ids.astype("i4")) + position_embeds = self.position_embeddings(position_ids.astype("i4")) + token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4")) # Sum all embeddings - hidden_states = inputs_embeds + jnp.broadcast_to(position_embeds, inputs_embeds.shape) + token_type_embeddings + hidden_states = inputs_embeds + token_type_embeddings + position_embeds + # hidden_states = hidden_states.reshape((batch_size, sequence_length, -1)) # Layer Norm hidden_states = self.LayerNorm(hidden_states) @@ -301,7 +237,7 @@ class FlaxRobertaSelfOutput(nn.Module): kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), dtype=self.dtype, ) - self.LayerNorm = FlaxRobertaLayerNorm(hidden_size=self.config.hidden_size) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) def __call__(self, hidden_states, input_tensor, deterministic: bool = True): @@ -360,7 +296,7 @@ class FlaxRobertaOutput(nn.Module): dtype=self.dtype, ) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) - self.LayerNorm = FlaxRobertaLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) def __call__(self, hidden_states, attention_output, deterministic: bool = True): hidden_states = self.dense(hidden_states) @@ -397,7 +333,7 @@ class FlaxRobertaLayerCollection(nn.Module): ] def __call__(self, hidden_states, attention_mask, deterministic: bool = True): - for layer in self.layers: + for i, layer in enumerate(self.layers): hidden_states = layer(hidden_states, attention_mask, deterministic=deterministic) return hidden_states @@ -515,7 +451,6 @@ class FlaxRobertaModule(nn.Module): self.pooler = FlaxRobertaPooler(self.config, dtype=self.dtype) def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True): - hidden_states = self.embeddings( input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic ) diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index 462ac4d01d..8d5ca111fd 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -28,7 +28,10 @@ if is_flax_available(): import jax import jax.numpy as jnp - from transformers.modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax + from transformers.modeling_flax_pytorch_utils import ( + convert_pytorch_state_dict_to_flax, + load_flax_weights_in_pytorch_model, + ) os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8 @@ -83,29 +86,32 @@ class FlaxModelTesterMixin: self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).") @is_pt_flax_cross_test - def test_equivalence_flax_pytorch(self): + def test_equivalence_pt_to_flax(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: with self.subTest(model_class.__name__): + # prepare inputs prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()} + + # load corresponding PyTorch class pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning pt_model_class = getattr(transformers, pt_model_class_name) - pt_model = pt_model_class(config).eval() + pt_model = pt_model_class(config).eval() fx_model = model_class(config, dtype=jnp.float32) + fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) fx_model.params = fx_state - pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()} - with torch.no_grad(): pt_outputs = pt_model(**pt_inputs).to_tuple() fx_outputs = fx_model(**prepared_inputs_dict) self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") for fx_output, pt_output in zip(fx_outputs, pt_outputs): - self.assert_almost_equals(fx_output, pt_output.numpy(), 2e-3) + self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3) with tempfile.TemporaryDirectory() as tmpdirname: pt_model.save_pretrained(tmpdirname) @@ -116,7 +122,50 @@ class FlaxModelTesterMixin: len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch" ) for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs): - self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 5e-3) + self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-3) + + @is_pt_flax_cross_test + def test_equivalence_flax_to_pt(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + # prepare inputs + prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()} + + # load corresponding PyTorch class + pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning + pt_model_class = getattr(transformers, pt_model_class_name) + + pt_model = pt_model_class(config).eval() + fx_model = model_class(config, dtype=jnp.float32) + + pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) + + # make sure weights are tied in PyTorch + pt_model.tie_weights() + + with torch.no_grad(): + pt_outputs = pt_model(**pt_inputs).to_tuple() + + fx_outputs = fx_model(**prepared_inputs_dict) + self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") + for fx_output, pt_output in zip(fx_outputs, pt_outputs): + self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3) + + with tempfile.TemporaryDirectory() as tmpdirname: + fx_model.save_pretrained(tmpdirname) + pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True) + + with torch.no_grad(): + pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple() + + self.assertEqual( + len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch" + ) + for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded): + self.assert_almost_equals(fx_output, pt_output.numpy(), 5e-3) def test_from_pretrained_save_pretrained(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -134,7 +183,7 @@ class FlaxModelTesterMixin: outputs_loaded = model_loaded(**prepared_inputs_dict) for output_loaded, output in zip(outputs_loaded, outputs): - self.assert_almost_equals(output_loaded, output, 5e-3) + self.assert_almost_equals(output_loaded, output, 1e-3) def test_jit_compilation(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()