diff --git a/.circleci/config.yml b/.circleci/config.yml index 17a5d80f60..80d0366ea4 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -139,6 +139,31 @@ jobs: - store_artifacts: path: ~/transformers/output.txt destination: test_output.txt + run_tests_flax: + working_directory: ~/transformers + docker: + - image: circleci/python:3.7 + environment: + OMP_NUM_THREADS: 1 + resource_class: xlarge + parallelism: 1 + steps: + - checkout + - restore_cache: + keys: + - v0.3-flax-{{ checksum "setup.py" }} + - v0.3-{{ checksum "setup.py" }} + - run: pip install --upgrade pip + - run: pip install git+https://github.com/huggingface/datasets + - run: sudo pip install .[flax,sklearn,torch,testing] + - save_cache: + key: v0.3-flax-{{ checksum "setup.py" }} + paths: + - '~/.cache/pip' + - run: python -m pytest -n 8 --dist=loadfile -rA -s ./tests/ | tee output.txt + - store_artifacts: + path: ~/transformers/output.txt + destination: test_output.txt run_tests_custom_tokenizers: working_directory: ~/transformers docker: @@ -305,6 +330,7 @@ workflows: - run_tests_torch_and_tf - run_tests_torch - run_tests_tf + - run_tests_flax - build_doc - deploy_doc: *workflow_filters tpu_testing_jobs: diff --git a/setup.py b/setup.py index 5a6ef149d9..e197dcf9fc 100644 --- a/setup.py +++ b/setup.py @@ -87,6 +87,7 @@ extras["tf-cpu"] = [ # "keras2onnx @ git+git://github.com/onnx/keras-onnx.git@cbdc75cb950b16db7f0a67be96a278f8d2953b48#egg=keras2onnx", ] extras["torch"] = ["torch>=1.0"] +extras["flax"] = ["jaxlib==0.1.55", "jax>=0.2.0", "flax==0.2.2"] extras["onnxruntime"] = ["onnxruntime>=1.4.0", "onnxruntime-tools>=1.4.2"] extras["serving"] = ["pydantic", "uvicorn", "fastapi", "starlette"] diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index af91c2e656..6f6c409efe 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -103,6 +103,7 @@ from .file_utils import ( is_apex_available, is_datasets_available, is_faiss_available, + is_flax_available, is_psutil_available, is_py3nvml_available, is_sentencepiece_available, @@ -817,6 +818,10 @@ else: from .utils.dummy_tf_objects import * +if is_flax_available(): + from .modeling_flax_bert import FlaxBertModel + from .modeling_flax_roberta import FlaxRobertaModel + if not is_tf_available() and not is_torch_available(): logger.warning( "Neither PyTorch nor TensorFlow >= 2.0 have been found." diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index 5832bb0993..b4ca47077f 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -34,10 +34,13 @@ from .utils import logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name +ENV_VARS_TRUE_VALUES = {"1", "ON", "YES"} +ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) + try: USE_TF = os.environ.get("USE_TF", "AUTO").upper() USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() - if USE_TORCH in ("1", "ON", "YES", "AUTO") and USE_TF not in ("1", "ON", "YES"): + if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: import torch _torch_available = True # pylint: disable=invalid-name @@ -52,7 +55,7 @@ try: USE_TF = os.environ.get("USE_TF", "AUTO").upper() USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() - if USE_TF in ("1", "ON", "YES", "AUTO") and USE_TORCH not in ("1", "ON", "YES"): + if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES: import tensorflow as tf assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2 @@ -65,6 +68,22 @@ except (ImportError, AssertionError): _tf_available = False # pylint: disable=invalid-name +try: + USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper() + + if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES: + import flax + import jax + + logger.info("JAX version {}, Flax: available".format(jax.__version__)) + logger.info("Flax available: {}".format(flax)) + _flax_available = True + else: + _flax_available = False +except ImportError: + _flax_available = False # pylint: disable=invalid-name + + try: import datasets # noqa: F401 @@ -213,6 +232,10 @@ def is_tf_available(): return _tf_available +def is_flax_available(): + return _flax_available + + def is_torch_tpu_available(): return _torch_tpu_available diff --git a/src/transformers/modeling_flax_auto.py b/src/transformers/modeling_flax_auto.py new file mode 100644 index 0000000000..22b56e25c0 --- /dev/null +++ b/src/transformers/modeling_flax_auto.py @@ -0,0 +1,167 @@ +# coding=utf-8 +# Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Auto Model class. """ + + +from collections import OrderedDict + +from .configuration_auto import AutoConfig, BertConfig, RobertaConfig +from .configuration_utils import PretrainedConfig +from .modeling_flax_bert import FlaxBertModel +from .modeling_flax_roberta import FlaxRobertaModel +from .utils import logging + + +logger = logging.get_logger(__name__) + + +ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict( + (key, value) + for pretrained_map in [ + FlaxBertModel.pretrained_model_archive_map, + FlaxRobertaModel.pretrained_model_archive_map, + ] + for key, value, in pretrained_map.items() +) + +MODEL_MAPPING = OrderedDict( + [ + (RobertaConfig, FlaxRobertaModel), + (BertConfig, FlaxBertModel), + ] +) + + +class FlaxAutoModel(object): + r""" + :class:`~transformers.FlaxAutoModel` is a generic model class + that will be instantiated as one of the base model classes of the library + when created with the `FlaxAutoModel.from_pretrained(pretrained_model_name_or_path)` + or the `FlaxAutoModel.from_config(config)` class methods. + + This class cannot be instantiated using `__init__()` (throws an error). + """ + + def __init__(self): + raise EnvironmentError( + "FlaxAutoModel is designed to be instantiated " + "using the `FlaxAutoModel.from_pretrained(pretrained_model_name_or_path)` or " + "`FlaxAutoModel.from_config(config)` methods." + ) + + @classmethod + def from_config(cls, config): + r"""Instantiates one of the base model classes of the library + from a configuration. + + Args: + config (:class:`~transformers.PretrainedConfig`): + The model class to instantiate is selected based on the configuration class: + + - isInstance of `roberta` configuration class: :class:`~transformers.FlaxRobertaModel` (RoBERTa model) + - isInstance of `bert` configuration class: :class:`~transformers.FlaxBertModel` (Bert model) + Examples: + + config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache. + model = FlaxAutoModel.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')` + """ + for config_class, model_class in MODEL_MAPPING.items(): + if isinstance(config, config_class): + return model_class(config) + raise ValueError( + f"Unrecognized configuration class {config.__class__} " + f"for this kind of FlaxAutoModel: {cls.__name__}.\n" + f"Model type should be one of {', '.join(c.__name__ for c in MODEL_MAPPING.keys())}." + ) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + r"""Instantiates one of the base model classes of the library + from a pre-trained model configuration. + + The `from_pretrained()` method takes care of returning the correct model class instance + based on the `model_type` property of the config object, or when it's missing, + falling back to using pattern matching on the `pretrained_model_name_or_path` string. + + The base model class to instantiate is selected as the first pattern matching + in the `pretrained_model_name_or_path` string (in the following order): + - contains `roberta`: :class:`~transformers.FlaxRobertaModel` (RoBERTa model) + - contains `bert`: :class:`~transformers.FlaxBertModel` (Bert model) + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated) + To train the model, you should first set it back in training mode with `model.train()` + + Args: + pretrained_model_name_or_path: either: + + - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. + - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``. + - a path to a `directory` containing model weights saved using :func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. + - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. + + model_args: (`optional`) Sequence of positional arguments: + All remaining positional arguments will be passed to the underlying model's ``__init__`` method + + config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`: + Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when: + + - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or + - the model was saved using :func:`~transformers.FlaxPreTrainedModel.save_pretrained` and is reloaded by suppling the save directory. + - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory. + + state_dict: (`optional`) dict: + an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file. + This option can be used if you want to create a model from a pretrained configuration but load your own weights. + In this case though, you should check if using :func:`~transformers.FlaxPreTrainedModel.save_pretrained` and :func:`~transformers.FlaxPreTrainedModel.from_pretrained` is not a simpler option. + + cache_dir: (`optional`) string: + Path to a directory in which a downloaded pre-trained model + configuration should be cached if the standard cache should not be used. + + force_download: (`optional`) boolean, default False: + Force to (re-)download the model weights and configuration files and override the cached versions if they exists. + + resume_download: (`optional`) boolean, default False: + Do not delete incompletely recieved file. Attempt to resume the download if such a file exists. + + proxies: (`optional`) dict, default None: + A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. + The proxies are used on each request. + + output_loading_info: (`optional`) boolean: + Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. + + kwargs: (`optional`) Remaining dictionary of keyword arguments: + These arguments will be passed to the configuration and the model. + + Examples:: + + model = FlaxAutoModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache. + model = FlaxAutoModel.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` + assert model.config.output_attention == True + + """ + config = kwargs.pop("config", None) + if not isinstance(config, PretrainedConfig): + config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + + for config_class, model_class in MODEL_MAPPING.items(): + if isinstance(config, config_class): + return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) + raise ValueError( + f"Unrecognized configuration class {config.__class__} " + f"for this kind of FlaxAutoModel: {cls.__name__}.\n" + f"Model type should be one of {', '.join(c.__name__ for c in MODEL_MAPPING.keys())}" + ) diff --git a/src/transformers/modeling_flax_bert.py b/src/transformers/modeling_flax_bert.py new file mode 100644 index 0000000000..2ca6e0935d --- /dev/null +++ b/src/transformers/modeling_flax_bert.py @@ -0,0 +1,438 @@ +# coding=utf-8 +# Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Dict + +import numpy as np + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.linen import compact + +from .configuration_bert import BertConfig +from .file_utils import add_start_docstrings +from .modeling_flax_utils import FlaxPreTrainedModel, gelu +from .utils import logging + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "BertConfig" +_TOKENIZER_FOR_DOC = "BertTokenizer" + + +BERT_START_DOCSTRING = r""" + + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch `torch.nn.Module `__ subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general + usage and behavior. + + Parameters: + config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the configuration. + Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. +""" + +BERT_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~transformers.BertTokenizer`. + See :meth:`transformers.PreTrainedTokenizer.encode` and + :meth:`transformers.PreTrainedTokenizer.__call__` for details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`): + Mask to avoid performing attention on padding token indices. + Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): + Segment token indices to indicate first and second portions of the inputs. + Indices are selected in ``[0, 1]``: + + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + + `What are token type IDs? <../glossary.html#token-type-ids>`_ + position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. + Selected in the range ``[0, config.max_position_embeddings - 1]``. + + `What are position IDs? <../glossary.html#position-ids>`_ + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. + Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + + +class FlaxBertLayerNorm(nn.Module): + """Layer normalization (https://arxiv.org/abs/1607.06450). + Operates on the last axis of the input data. + """ + + epsilon: float = 1e-6 + dtype: jnp.dtype = jnp.float32 + bias: bool = True + scale: bool = True + bias_init: jnp.ndarray = nn.initializers.zeros + scale_init: jnp.ndarray = nn.initializers.ones + + @compact + def __call__(self, x): + """Applies layer normalization on the input. + It normalizes the activations of the layer for each given example in a + batch independently, rather than across a batch like Batch Normalization. + i.e. applies a transformation that maintains the mean activation within + each example close to 0 and the activation standard deviation close to 1. + Args: + x: the inputs + epsilon: A small float added to variance to avoid dividing by zero. + dtype: the dtype of the computation (default: float32). + bias: If True, bias (beta) is added. + scale: If True, multiply by scale (gamma). When the next layer is linear + (also e.g. nn.relu), this can be disabled since the scaling will be done + by the next layer. + bias_init: Initializer for bias, by default, zero. + scale_init: Initializer for scale, by default, one. + Returns: + Normalized inputs (the same shape as inputs). + """ + features = x.shape[-1] + mean = jnp.mean(x, axis=-1, keepdims=True) + mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True) + var = mean2 - jax.lax.square(mean) + mul = jax.lax.rsqrt(var + self.epsilon) + if self.scale: + mul = mul * jnp.asarray(self.param("gamma", self.scale_init, (features,)), self.dtype) + y = (x - mean) * mul + if self.bias: + y = y + jnp.asarray(self.param("beta", self.bias_init, (features,)), self.dtype) + return y + + +class FlaxBertEmbedding(nn.Module): + """ + Specify a new class for doing the embedding stuff + as Flax's one use 'embedding' for the parameter name + and PyTorch use 'weight' + """ + + vocab_size: int + hidden_size: int + emb_init: Callable[..., np.ndarray] = nn.initializers.normal(stddev=0.1) + + @compact + def __call__(self, inputs): + embedding = self.param("weight", self.emb_init, (self.vocab_size, self.hidden_size)) + return jnp.take(embedding, inputs, axis=0) + + +class FlaxBertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + vocab_size: int + hidden_size: int + type_vocab_size: int + max_length: int + + @compact + def __call__(self, input_ids, token_type_ids, position_ids, attention_mask): + + # Embed + w_emb = FlaxBertEmbedding(self.vocab_size, self.hidden_size, name="word_embeddings")( + jnp.atleast_2d(input_ids.astype("i4")) + ) + p_emb = FlaxBertEmbedding(self.max_length, self.hidden_size, name="position_embeddings")( + jnp.atleast_2d(position_ids.astype("i4")) + ) + t_emb = FlaxBertEmbedding(self.type_vocab_size, self.hidden_size, name="token_type_embeddings")( + jnp.atleast_2d(token_type_ids.astype("i4")) + ) + + # Sum all embeddings + summed_emb = w_emb + jnp.broadcast_to(p_emb, w_emb.shape) + t_emb + + # Layer Norm + layer_norm = FlaxBertLayerNorm(name="layer_norm")(summed_emb) + + return layer_norm + + +class FlaxBertAttention(nn.Module): + num_heads: int + head_size: int + + @compact + def __call__(self, hidden_state, attention_mask): + self_att = nn.attention.SelfAttention(num_heads=self.num_heads, qkv_features=self.head_size, name="self")( + hidden_state, attention_mask + ) + + layer_norm = FlaxBertLayerNorm(name="layer_norm")(self_att + hidden_state) + return layer_norm + + +class FlaxBertIntermediate(nn.Module): + output_size: int + + @compact + def __call__(self, hidden_state): + # TODO: Add ACT2FN reference to change activation function + dense = nn.Dense(features=self.output_size, name="dense")(hidden_state) + return gelu(dense) + + +class FlaxBertOutput(nn.Module): + @compact + def __call__(self, intermediate_output, attention_output): + hidden_state = nn.Dense(attention_output.shape[-1], name="dense")(intermediate_output) + hidden_state = FlaxBertLayerNorm(name="layer_norm")(hidden_state + attention_output) + return hidden_state + + +class FlaxBertLayer(nn.Module): + num_heads: int + head_size: int + intermediate_size: int + + @compact + def __call__(self, hidden_state, attention_mask): + attention = FlaxBertAttention(self.num_heads, self.head_size, name="attention")(hidden_state, attention_mask) + intermediate = FlaxBertIntermediate(self.intermediate_size, name="intermediate")(attention) + output = FlaxBertOutput(name="output")(intermediate, attention) + + return output + + +class FlaxBertLayerCollection(nn.Module): + """ + Stores N BertLayer(s) + """ + + num_layers: int + num_heads: int + head_size: int + intermediate_size: int + + @compact + def __call__(self, inputs, attention_mask): + assert self.num_layers > 0, f"num_layers should be >= 1, got ({self.num_layers})" + + # Initialize input / output + input_i = inputs + + # Forward over all encoders + for i in range(self.num_layers): + layer = FlaxBertLayer(self.num_heads, self.head_size, self.intermediate_size, name=f"{i}") + input_i = layer(input_i, attention_mask) + return input_i + + +class FlaxBertEncoder(nn.Module): + num_layers: int + num_heads: int + head_size: int + intermediate_size: int + + @compact + def __call__(self, hidden_state, attention_mask): + layer = FlaxBertLayerCollection( + self.num_layers, self.num_heads, self.head_size, self.intermediate_size, name="layer" + )(hidden_state, attention_mask) + return layer + + +class FlaxBertPooler(nn.Module): + @compact + def __call__(self, hidden_state): + cls_token = hidden_state[:, 0] + out = nn.Dense(hidden_state.shape[-1], name="dense")(cls_token) + return jax.lax.tanh(out) + + +class FlaxBertModule(nn.Module): + vocab_size: int + hidden_size: int + type_vocab_size: int + max_length: int + num_encoder_layers: int + num_heads: int + head_size: int + intermediate_size: int + + @compact + def __call__(self, input_ids, token_type_ids, position_ids, attention_mask): + + # Embedding + embeddings = FlaxBertEmbeddings( + self.vocab_size, self.hidden_size, self.type_vocab_size, self.max_length, name="embeddings" + )(input_ids, token_type_ids, position_ids, attention_mask) + + # N stacked encoding layers + encoder = FlaxBertEncoder( + self.num_encoder_layers, self.num_heads, self.head_size, self.intermediate_size, name="encoder" + )(embeddings, attention_mask) + + pooled = FlaxBertPooler(name="pooler")(encoder) + return encoder, pooled + + +@add_start_docstrings( + "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.", + BERT_START_DOCSTRING, +) +class FlaxBertModel(FlaxPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well + as a decoder, in which case a layer of cross-attention is added between + the self-attention layers, following the architecture described in `Attention is all you need + `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, + Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + """ + + model_class = FlaxBertModule + config_class = BertConfig + base_model_prefix = "bert" + + @staticmethod + def convert_from_pytorch(pt_state: Dict, config: BertConfig) -> Dict: + jax_state = dict(pt_state) + + # Need to change some parameters name to match Flax names so that we don't have to fork any layer + for key, tensor in pt_state.items(): + # Key parts + key_parts = set(key.split(".")) + + # Every dense layer has "kernel" parameters instead of "weight" + if "dense.weight" in key: + del jax_state[key] + key = key.replace("weight", "kernel") + jax_state[key] = tensor + + # SelfAttention needs also to replace "weight" by "kernel" + if {"query", "key", "value"} & key_parts: + + # Flax SelfAttention decomposes the heads (num_head, size // num_heads) + if "bias" in key: + jax_state[key] = tensor.reshape((config.num_attention_heads, -1)) + elif "weight": + del jax_state[key] + key = key.replace("weight", "kernel") + tensor = tensor.reshape((config.num_attention_heads, -1, config.hidden_size)).transpose((2, 0, 1)) + jax_state[key] = tensor + + # SelfAttention output is not a separate layer, remove one nesting + if "attention.output.dense" in key: + del jax_state[key] + key = key.replace("attention.output.dense", "attention.self.out") + jax_state[key] = tensor + + # SelfAttention output is not a separate layer, remove nesting on layer norm + if "attention.output.LayerNorm" in key: + del jax_state[key] + key = key.replace("attention.output.LayerNorm", "attention.LayerNorm") + jax_state[key] = tensor + + # There are some transposed parameters w.r.t their PyTorch counterpart + if "intermediate.dense.kernel" in key or "output.dense.kernel" in key: + jax_state[key] = tensor.T + + # Self Attention output projection needs to be transposed + if "out.kernel" in key: + jax_state[key] = tensor.reshape((config.hidden_size, config.num_attention_heads, -1)).transpose( + 1, 2, 0 + ) + + # Pooler needs to transpose its kernel + if "pooler.dense.kernel" in key: + jax_state[key] = tensor.T + + # Handle LayerNorm conversion + if "LayerNorm" in key: + del jax_state[key] + + # Replace LayerNorm by layer_norm + new_key = key.replace("LayerNorm", "layer_norm") + + if "weight" in key: + new_key = new_key.replace("weight", "gamma") + elif "bias" in key: + new_key = new_key.replace("bias", "beta") + + jax_state[new_key] = tensor + + return jax_state + + def __init__(self, config: BertConfig, state: dict, seed: int = 0, **kwargs): + model = FlaxBertModule( + vocab_size=config.vocab_size, + hidden_size=config.hidden_size, + type_vocab_size=config.type_vocab_size, + max_length=config.max_position_embeddings, + num_encoder_layers=config.num_hidden_layers, + num_heads=config.num_attention_heads, + head_size=config.hidden_size, + intermediate_size=config.intermediate_size, + ) + + super().__init__(config, model, state, seed) + + @property + def module(self) -> nn.Module: + return self._module + + def __call__(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None): + if token_type_ids is None: + token_type_ids = jnp.ones_like(input_ids) + + if position_ids is None: + position_ids = jnp.arange(jnp.atleast_2d(input_ids).shape[-1]) + + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + return self.model.apply( + {"params": self.params}, + jnp.array(input_ids, dtype="i4"), + jnp.array(token_type_ids, dtype="i4"), + jnp.array(position_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + ) diff --git a/src/transformers/modeling_flax_roberta.py b/src/transformers/modeling_flax_roberta.py new file mode 100644 index 0000000000..b6bf6e2c5d --- /dev/null +++ b/src/transformers/modeling_flax_roberta.py @@ -0,0 +1,450 @@ +# coding=utf-8 +# Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Dict + +import numpy as np + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.linen import compact + +from .configuration_roberta import RobertaConfig +from .file_utils import add_start_docstrings +from .modeling_flax_utils import FlaxPreTrainedModel, gelu +from .utils import logging + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "RobertaConfig" +_TOKENIZER_FOR_DOC = "RobertaTokenizer" + + +ROBERTA_START_DOCSTRING = r""" + + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch `torch.nn.Module `__ subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general + usage and behavior. + + Parameters: + config (:class:`~transformers.RobertaConfig`): Model configuration class with all the parameters of the + model. Initializing with a config file does not load the weights associated with the model, only the configuration. + Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. +""" + +ROBERTA_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~transformers.RobertaTokenizer`. + See :meth:`transformers.PreTrainedTokenizer.encode` and + :meth:`transformers.PreTrainedTokenizer.__call__` for details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`): + Mask to avoid performing attention on padding token indices. + Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **maked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): + Segment token indices to indicate first and second portions of the inputs. + Indices are selected in ``[0, 1]``: + + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + + `What are token type IDs? <../glossary.html#token-type-ids>`_ + position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. + Selected in the range ``[0, config.max_position_embeddings - 1]``. + + `What are position IDs? <../glossary.html#position-ids>`_ + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. + Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + + +# Copied from transformers.modeling_flax_bert.FlaxBertLayerNorm with Bert->Roberta +class FlaxRobertaLayerNorm(nn.Module): + """Layer normalization (https://arxiv.org/abs/1607.06450). + Operates on the last axis of the input data. + """ + + epsilon: float = 1e-6 + dtype: jnp.dtype = jnp.float32 + bias: bool = True + scale: bool = True + bias_init: jnp.ndarray = nn.initializers.zeros + scale_init: jnp.ndarray = nn.initializers.ones + + @compact + def __call__(self, x): + """Applies layer normalization on the input. + It normalizes the activations of the layer for each given example in a + batch independently, rather than across a batch like Batch Normalization. + i.e. applies a transformation that maintains the mean activation within + each example close to 0 and the activation standard deviation close to 1. + Args: + x: the inputs + epsilon: A small float added to variance to avoid dividing by zero. + dtype: the dtype of the computation (default: float32). + bias: If True, bias (beta) is added. + scale: If True, multiply by scale (gamma). When the next layer is linear + (also e.g. nn.relu), this can be disabled since the scaling will be done + by the next layer. + bias_init: Initializer for bias, by default, zero. + scale_init: Initializer for scale, by default, one. + Returns: + Normalized inputs (the same shape as inputs). + """ + features = x.shape[-1] + mean = jnp.mean(x, axis=-1, keepdims=True) + mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True) + var = mean2 - jax.lax.square(mean) + mul = jax.lax.rsqrt(var + self.epsilon) + if self.scale: + mul = mul * jnp.asarray(self.param("gamma", self.scale_init, (features,)), self.dtype) + y = (x - mean) * mul + if self.bias: + y = y + jnp.asarray(self.param("beta", self.bias_init, (features,)), self.dtype) + return y + + +# Copied from transformers.modeling_flax_bert.FlaxBertEmbedding with Bert->Roberta +class FlaxRobertaEmbedding(nn.Module): + """ + Specify a new class for doing the embedding stuff + as Flax's one use 'embedding' for the parameter name + and PyTorch use 'weight' + """ + + vocab_size: int + hidden_size: int + emb_init: Callable[..., np.ndarray] = nn.initializers.normal(stddev=0.1) + + @compact + def __call__(self, inputs): + embedding = self.param("weight", self.emb_init, (self.vocab_size, self.hidden_size)) + return jnp.take(embedding, inputs, axis=0) + + +# Copied from transformers.modeling_flax_bert.FlaxBertEmbeddings with Bert->Roberta +class FlaxRobertaEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + vocab_size: int + hidden_size: int + type_vocab_size: int + max_length: int + + @compact + def __call__(self, input_ids, token_type_ids, position_ids, attention_mask): + + # Embed + w_emb = FlaxRobertaEmbedding(self.vocab_size, self.hidden_size, name="word_embeddings")( + jnp.atleast_2d(input_ids.astype("i4")) + ) + p_emb = FlaxRobertaEmbedding(self.max_length, self.hidden_size, name="position_embeddings")( + jnp.atleast_2d(position_ids.astype("i4")) + ) + t_emb = FlaxRobertaEmbedding(self.type_vocab_size, self.hidden_size, name="token_type_embeddings")( + jnp.atleast_2d(token_type_ids.astype("i4")) + ) + + # Sum all embeddings + summed_emb = w_emb + jnp.broadcast_to(p_emb, w_emb.shape) + t_emb + + # Layer Norm + layer_norm = FlaxRobertaLayerNorm(name="layer_norm")(summed_emb) + + return layer_norm + + +# Copied from transformers.modeling_flax_bert.FlaxBertAttention with Bert->Roberta +class FlaxRobertaAttention(nn.Module): + num_heads: int + head_size: int + + @compact + def __call__(self, hidden_state, attention_mask): + self_att = nn.attention.SelfAttention(num_heads=self.num_heads, qkv_features=self.head_size, name="self")( + hidden_state, attention_mask + ) + + layer_norm = FlaxRobertaLayerNorm(name="layer_norm")(self_att + hidden_state) + return layer_norm + + +# Copied from transformers.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta +class FlaxRobertaIntermediate(nn.Module): + output_size: int + + @compact + def __call__(self, hidden_state): + # TODO: Add ACT2FN reference to change activation function + dense = nn.Dense(features=self.output_size, name="dense")(hidden_state) + return gelu(dense) + + +# Copied from transformers.modeling_flax_bert.FlaxBertOutput with Bert->Roberta +class FlaxRobertaOutput(nn.Module): + @compact + def __call__(self, intermediate_output, attention_output): + hidden_state = nn.Dense(attention_output.shape[-1], name="dense")(intermediate_output) + hidden_state = FlaxRobertaLayerNorm(name="layer_norm")(hidden_state + attention_output) + return hidden_state + + +class FlaxRobertaLayer(nn.Module): + num_heads: int + head_size: int + intermediate_size: int + + @compact + def __call__(self, hidden_state, attention_mask): + attention = FlaxRobertaAttention(self.num_heads, self.head_size, name="attention")( + hidden_state, attention_mask + ) + intermediate = FlaxRobertaIntermediate(self.intermediate_size, name="intermediate")(attention) + output = FlaxRobertaOutput(name="output")(intermediate, attention) + + return output + + +# Copied from transformers.modeling_flax_bert.FlaxBertLayerCollection with Bert->Roberta +class FlaxRobertaLayerCollection(nn.Module): + """ + Stores N RobertaLayer(s) + """ + + num_layers: int + num_heads: int + head_size: int + intermediate_size: int + + @compact + def __call__(self, inputs, attention_mask): + assert self.num_layers > 0, f"num_layers should be >= 1, got ({self.num_layers})" + + # Initialize input / output + input_i = inputs + + # Forward over all encoders + for i in range(self.num_layers): + layer = FlaxRobertaLayer(self.num_heads, self.head_size, self.intermediate_size, name=f"{i}") + input_i = layer(input_i, attention_mask) + return input_i + + +# Copied from transformers.modeling_flax_bert.FlaxBertEncoder with Bert->Roberta +class FlaxRobertaEncoder(nn.Module): + num_layers: int + num_heads: int + head_size: int + intermediate_size: int + + @compact + def __call__(self, hidden_state, attention_mask): + layer = FlaxRobertaLayerCollection( + self.num_layers, self.num_heads, self.head_size, self.intermediate_size, name="layer" + )(hidden_state, attention_mask) + return layer + + +# Copied from transformers.modeling_flax_bert.FlaxBertPooler with Bert->Roberta +class FlaxRobertaPooler(nn.Module): + @compact + def __call__(self, hidden_state): + cls_token = hidden_state[:, 0] + out = nn.Dense(hidden_state.shape[-1], name="dense")(cls_token) + return jax.lax.tanh(out) + + +# Copied from transformers.modeling_flax_bert.FlaxBertModule with Bert->Roberta +class FlaxRobertaModule(nn.Module): + vocab_size: int + hidden_size: int + type_vocab_size: int + max_length: int + num_encoder_layers: int + num_heads: int + head_size: int + intermediate_size: int + + @compact + def __call__(self, input_ids, token_type_ids, position_ids, attention_mask): + + # Embedding + embeddings = FlaxRobertaEmbeddings( + self.vocab_size, self.hidden_size, self.type_vocab_size, self.max_length, name="embeddings" + )(input_ids, token_type_ids, position_ids, attention_mask) + + # N stacked encoding layers + encoder = FlaxRobertaEncoder( + self.num_encoder_layers, self.num_heads, self.head_size, self.intermediate_size, name="encoder" + )(embeddings, attention_mask) + + pooled = FlaxRobertaPooler(name="pooler")(encoder) + return encoder, pooled + + +@add_start_docstrings( + "The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.", + ROBERTA_START_DOCSTRING, +) +class FlaxRobertaModel(FlaxPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well + as a decoder, in which case a layer of cross-attention is added between + the self-attention layers, following the architecture described in `Attention is all you need`_ by Ashish Vaswani, + Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + """ + + model_class = FlaxRobertaModule + config_class = RobertaConfig + base_model_prefix = "roberta" + + @staticmethod + def convert_from_pytorch(pt_state: Dict, config: RobertaConfig) -> Dict: + jax_state = dict(pt_state) + + # Need to change some parameters name to match Flax names so that we don't have to fork any layer + for key, tensor in pt_state.items(): + # Key parts + key_parts = set(key.split(".")) + + # Every dense layer has "kernel" parameters instead of "weight" + if "dense.weight" in key: + del jax_state[key] + key = key.replace("weight", "kernel") + jax_state[key] = tensor + + # SelfAttention needs also to replace "weight" by "kernel" + if {"query", "key", "value"} & key_parts: + + # Flax SelfAttention decomposes the heads (num_head, size // num_heads) + if "bias" in key: + jax_state[key] = tensor.reshape((config.num_attention_heads, -1)) + elif "weight": + del jax_state[key] + key = key.replace("weight", "kernel") + tensor = tensor.reshape((config.num_attention_heads, -1, config.hidden_size)).transpose((2, 0, 1)) + jax_state[key] = tensor + + # SelfAttention output is not a separate layer, remove one nesting + if "attention.output.dense" in key: + del jax_state[key] + key = key.replace("attention.output.dense", "attention.self.out") + jax_state[key] = tensor + + # SelfAttention output is not a separate layer, remove nesting on layer norm + if "attention.output.LayerNorm" in key: + del jax_state[key] + key = key.replace("attention.output.LayerNorm", "attention.LayerNorm") + jax_state[key] = tensor + + # There are some transposed parameters w.r.t their PyTorch counterpart + if "intermediate.dense.kernel" in key or "output.dense.kernel" in key: + jax_state[key] = tensor.T + + # Self Attention output projection needs to be transposed + if "out.kernel" in key: + jax_state[key] = tensor.reshape((config.hidden_size, config.num_attention_heads, -1)).transpose( + 1, 2, 0 + ) + + # Pooler needs to transpose its kernel + if "pooler.dense.kernel" in key: + jax_state[key] = tensor.T + + # Handle LayerNorm conversion + if "LayerNorm" in key: + del jax_state[key] + + # Replace LayerNorm by layer_norm + new_key = key.replace("LayerNorm", "layer_norm") + + if "weight" in key: + new_key = new_key.replace("weight", "gamma") + elif "bias" in key: + new_key = new_key.replace("bias", "beta") + + jax_state[new_key] = tensor + + return jax_state + + def __init__(self, config: RobertaConfig, state: dict, seed: int = 0, **kwargs): + model = FlaxRobertaModule( + vocab_size=config.vocab_size, + hidden_size=config.hidden_size, + type_vocab_size=config.type_vocab_size, + max_length=config.max_position_embeddings, + num_encoder_layers=config.num_hidden_layers, + num_heads=config.num_attention_heads, + head_size=config.hidden_size, + intermediate_size=config.intermediate_size, + ) + + super().__init__(config, model, state, seed) + + @property + def module(self) -> nn.Module: + return self._module + + def __call__(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None): + if token_type_ids is None: + token_type_ids = jnp.ones_like(input_ids) + + if position_ids is None: + position_ids = np.arange( + self.config.pad_token_id + 1, np.atleast_2d(input_ids).shape[-1] + self.config.pad_token_id + 1 + ) + + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + return self.model.apply( + {"params": self.params}, + jnp.array(input_ids, dtype="i4"), + jnp.array(token_type_ids, dtype="i4"), + jnp.array(position_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + ) diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py new file mode 100644 index 0000000000..6b88b645a1 --- /dev/null +++ b/src/transformers/modeling_flax_utils.py @@ -0,0 +1,194 @@ +# coding=utf-8 +# Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from abc import ABC, abstractmethod +from pickle import UnpicklingError +from typing import Dict + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.serialization import to_bytes +from flax.traverse_util import unflatten_dict +from jax.random import PRNGKey + +from .configuration_utils import PretrainedConfig +from .file_utils import WEIGHTS_NAME, cached_path, hf_bucket_url, is_remote_url +from .utils import logging + + +logger = logging.get_logger(__name__) + + +@jax.jit +def gelu(x): + r"""Gaussian error linear unit activation function. + + Computes the element-wise function: + + .. math:: + \mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left( + \sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right) + + We explicitly use the approximation rather than the exact formulation for + speed. For more information, see `Gaussian Error Linear Units (GELUs) + `_, section 2. + """ + return x * 0.5 * (1.0 + jax.lax.erf(x / jnp.sqrt(2.0))) + + +ACT2FN = { + "gelu": nn.gelu, + "relu": nn.relu, + "swish": nn.swish, + "gelu_new": gelu, +} + + +class FlaxPreTrainedModel(ABC): + config_class = None + pretrained_model_archive_map = {} + base_model_prefix = "" + model_class = None + + def __init__(self, config: PretrainedConfig, module: nn.Module, params: Dict, seed: int = 0): + if config is None: + raise ValueError("config cannot be None") + + if module is None: + raise ValueError("module cannot be None") + + if params is None: + raise ValueError("state cannot be None") + + # Those are private to be exposed as typed property on derived classes. + self._config = config + self._module = module + + # Those are public as their type is generic to every derived classes. + self.key = PRNGKey(seed) + self.params = params + self.model = module + + @property + def config(self) -> PretrainedConfig: + return self._config + + @staticmethod + @abstractmethod + def convert_from_pytorch(pt_state: Dict, config: PretrainedConfig) -> Dict: + raise NotImplementedError() + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + r""" + Instantiate a pretrained Flax model from a pre-trained model configuration. + """ + config = kwargs.pop("config", None) + # state_dict = kwargs.pop("state_dict", None) + cache_dir = kwargs.pop("cache_dir", None) + # from_tf = kwargs.pop("from_tf", False) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + # output_loading_info = kwargs.pop("output_loading_info", False) + local_files_only = kwargs.pop("local_files_only", False) + use_cdn = kwargs.pop("use_cdn", True) + + # Load config if we don't provide a configuration + if not isinstance(config, PretrainedConfig): + config_path = config if config is not None else pretrained_model_name_or_path + config, model_kwargs = cls.config_class.from_pretrained( + config_path, + *model_args, + cache_dir=cache_dir, + return_unused_kwargs=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + **kwargs, + ) + else: + model_kwargs = kwargs + + # Load model + if pretrained_model_name_or_path is not None: + if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): + archive_file = pretrained_model_name_or_path + else: + archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, use_cdn=use_cdn) + + # redirect to the cache, if necessary + try: + resolved_archive_file = cached_path( + archive_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + ) + except EnvironmentError: + if pretrained_model_name_or_path in cls.pretrained_model_archive_map: + msg = f"Couldn't reach server at '{archive_file}' to download pretrained weights." + else: + msg = ( + f"Model name '{pretrained_model_name_or_path}' " + f"was not found in model name list ({', '.join(cls.pretrained_model_archive_map.keys())}). " + f"We assumed '{archive_file}' was a path or url to model weight files but " + "couldn't find any such file at this path or url." + ) + raise EnvironmentError(msg) + + if resolved_archive_file == archive_file: + logger.info(f"loading weights file {archive_file}") + else: + logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}") + else: + resolved_archive_file = None + + # Instantiate model. + with open(resolved_archive_file, "rb") as state_f: + try: + from flax.serialization import from_bytes + + state = from_bytes(cls.model_class, state_f) + except TypeError: + try: + import torch + + state = torch.load(state_f) + state = {k: v.numpy() for k, v in state.items()} + state = cls.convert_from_pytorch(state, config) + state = unflatten_dict({tuple(k.split(".")[1:]): v for k, v in state.items()}) + except UnpicklingError: + raise EnvironmentError( + f"Unable to convert model {archive_file} to Flax deserializable object. " + "Supported format are PyTorch archive or Flax msgpack" + ) + + return cls(config, state, *model_args, **model_kwargs) + + def save_pretrained(self, folder): + folder_abs = os.path.abspath(folder) + + if not os.path.exists(folder_abs): + os.mkdir(folder_abs) + + with open(os.path.join(folder_abs, f"{self._config.model_type}.flax", "wb")) as f: + model_bytes = to_bytes(self.params) + f.write(model_bytes) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index b842150a5e..1d83b73b4f 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -13,6 +13,7 @@ from pathlib import Path from .file_utils import ( _datasets_available, _faiss_available, + _flax_available, _sentencepiece_available, _tf_available, _tokenizers_available, @@ -115,6 +116,18 @@ def require_tf(test_case): return test_case +def require_flax(test_case): + """ + Decorator marking a test that requires JAX & Flax + + These tests are skipped when one / both are not installed + + """ + if not _flax_available: + test_case = unittest.skip("test requires JAX & Flax")(test_case) + return test_case + + def require_sentencepiece(test_case): """ Decorator marking a test that requires SentencePiece. diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index a79dbb8fbf..2ec3614bd2 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -33,6 +33,7 @@ from .file_utils import ( add_end_docstrings, cached_path, hf_bucket_url, + is_flax_available, is_remote_url, is_tf_available, is_tokenizers_available, @@ -47,6 +48,8 @@ if is_tf_available(): if is_torch_available(): import torch +if is_flax_available(): + import jax.numpy as jnp if is_tokenizers_available(): from tokenizers import AddedToken @@ -143,6 +146,7 @@ class TensorType(ExplicitEnum): PYTORCH = "pt" TENSORFLOW = "tf" NUMPY = "np" + JAX = "jax" class CharSpan(NamedTuple): @@ -559,18 +563,27 @@ class BatchEncoding(UserDict): tensor_type = TensorType(tensor_type) # Get a function reference for the correct framework - if tensor_type == TensorType.TENSORFLOW and is_tf_available(): - as_tensor = tf.constant - elif tensor_type == TensorType.PYTORCH and is_torch_available(): - as_tensor = torch.tensor - elif tensor_type == TensorType.NUMPY: - as_tensor = np.asarray - else: - raise ImportError( - "Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format( - tensor_type + if tensor_type == TensorType.TENSORFLOW: + if not is_tf_available(): + raise ImportError( + "Unable to convert output to TensorFlow tensors format, TensorFlow is not installed." ) - ) + as_tensor = tf.constant + elif tensor_type == TensorType.PYTORCH: + if not is_torch_available(): + raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.") + as_tensor = torch.tensor + elif tensor_type == TensorType.JAX: + if not is_flax_available(): + raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.") + as_tensor = jnp.array + else: + as_tensor = np.asarray + # (mfuntowicz: This code is unreachable) + # else: + # raise ImportError( + # "Unable to convert output to tensors format {}".format(tensor_type) + # ) # Do the tensor conversion in batch for key, value in self.items(): diff --git a/tests/test_flax_auto.py b/tests/test_flax_auto.py new file mode 100644 index 0000000000..322c98b77a --- /dev/null +++ b/tests/test_flax_auto.py @@ -0,0 +1,64 @@ +import unittest + +from transformers import AutoConfig, AutoTokenizer, BertConfig, TensorType, is_flax_available +from transformers.testing_utils import require_flax, slow + + +if is_flax_available(): + import jax + from transformers.modeling_flax_auto import FlaxAutoModel + from transformers.modeling_flax_bert import FlaxBertModel + from transformers.modeling_flax_roberta import FlaxRobertaModel + + +@require_flax +class FlaxAutoModelTest(unittest.TestCase): + @slow + def test_bert_from_pretrained(self): + for model_name in ["bert-base-cased", "bert-large-uncased"]: + with self.subTest(model_name): + config = AutoConfig.from_pretrained(model_name) + self.assertIsNotNone(config) + self.assertIsInstance(config, BertConfig) + + model = FlaxAutoModel.from_pretrained(model_name) + self.assertIsNotNone(model) + self.assertIsInstance(model, FlaxBertModel) + + @slow + def test_roberta_from_pretrained(self): + for model_name in ["roberta-base-cased", "roberta-large-uncased"]: + with self.subTest(model_name): + config = AutoConfig.from_pretrained(model_name) + self.assertIsNotNone(config) + self.assertIsInstance(config, BertConfig) + + model = FlaxAutoModel.from_pretrained(model_name) + self.assertIsNotNone(model) + self.assertIsInstance(model, FlaxRobertaModel) + + @slow + def test_bert_jax_jit(self): + for model_name in ["bert-base-cased", "bert-large-uncased"]: + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = FlaxBertModel.from_pretrained(model_name) + tokens = tokenizer("Do you support jax jitted function?", return_tensors=TensorType.JAX) + + @jax.jit + def eval(**kwargs): + return model(**kwargs) + + eval(**tokens).block_until_ready() + + @slow + def test_roberta_jax_jit(self): + for model_name in ["roberta-base-cased", "roberta-large-uncased"]: + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = FlaxRobertaModel.from_pretrained(model_name) + tokens = tokenizer("Do you support jax jitted function?", return_tensors=TensorType.JAX) + + @jax.jit + def eval(**kwargs): + return model(**kwargs) + + eval(**tokens).block_until_ready() diff --git a/tests/test_modeling_flax_bert.py b/tests/test_modeling_flax_bert.py new file mode 100644 index 0000000000..3bd67c35d4 --- /dev/null +++ b/tests/test_modeling_flax_bert.py @@ -0,0 +1,42 @@ +import unittest + +from numpy import ndarray + +from transformers import TensorType, is_flax_available, is_torch_available +from transformers.testing_utils import require_flax, require_torch +from transformers.tokenization_bert_fast import BertTokenizerFast + + +if is_flax_available(): + from transformers.modeling_flax_bert import FlaxBertModel + +if is_torch_available(): + import torch + + from transformers.modeling_bert import BertModel + + +@require_flax +@require_torch +class FlaxBertModelTest(unittest.TestCase): + def test_from_pytorch(self): + with torch.no_grad(): + with self.subTest("bert-base-cased"): + tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased") + fx_model = FlaxBertModel.from_pretrained("bert-base-cased") + pt_model = BertModel.from_pretrained("bert-base-cased") + + # Check for simple input + pt_inputs = tokenizer.encode_plus("This is a simple input", return_tensors=TensorType.PYTORCH) + fx_inputs = tokenizer.encode_plus("This is a simple input", return_tensors=TensorType.JAX) + pt_outputs = pt_model(**pt_inputs) + fx_outputs = fx_model(**fx_inputs) + + self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") + + for fx_output, pt_output in zip(fx_outputs, pt_outputs): + self.assert_almost_equals(fx_output, pt_output.numpy(), 5e-4) + + def assert_almost_equals(self, a: ndarray, b: ndarray, tol: float): + diff = (a - b).sum() + self.assertLessEqual(diff, tol, "Difference between torch and flax is {} (>= {})".format(diff, tol)) diff --git a/tests/test_modeling_flax_roberta.py b/tests/test_modeling_flax_roberta.py new file mode 100644 index 0000000000..2db0cf9c83 --- /dev/null +++ b/tests/test_modeling_flax_roberta.py @@ -0,0 +1,42 @@ +import unittest + +from numpy import ndarray + +from transformers import TensorType, is_flax_available, is_torch_available +from transformers.testing_utils import require_flax, require_torch +from transformers.tokenization_roberta_fast import RobertaTokenizerFast + + +if is_flax_available(): + from transformers.modeling_flax_roberta import FlaxRobertaModel + +if is_torch_available(): + import torch + + from transformers.modeling_roberta import RobertaModel + + +@require_flax +@require_torch +class FlaxRobertaModelTest(unittest.TestCase): + def test_from_pytorch(self): + with torch.no_grad(): + with self.subTest("roberta-base"): + tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base") + fx_model = FlaxRobertaModel.from_pretrained("roberta-base") + pt_model = RobertaModel.from_pretrained("roberta-base") + + # Check for simple input + pt_inputs = tokenizer.encode_plus("This is a simple input", return_tensors=TensorType.PYTORCH) + fx_inputs = tokenizer.encode_plus("This is a simple input", return_tensors=TensorType.JAX) + pt_outputs = pt_model(**pt_inputs) + fx_outputs = fx_model(**fx_inputs) + + self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") + + for fx_output, pt_output in zip(fx_outputs, pt_outputs): + self.assert_almost_equals(fx_output, pt_output.numpy(), 5e-4) + + def assert_almost_equals(self, a: ndarray, b: ndarray, tol: float): + diff = (a - b).sum() + self.assertLessEqual(diff, tol, "Difference between torch and flax is {} (>= {})".format(diff, tol))