[Flax] Align FlaxBertForMaskedLM with BertForMaskedLM, implement from_pretrained, init (#9054)
* save intermediate * save intermediate * save intermediate * correct flax bert model file * new module / model naming * make style * almost finish BERT * finish roberta * make fix-copies * delete keys file * last refactor * fixes in run_mlm_flax.py * remove pooled from run_mlm_flax.py` * fix gelu | gelu_new * remove Module from inits * splits * dirty print * preventing warmup_steps == 0 * smaller splits * make fix-copies * dirty print * dirty print * initial_evaluation argument * declaration order fix * proper model initialization/loading * proper initialization * run_mlm_flax improvements: improper model inputs bugfix + automatic dataset splitting + tokenizers parallelism warning + avoiding warmup_steps=0 bug * removed tokenizers warning hack, fixed model re-initialization * reverted training_args.py changes * fix flax from pretrained * improve test in flax * apply sylvains tips * update init * make 0.3.0 compatible * revert tevens changes * revert tevens changes 2 * finalize revert * fix bug * add docs * add pretrained to init * Update src/transformers/modeling_flax_utils.py * fix copies * final improvements Co-authored-by: TevenLeScao <teven.lescao@gmail.com>
This commit is contained in:
committed by
GitHub
parent
51adb97cd6
commit
640e6fe190
@@ -13,9 +13,10 @@
|
|||||||
Models
|
Models
|
||||||
-----------------------------------------------------------------------------------------------------------------------
|
-----------------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
The base classes :class:`~transformers.PreTrainedModel` and :class:`~transformers.TFPreTrainedModel` implement the
|
The base classes :class:`~transformers.PreTrainedModel`, :class:`~transformers.TFPreTrainedModel`, and
|
||||||
common methods for loading/saving a model either from a local file or directory, or from a pretrained model
|
:class:`~transformers.FlaxPreTrainedModel` implement the common methods for loading/saving a model either from a local
|
||||||
configuration provided by the library (downloaded from HuggingFace's AWS S3 repository).
|
file or directory, or from a pretrained model configuration provided by the library (downloaded from HuggingFace's AWS
|
||||||
|
S3 repository).
|
||||||
|
|
||||||
:class:`~transformers.PreTrainedModel` and :class:`~transformers.TFPreTrainedModel` also implement a few methods which
|
:class:`~transformers.PreTrainedModel` and :class:`~transformers.TFPreTrainedModel` also implement a few methods which
|
||||||
are common among all the models to:
|
are common among all the models to:
|
||||||
@@ -57,6 +58,13 @@ TFModelUtilsMixin
|
|||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
FlaxPreTrainedModel
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.FlaxPreTrainedModel
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
Generation
|
Generation
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|||||||
@@ -385,7 +385,7 @@ def training_step(optimizer, batch, dropout_rng):
|
|||||||
# Hide away tokens which doesn't participate in the optimization
|
# Hide away tokens which doesn't participate in the optimization
|
||||||
token_mask = jnp.where(targets > 0, 1.0, 0.0)
|
token_mask = jnp.where(targets > 0, 1.0, 0.0)
|
||||||
|
|
||||||
pooled, logits = model(**batch, params=params, dropout_rng=dropout_rng, train=True)
|
logits = model(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
||||||
loss, weight_sum = cross_entropy(logits, targets, token_mask)
|
loss, weight_sum = cross_entropy(logits, targets, token_mask)
|
||||||
return loss / weight_sum
|
return loss / weight_sum
|
||||||
|
|
||||||
@@ -407,7 +407,7 @@ def eval_step(params, batch):
|
|||||||
|
|
||||||
# Hide away tokens which doesn't participate in the optimization
|
# Hide away tokens which doesn't participate in the optimization
|
||||||
token_mask = jnp.where(targets > 0, 1.0, 0.0)
|
token_mask = jnp.where(targets > 0, 1.0, 0.0)
|
||||||
_, logits = model(**batch, params=params, train=False)
|
logits = model(**batch, params=params, train=False)[0]
|
||||||
|
|
||||||
return compute_metrics(logits, targets, token_mask)
|
return compute_metrics(logits, targets, token_mask)
|
||||||
|
|
||||||
@@ -572,8 +572,13 @@ if __name__ == "__main__":
|
|||||||
rng = jax.random.PRNGKey(training_args.seed)
|
rng = jax.random.PRNGKey(training_args.seed)
|
||||||
dropout_rngs = jax.random.split(rng, jax.local_device_count())
|
dropout_rngs = jax.random.split(rng, jax.local_device_count())
|
||||||
|
|
||||||
model = FlaxBertForMaskedLM.from_pretrained("bert-base-cased", dtype=jnp.float32, dropout_rate=0.1)
|
model = FlaxBertForMaskedLM.from_pretrained(
|
||||||
model.init(jax.random.PRNGKey(training_args.seed), (training_args.train_batch_size, model.config.max_length))
|
"bert-base-cased",
|
||||||
|
dtype=jnp.float32,
|
||||||
|
input_shape=(training_args.train_batch_size, config.max_position_embeddings),
|
||||||
|
seed=training_args.seed,
|
||||||
|
dropout_rate=0.1,
|
||||||
|
)
|
||||||
|
|
||||||
# Setup optimizer
|
# Setup optimizer
|
||||||
optimizer = Adam(
|
optimizer = Adam(
|
||||||
|
|||||||
14
setup.py
14
setup.py
@@ -97,7 +97,7 @@ _deps = [
|
|||||||
"fastapi",
|
"fastapi",
|
||||||
"filelock",
|
"filelock",
|
||||||
"flake8>=3.8.3",
|
"flake8>=3.8.3",
|
||||||
"flax==0.2.2",
|
"flax>=0.2.2",
|
||||||
"fugashi>=1.0",
|
"fugashi>=1.0",
|
||||||
"ipadic>=1.0.0,<2.0",
|
"ipadic>=1.0.0,<2.0",
|
||||||
"isort>=5.5.4",
|
"isort>=5.5.4",
|
||||||
@@ -175,7 +175,7 @@ class DepsTableUpdateCommand(Command):
|
|||||||
"deps = {",
|
"deps = {",
|
||||||
entries,
|
entries,
|
||||||
"}",
|
"}",
|
||||||
""
|
"",
|
||||||
]
|
]
|
||||||
target = "src/transformers/dependency_versions_table.py"
|
target = "src/transformers/dependency_versions_table.py"
|
||||||
print(f"updating {target}")
|
print(f"updating {target}")
|
||||||
@@ -232,14 +232,14 @@ extras["dev"] = (
|
|||||||
# when modifying the following list, make sure to update src/transformers/dependency_versions_check.py
|
# when modifying the following list, make sure to update src/transformers/dependency_versions_check.py
|
||||||
install_requires = [
|
install_requires = [
|
||||||
deps["dataclasses"] + ";python_version<'3.7'", # dataclasses for Python versions that don't have it
|
deps["dataclasses"] + ";python_version<'3.7'", # dataclasses for Python versions that don't have it
|
||||||
deps["filelock"], # filesystem locks, e.g., to prevent parallel downloads
|
deps["filelock"], # filesystem locks, e.g., to prevent parallel downloads
|
||||||
deps["numpy"],
|
deps["numpy"],
|
||||||
deps["packaging"], # utilities from PyPA to e.g., compare versions
|
deps["packaging"], # utilities from PyPA to e.g., compare versions
|
||||||
deps["regex"], # for OpenAI GPT
|
deps["regex"], # for OpenAI GPT
|
||||||
deps["requests"], # for downloading models over HTTPS
|
deps["requests"], # for downloading models over HTTPS
|
||||||
deps["sacremoses"], # for XLM
|
deps["sacremoses"], # for XLM
|
||||||
deps["tokenizers"],
|
deps["tokenizers"],
|
||||||
deps["tqdm"], # progress bars in model download and training scripts
|
deps["tqdm"], # progress bars in model download and training scripts
|
||||||
]
|
]
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
|
|||||||
@@ -945,6 +945,7 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
|
from .modeling_flax_utils import FlaxPreTrainedModel
|
||||||
from .models.auto import FLAX_MODEL_MAPPING, FlaxAutoModel
|
from .models.auto import FLAX_MODEL_MAPPING, FlaxAutoModel
|
||||||
from .models.bert import FlaxBertForMaskedLM, FlaxBertModel
|
from .models.bert import FlaxBertForMaskedLM, FlaxBertModel
|
||||||
from .models.roberta import FlaxRobertaModel
|
from .models.roberta import FlaxRobertaModel
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ deps = {
|
|||||||
"fastapi": "fastapi",
|
"fastapi": "fastapi",
|
||||||
"filelock": "filelock",
|
"filelock": "filelock",
|
||||||
"flake8": "flake8>=3.8.3",
|
"flake8": "flake8>=3.8.3",
|
||||||
"flax": "flax==0.2.2",
|
"flax": "flax>=0.2.2",
|
||||||
"fugashi": "fugashi>=1.0",
|
"fugashi": "fugashi>=1.0",
|
||||||
"ipadic": "ipadic>=1.0.0,<2.0",
|
"ipadic": "ipadic>=1.0.0,<2.0",
|
||||||
"isort": "isort>=5.5.4",
|
"isort": "isort>=5.5.4",
|
||||||
@@ -40,8 +40,8 @@ deps = {
|
|||||||
"sphinx-rtd-theme": "sphinx-rtd-theme==0.4.3",
|
"sphinx-rtd-theme": "sphinx-rtd-theme==0.4.3",
|
||||||
"sphinx": "sphinx==3.2.1",
|
"sphinx": "sphinx==3.2.1",
|
||||||
"starlette": "starlette",
|
"starlette": "starlette",
|
||||||
"tensorflow-cpu": "tensorflow-cpu>=2.0",
|
"tensorflow-cpu": "tensorflow-cpu>=2.0,<2.4",
|
||||||
"tensorflow": "tensorflow>=2.0",
|
"tensorflow": "tensorflow>=2.0,<2.4",
|
||||||
"timeout-decorator": "timeout-decorator",
|
"timeout-decorator": "timeout-decorator",
|
||||||
"tokenizers": "tokenizers==0.9.4",
|
"tokenizers": "tokenizers==0.9.4",
|
||||||
"torch": "torch>=1.0",
|
"torch": "torch>=1.0",
|
||||||
|
|||||||
@@ -270,6 +270,7 @@ TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
|
|||||||
WEIGHTS_NAME = "pytorch_model.bin"
|
WEIGHTS_NAME = "pytorch_model.bin"
|
||||||
TF2_WEIGHTS_NAME = "tf_model.h5"
|
TF2_WEIGHTS_NAME = "tf_model.h5"
|
||||||
TF_WEIGHTS_NAME = "model.ckpt"
|
TF_WEIGHTS_NAME = "model.ckpt"
|
||||||
|
FLAX_WEIGHTS_NAME = "flax_model.msgpack"
|
||||||
CONFIG_NAME = "config.json"
|
CONFIG_NAME = "config.json"
|
||||||
MODEL_CARD_NAME = "modelcard.json"
|
MODEL_CARD_NAME = "modelcard.json"
|
||||||
|
|
||||||
|
|||||||
@@ -15,64 +15,65 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from functools import partial
|
||||||
from pickle import UnpicklingError
|
from pickle import UnpicklingError
|
||||||
from typing import Dict
|
from typing import Dict, Set, Tuple, Union
|
||||||
|
|
||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from flax.serialization import to_bytes
|
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||||
from flax.traverse_util import unflatten_dict
|
from flax.serialization import from_bytes, to_bytes
|
||||||
|
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||||
from jax.random import PRNGKey
|
from jax.random import PRNGKey
|
||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
from .file_utils import WEIGHTS_NAME, cached_path, hf_bucket_url, is_remote_url
|
from .file_utils import FLAX_WEIGHTS_NAME, WEIGHTS_NAME, cached_path, hf_bucket_url, is_remote_url
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@jax.jit
|
|
||||||
def gelu(x):
|
|
||||||
r"""
|
|
||||||
Gaussian error linear unit activation function.
|
|
||||||
|
|
||||||
Computes the element-wise function:
|
|
||||||
|
|
||||||
.. math::
|
|
||||||
\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left(
|
|
||||||
\sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)
|
|
||||||
|
|
||||||
We explicitly use the approximation rather than the exact formulation for speed. For more information, see
|
|
||||||
`Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_, section 2.
|
|
||||||
"""
|
|
||||||
return x * 0.5 * (1.0 + jax.lax.erf(x / jnp.sqrt(2.0)))
|
|
||||||
|
|
||||||
|
|
||||||
ACT2FN = {
|
ACT2FN = {
|
||||||
"gelu": nn.gelu,
|
"gelu": nn.gelu,
|
||||||
"relu": nn.relu,
|
"relu": nn.relu,
|
||||||
"silu": nn.swish,
|
"silu": nn.swish,
|
||||||
"swish": nn.swish,
|
"swish": nn.swish,
|
||||||
"gelu_new": gelu,
|
"gelu_new": partial(nn.gelu, approximate=True),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class FlaxPreTrainedModel(ABC):
|
class FlaxPreTrainedModel(ABC):
|
||||||
|
r"""
|
||||||
|
Base class for all models.
|
||||||
|
|
||||||
|
:class:`~transformers.FlaxPreTrainedModel` takes care of storing the configuration of the models and handles
|
||||||
|
methods for loading, downloading and saving models.
|
||||||
|
|
||||||
|
Class attributes (overridden by derived classes):
|
||||||
|
|
||||||
|
- **config_class** (:class:`~transformers.PretrainedConfig`) -- A subclass of
|
||||||
|
:class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
|
||||||
|
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
|
||||||
|
derived classes of the same architecture adding modules on top of the base model.
|
||||||
|
"""
|
||||||
config_class = None
|
config_class = None
|
||||||
pretrained_model_archive_map = {}
|
|
||||||
base_model_prefix = ""
|
base_model_prefix = ""
|
||||||
model_class = None
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, config: PretrainedConfig, module: nn.Module, params: Dict, seed: int = 0, dtype: jnp.dtype = jnp.float32
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
module: nn.Module,
|
||||||
|
input_shape: Tuple = (1, 1),
|
||||||
|
seed: int = 0,
|
||||||
|
dtype: jnp.dtype = jnp.float32,
|
||||||
):
|
):
|
||||||
if config is None:
|
if config is None:
|
||||||
raise ValueError("config cannot be None")
|
raise ValueError("config cannot be None")
|
||||||
|
|
||||||
if params is None:
|
if module is None:
|
||||||
raise ValueError("state cannot be None")
|
raise ValueError("module cannot be None")
|
||||||
|
|
||||||
# Those are private to be exposed as typed property on derived classes.
|
# Those are private to be exposed as typed property on derived classes.
|
||||||
self._config = config
|
self._config = config
|
||||||
@@ -80,9 +81,18 @@ class FlaxPreTrainedModel(ABC):
|
|||||||
|
|
||||||
# Those are public as their type is generic to every derived classes.
|
# Those are public as their type is generic to every derived classes.
|
||||||
self.key = PRNGKey(seed)
|
self.key = PRNGKey(seed)
|
||||||
self.params = params
|
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
|
# randomely initialized parameters
|
||||||
|
random_params = self.init(self.key, input_shape)
|
||||||
|
|
||||||
|
# save required_params as set
|
||||||
|
self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
|
||||||
|
self.params = random_params
|
||||||
|
|
||||||
|
def init(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> Dict:
|
||||||
|
raise NotImplementedError(f"init method has to be implemented for {self}")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def config(self) -> PretrainedConfig:
|
def config(self) -> PretrainedConfig:
|
||||||
return self._config
|
return self._config
|
||||||
@@ -91,24 +101,130 @@ class FlaxPreTrainedModel(ABC):
|
|||||||
def module(self) -> nn.Module:
|
def module(self) -> nn.Module:
|
||||||
return self._module
|
return self._module
|
||||||
|
|
||||||
|
@property
|
||||||
|
def params(self) -> Union[Dict, FrozenDict]:
|
||||||
|
return self._params
|
||||||
|
|
||||||
|
@property
|
||||||
|
def required_params(self) -> Set:
|
||||||
|
return self._required_params
|
||||||
|
|
||||||
|
@params.setter
|
||||||
|
def params(self, params: Union[Dict, FrozenDict]):
|
||||||
|
if isinstance(params, FrozenDict):
|
||||||
|
params = unfreeze(params)
|
||||||
|
param_keys = set(flatten_dict(params).keys())
|
||||||
|
if len(self.required_params - param_keys) > 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Some parameters are missing. Make sure that `params` include the following "
|
||||||
|
f"parameters {self.required_params - param_keys}"
|
||||||
|
)
|
||||||
|
self._params = freeze(params)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def convert_from_pytorch(pt_state: Dict, config: PretrainedConfig) -> Dict:
|
def convert_from_pytorch(pt_state: Dict, config: PretrainedConfig) -> Dict:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path, dtype: jnp.dtype = jnp.float32, *model_args, **kwargs):
|
def from_pretrained(
|
||||||
|
cls,
|
||||||
|
pretrained_model_name_or_path: Union[str, os.PathLike],
|
||||||
|
dtype: jnp.dtype = jnp.float32,
|
||||||
|
*model_args,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
|
||||||
r"""
|
r"""
|
||||||
Instantiate a pretrained Flax model from a pre-trained model configuration.
|
Instantiate a pretrained flax model from a pre-trained model configuration.
|
||||||
|
|
||||||
|
The warning `Weights from XXX not initialized from pretrained model` means that the weights of XXX do not come
|
||||||
|
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
|
||||||
|
task.
|
||||||
|
|
||||||
|
The warning `Weights from XXX not used in YYY` means that the layer XXX is not used by YYY, therefore those
|
||||||
|
weights are discarded.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
|
||||||
|
Can be either:
|
||||||
|
|
||||||
|
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
|
||||||
|
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
|
||||||
|
a user or organization name, like ``dbmdz/bert-base-german-cased``.
|
||||||
|
- A path to a `directory` containing model weights saved using
|
||||||
|
:func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
|
||||||
|
- A path or url to a `pt index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In this
|
||||||
|
case, ``from_pt`` should be set to :obj:`True`.
|
||||||
|
model_args (sequence of positional arguments, `optional`):
|
||||||
|
All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
|
||||||
|
config (:obj:`Union[PretrainedConfig, str, os.PathLike]`, `optional`):
|
||||||
|
Can be either:
|
||||||
|
|
||||||
|
- an instance of a class derived from :class:`~transformers.PretrainedConfig`,
|
||||||
|
- a string or path valid as input to :func:`~transformers.PretrainedConfig.from_pretrained`.
|
||||||
|
|
||||||
|
Configuration for the model to use instead of an automatically loaded configuation. Configuration can
|
||||||
|
be automatically loaded when:
|
||||||
|
|
||||||
|
- The model is a model provided by the library (loaded with the `model id` string of a pretrained
|
||||||
|
model).
|
||||||
|
- The model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded
|
||||||
|
by supplying the save directory.
|
||||||
|
- The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a
|
||||||
|
configuration JSON file named `config.json` is found in the directory.
|
||||||
|
cache_dir (:obj:`Union[str, os.PathLike]`, `optional`):
|
||||||
|
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
||||||
|
standard cache should not be used.
|
||||||
|
from_pt (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Load the model weights from a PyTorch checkpoint save file (see docstring of
|
||||||
|
``pretrained_model_name_or_path`` argument).
|
||||||
|
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||||
|
cached versions if they exist.
|
||||||
|
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
||||||
|
file exists.
|
||||||
|
proxies (:obj:`Dict[str, str], `optional`):
|
||||||
|
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
|
||||||
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||||
|
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not to only look at local files (i.e., do not try to download the model).
|
||||||
|
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
|
||||||
|
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||||
|
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
|
||||||
|
identifier allowed by git.
|
||||||
|
kwargs (remaining dictionary of keyword arguments, `optional`):
|
||||||
|
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||||
|
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
|
||||||
|
automatically loaded:
|
||||||
|
|
||||||
|
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the
|
||||||
|
underlying model's ``__init__`` method (we assume all relevant updates to the configuration have
|
||||||
|
already been done)
|
||||||
|
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class
|
||||||
|
initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of
|
||||||
|
``kwargs`` that corresponds to a configuration attribute will be used to override said attribute
|
||||||
|
with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
|
||||||
|
attribute will be passed to the underlying model's ``__init__`` function.
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
>>> from transformers import BertConfig, FlaxBertModel
|
||||||
|
>>> # Download model and configuration from huggingface.co and cache.
|
||||||
|
>>> model = FlaxBertModel.from_pretrained('bert-base-cased')
|
||||||
|
>>> # Model was saved using `save_pretrained('./test/saved_model/')` (for example purposes, not runnable).
|
||||||
|
>>> model = FlaxBertModel.from_pretrained('./test/saved_model/')
|
||||||
|
>>> # Loading from a PyTorch checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable).
|
||||||
|
>>> config = BertConfig.from_json_file('./pt_model/config.json')
|
||||||
|
>>> model = FlaxBertModel.from_pretrained('./pt_model/pytorch_model.bin', from_pt=True, config=config)
|
||||||
"""
|
"""
|
||||||
config = kwargs.pop("config", None)
|
config = kwargs.pop("config", None)
|
||||||
# state_dict = kwargs.pop("state_dict", None)
|
|
||||||
cache_dir = kwargs.pop("cache_dir", None)
|
cache_dir = kwargs.pop("cache_dir", None)
|
||||||
# from_tf = kwargs.pop("from_tf", False)
|
from_pt = kwargs.pop("from_pt", False)
|
||||||
force_download = kwargs.pop("force_download", False)
|
force_download = kwargs.pop("force_download", False)
|
||||||
resume_download = kwargs.pop("resume_download", False)
|
resume_download = kwargs.pop("resume_download", False)
|
||||||
proxies = kwargs.pop("proxies", None)
|
proxies = kwargs.pop("proxies", None)
|
||||||
# output_loading_info = kwargs.pop("output_loading_info", False)
|
|
||||||
local_files_only = kwargs.pop("local_files_only", False)
|
local_files_only = kwargs.pop("local_files_only", False)
|
||||||
revision = kwargs.pop("revision", None)
|
revision = kwargs.pop("revision", None)
|
||||||
|
|
||||||
@@ -135,10 +251,28 @@ class FlaxPreTrainedModel(ABC):
|
|||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
if pretrained_model_name_or_path is not None:
|
if pretrained_model_name_or_path is not None:
|
||||||
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
if os.path.isdir(pretrained_model_name_or_path):
|
||||||
|
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
||||||
|
# Load from a PyTorch checkpoint
|
||||||
|
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
||||||
|
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
|
||||||
|
# Load from a Flax checkpoint
|
||||||
|
archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
|
||||||
|
else:
|
||||||
|
raise EnvironmentError(
|
||||||
|
"Error no file named {} found in directory {} or `from_pt` set to False".format(
|
||||||
|
[FLAX_WEIGHTS_NAME, WEIGHTS_NAME],
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||||
archive_file = pretrained_model_name_or_path
|
archive_file = pretrained_model_name_or_path
|
||||||
else:
|
else:
|
||||||
archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, revision=revision)
|
archive_file = hf_bucket_url(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
filename=WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME,
|
||||||
|
revision=revision,
|
||||||
|
)
|
||||||
|
|
||||||
# redirect to the cache, if necessary
|
# redirect to the cache, if necessary
|
||||||
try:
|
try:
|
||||||
@@ -169,31 +303,96 @@ class FlaxPreTrainedModel(ABC):
|
|||||||
# Instantiate model.
|
# Instantiate model.
|
||||||
with open(resolved_archive_file, "rb") as state_f:
|
with open(resolved_archive_file, "rb") as state_f:
|
||||||
try:
|
try:
|
||||||
from flax.serialization import from_bytes
|
if from_pt:
|
||||||
|
|
||||||
state = from_bytes(cls.model_class, state_f)
|
|
||||||
except TypeError:
|
|
||||||
try:
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
state = torch.load(state_f)
|
state = torch.load(state_f)
|
||||||
state = {k: v.numpy() for k, v in state.items()}
|
|
||||||
state = cls.convert_from_pytorch(state, config)
|
|
||||||
state = unflatten_dict({tuple(k.split(".")[1:]): v for k, v in state.items()})
|
|
||||||
except UnpicklingError:
|
|
||||||
raise EnvironmentError(
|
|
||||||
f"Unable to convert model {archive_file} to Flax deserializable object. "
|
|
||||||
"Supported format are PyTorch archive or Flax msgpack"
|
|
||||||
)
|
|
||||||
|
|
||||||
return cls(config, state, *model_args, **model_kwargs)
|
state = convert_state_dict_from_pt(cls, state, config)
|
||||||
|
else:
|
||||||
|
state = from_bytes(cls, state_f.read())
|
||||||
|
except UnpicklingError:
|
||||||
|
raise EnvironmentError(
|
||||||
|
f"Unable to convert pytorch model {archive_file} to Flax deserializable object. "
|
||||||
|
)
|
||||||
|
|
||||||
def save_pretrained(self, folder):
|
# init random models
|
||||||
folder_abs = os.path.abspath(folder)
|
model = cls(config, *model_args, **model_kwargs)
|
||||||
|
|
||||||
if not os.path.exists(folder_abs):
|
# if model is base model only use model_prefix key
|
||||||
os.mkdir(folder_abs)
|
if cls.base_model_prefix not in dict(model.params) and cls.base_model_prefix in state:
|
||||||
|
state = state[cls.base_model_prefix]
|
||||||
|
|
||||||
with open(os.path.join(folder_abs, f"{self._config.model_type}.flax", "wb")) as f:
|
# flatten dicts
|
||||||
|
state = flatten_dict(state)
|
||||||
|
random_state = flatten_dict(unfreeze(model.params))
|
||||||
|
|
||||||
|
missing_keys = model.required_params - set(state.keys())
|
||||||
|
unexpected_keys = set(state.keys()) - model.required_params
|
||||||
|
|
||||||
|
# add missing keys as random parameters
|
||||||
|
for missing_key in missing_keys:
|
||||||
|
state[missing_key] = random_state[missing_key]
|
||||||
|
|
||||||
|
if len(unexpected_keys) > 0:
|
||||||
|
logger.warning(
|
||||||
|
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
|
||||||
|
f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
|
||||||
|
f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
|
||||||
|
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n"
|
||||||
|
f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
|
||||||
|
f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
|
||||||
|
|
||||||
|
if len(missing_keys) > 0:
|
||||||
|
logger.warning(
|
||||||
|
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
|
||||||
|
f"and are newly initialized: {missing_keys}\n"
|
||||||
|
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
|
||||||
|
f"If your task is similar to the task the model of the checkpoint was trained on, "
|
||||||
|
f"you can already use {model.__class__.__name__} for predictions without further training."
|
||||||
|
)
|
||||||
|
|
||||||
|
# set correct parameters
|
||||||
|
model.params = unflatten_dict(state)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
|
||||||
|
"""
|
||||||
|
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
||||||
|
`:func:`~transformers.FlaxPreTrainedModel.from_pretrained`` class method
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
save_directory (:obj:`str` or :obj:`os.PathLike`):
|
||||||
|
Directory to which to save. Will be created if it doesn't exist.
|
||||||
|
"""
|
||||||
|
if os.path.isfile(save_directory):
|
||||||
|
logger.error("Provided path ({}) should be a directory, not a file".format(save_directory))
|
||||||
|
return
|
||||||
|
os.makedirs(save_directory, exist_ok=True)
|
||||||
|
|
||||||
|
# get abs dir
|
||||||
|
save_directory = os.path.abspath(save_directory)
|
||||||
|
# save config as well
|
||||||
|
self.config.save_pretrained(save_directory)
|
||||||
|
|
||||||
|
# save model
|
||||||
|
with open(os.path.join(save_directory, FLAX_WEIGHTS_NAME), "wb") as f:
|
||||||
model_bytes = to_bytes(self.params)
|
model_bytes = to_bytes(self.params)
|
||||||
f.write(model_bytes)
|
f.write(model_bytes)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_state_dict_from_pt(model_class: ABC, state: Dict, config: PretrainedConfig):
|
||||||
|
"""
|
||||||
|
Converts a PyTorch parameter state dict to an equivalent Flax parameter state dict
|
||||||
|
"""
|
||||||
|
state = {k: v.numpy() for k, v in state.items()}
|
||||||
|
state = model_class.convert_from_pytorch(state, config)
|
||||||
|
state = unflatten_dict({tuple(k.split(".")): v for k, v in state.items()})
|
||||||
|
return state
|
||||||
|
|||||||
@@ -27,15 +27,6 @@ from .configuration_auto import AutoConfig, BertConfig, RobertaConfig
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict(
|
|
||||||
(key, value)
|
|
||||||
for pretrained_map in [
|
|
||||||
FlaxBertModel.pretrained_model_archive_map,
|
|
||||||
FlaxRobertaModel.pretrained_model_archive_map,
|
|
||||||
]
|
|
||||||
for key, value, in pretrained_map.items()
|
|
||||||
)
|
|
||||||
|
|
||||||
FLAX_MODEL_MAPPING = OrderedDict(
|
FLAX_MODEL_MAPPING = OrderedDict(
|
||||||
[
|
[
|
||||||
(RobertaConfig, FlaxRobertaModel),
|
(RobertaConfig, FlaxRobertaModel),
|
||||||
@@ -114,10 +105,9 @@ class FlaxAutoModel(object):
|
|||||||
organization name, like ``dbmdz/bert-base-german-cased``.
|
organization name, like ``dbmdz/bert-base-german-cased``.
|
||||||
- a path to a `directory` containing model weights saved using
|
- a path to a `directory` containing model weights saved using
|
||||||
:func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
|
:func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
|
||||||
- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this
|
- a path or url to a `pytorch index checkpoint file` (e.g. `./pt_model/pytorch_model.bin`). In this
|
||||||
case, ``from_tf`` should be set to True and a configuration object should be provided as ``config``
|
case, ``from_pt`` should be set to True and a configuration object should be provided as ``config``
|
||||||
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model
|
argument.
|
||||||
using the provided conversion scripts and loading the PyTorch model afterwards.
|
|
||||||
|
|
||||||
model_args: (`optional`) Sequence of positional arguments:
|
model_args: (`optional`) Sequence of positional arguments:
|
||||||
All remaining positional arguments will be passed to the underlying model's ``__init__`` method
|
All remaining positional arguments will be passed to the underlying model's ``__init__`` method
|
||||||
@@ -133,13 +123,6 @@ class FlaxAutoModel(object):
|
|||||||
- the model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a
|
- the model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a
|
||||||
configuration JSON file named `config.json` is found in the directory.
|
configuration JSON file named `config.json` is found in the directory.
|
||||||
|
|
||||||
state_dict: (`optional`) dict:
|
|
||||||
an optional state dictionary for the model to use instead of a state dictionary loaded from saved
|
|
||||||
weights file. This option can be used if you want to create a model from a pretrained configuration but
|
|
||||||
load your own weights. In this case though, you should check if using
|
|
||||||
:func:`~transformers.FlaxPreTrainedModel.save_pretrained` and
|
|
||||||
:func:`~transformers.FlaxPreTrainedModel.from_pretrained` is not a simpler option.
|
|
||||||
|
|
||||||
cache_dir: (`optional`) string:
|
cache_dir: (`optional`) string:
|
||||||
Path to a directory in which a downloaded pre-trained model configuration should be cached if the
|
Path to a directory in which a downloaded pre-trained model configuration should be cached if the
|
||||||
standard cache should not be used.
|
standard cache should not be used.
|
||||||
|
|||||||
@@ -20,10 +20,11 @@ import numpy as np
|
|||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
from flax.core.frozen_dict import FrozenDict
|
||||||
from jax.random import PRNGKey
|
from jax.random import PRNGKey
|
||||||
|
|
||||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
||||||
from ...modeling_flax_utils import FlaxPreTrainedModel, gelu
|
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from .configuration_bert import BertConfig
|
from .configuration_bert import BertConfig
|
||||||
|
|
||||||
@@ -205,7 +206,7 @@ class FlaxBertAttention(nn.Module):
|
|||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
|
||||||
@nn.compact
|
@nn.compact
|
||||||
def __call__(self, hidden_state, attention_mask, deterministic: bool = True):
|
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||||
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
||||||
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
||||||
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
||||||
@@ -219,27 +220,28 @@ class FlaxBertAttention(nn.Module):
|
|||||||
bias_init=jax.nn.initializers.zeros,
|
bias_init=jax.nn.initializers.zeros,
|
||||||
name="self",
|
name="self",
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)(hidden_state, attention_mask)
|
)(hidden_states, attention_mask)
|
||||||
|
|
||||||
layer_norm = FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(self_att + hidden_state)
|
layer_norm = FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(self_att + hidden_states)
|
||||||
return layer_norm
|
return layer_norm
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertIntermediate(nn.Module):
|
class FlaxBertIntermediate(nn.Module):
|
||||||
output_size: int
|
output_size: int
|
||||||
|
hidden_act: str = "gelu"
|
||||||
kernel_init_scale: float = 0.2
|
kernel_init_scale: float = 0.2
|
||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
|
||||||
@nn.compact
|
@nn.compact
|
||||||
def __call__(self, hidden_state):
|
def __call__(self, hidden_states):
|
||||||
# TODO: Add ACT2FN reference to change activation function
|
hidden_states = nn.Dense(
|
||||||
dense = nn.Dense(
|
|
||||||
features=self.output_size,
|
features=self.output_size,
|
||||||
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
||||||
name="dense",
|
name="dense",
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)(hidden_state)
|
)(hidden_states)
|
||||||
return gelu(dense)
|
hidden_states = ACT2FN[self.hidden_act](hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertOutput(nn.Module):
|
class FlaxBertOutput(nn.Module):
|
||||||
@@ -249,27 +251,28 @@ class FlaxBertOutput(nn.Module):
|
|||||||
|
|
||||||
@nn.compact
|
@nn.compact
|
||||||
def __call__(self, intermediate_output, attention_output, deterministic: bool = True):
|
def __call__(self, intermediate_output, attention_output, deterministic: bool = True):
|
||||||
hidden_state = nn.Dense(
|
hidden_states = nn.Dense(
|
||||||
attention_output.shape[-1],
|
attention_output.shape[-1],
|
||||||
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
||||||
name="dense",
|
name="dense",
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)(intermediate_output)
|
)(intermediate_output)
|
||||||
hidden_state = nn.Dropout(rate=self.dropout_rate)(hidden_state, deterministic=deterministic)
|
hidden_states = nn.Dropout(rate=self.dropout_rate)(hidden_states, deterministic=deterministic)
|
||||||
hidden_state = FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_state + attention_output)
|
hidden_states = FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_states + attention_output)
|
||||||
return hidden_state
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertLayer(nn.Module):
|
class FlaxBertLayer(nn.Module):
|
||||||
num_heads: int
|
num_heads: int
|
||||||
head_size: int
|
head_size: int
|
||||||
intermediate_size: int
|
intermediate_size: int
|
||||||
|
hidden_act: str = "gelu"
|
||||||
dropout_rate: float = 0.0
|
dropout_rate: float = 0.0
|
||||||
kernel_init_scale: float = 0.2
|
kernel_init_scale: float = 0.2
|
||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
|
||||||
@nn.compact
|
@nn.compact
|
||||||
def __call__(self, hidden_state, attention_mask, deterministic: bool = True):
|
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||||
attention = FlaxBertAttention(
|
attention = FlaxBertAttention(
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.head_size,
|
self.head_size,
|
||||||
@@ -277,9 +280,13 @@ class FlaxBertLayer(nn.Module):
|
|||||||
dropout_rate=self.dropout_rate,
|
dropout_rate=self.dropout_rate,
|
||||||
name="attention",
|
name="attention",
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)(hidden_state, attention_mask, deterministic=deterministic)
|
)(hidden_states, attention_mask, deterministic=deterministic)
|
||||||
intermediate = FlaxBertIntermediate(
|
intermediate = FlaxBertIntermediate(
|
||||||
self.intermediate_size, kernel_init_scale=self.kernel_init_scale, name="intermediate", dtype=self.dtype
|
self.intermediate_size,
|
||||||
|
kernel_init_scale=self.kernel_init_scale,
|
||||||
|
hidden_act=self.hidden_act,
|
||||||
|
name="intermediate",
|
||||||
|
dtype=self.dtype,
|
||||||
)(attention)
|
)(attention)
|
||||||
output = FlaxBertOutput(
|
output = FlaxBertOutput(
|
||||||
kernel_init_scale=self.kernel_init_scale, dropout_rate=self.dropout_rate, name="output", dtype=self.dtype
|
kernel_init_scale=self.kernel_init_scale, dropout_rate=self.dropout_rate, name="output", dtype=self.dtype
|
||||||
@@ -297,6 +304,7 @@ class FlaxBertLayerCollection(nn.Module):
|
|||||||
num_heads: int
|
num_heads: int
|
||||||
head_size: int
|
head_size: int
|
||||||
intermediate_size: int
|
intermediate_size: int
|
||||||
|
hidden_act: str = "gelu"
|
||||||
dropout_rate: float = 0.0
|
dropout_rate: float = 0.0
|
||||||
kernel_init_scale: float = 0.2
|
kernel_init_scale: float = 0.2
|
||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
@@ -316,6 +324,7 @@ class FlaxBertLayerCollection(nn.Module):
|
|||||||
self.intermediate_size,
|
self.intermediate_size,
|
||||||
kernel_init_scale=self.kernel_init_scale,
|
kernel_init_scale=self.kernel_init_scale,
|
||||||
dropout_rate=self.dropout_rate,
|
dropout_rate=self.dropout_rate,
|
||||||
|
hidden_act=self.hidden_act,
|
||||||
name=f"{i}",
|
name=f"{i}",
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
@@ -328,22 +337,24 @@ class FlaxBertEncoder(nn.Module):
|
|||||||
num_heads: int
|
num_heads: int
|
||||||
head_size: int
|
head_size: int
|
||||||
intermediate_size: int
|
intermediate_size: int
|
||||||
|
hidden_act: str = "gelu"
|
||||||
dropout_rate: float = 0.0
|
dropout_rate: float = 0.0
|
||||||
kernel_init_scale: float = 0.2
|
kernel_init_scale: float = 0.2
|
||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
|
||||||
@nn.compact
|
@nn.compact
|
||||||
def __call__(self, hidden_state, attention_mask, deterministic: bool = True):
|
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||||
layer = FlaxBertLayerCollection(
|
layer = FlaxBertLayerCollection(
|
||||||
self.num_layers,
|
self.num_layers,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.head_size,
|
self.head_size,
|
||||||
self.intermediate_size,
|
self.intermediate_size,
|
||||||
|
hidden_act=self.hidden_act,
|
||||||
kernel_init_scale=self.kernel_init_scale,
|
kernel_init_scale=self.kernel_init_scale,
|
||||||
dropout_rate=self.dropout_rate,
|
dropout_rate=self.dropout_rate,
|
||||||
name="layer",
|
name="layer",
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)(hidden_state, attention_mask, deterministic=deterministic)
|
)(hidden_states, attention_mask, deterministic=deterministic)
|
||||||
return layer
|
return layer
|
||||||
|
|
||||||
|
|
||||||
@@ -352,10 +363,10 @@ class FlaxBertPooler(nn.Module):
|
|||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
|
||||||
@nn.compact
|
@nn.compact
|
||||||
def __call__(self, hidden_state):
|
def __call__(self, hidden_states):
|
||||||
cls_token = hidden_state[:, 0]
|
cls_token = hidden_states[:, 0]
|
||||||
out = nn.Dense(
|
out = nn.Dense(
|
||||||
hidden_state.shape[-1],
|
hidden_states.shape[-1],
|
||||||
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
||||||
name="dense",
|
name="dense",
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
@@ -363,62 +374,20 @@ class FlaxBertPooler(nn.Module):
|
|||||||
return nn.tanh(out)
|
return nn.tanh(out)
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertModule(nn.Module):
|
|
||||||
vocab_size: int
|
|
||||||
hidden_size: int
|
|
||||||
type_vocab_size: int
|
|
||||||
max_length: int
|
|
||||||
num_encoder_layers: int
|
|
||||||
num_heads: int
|
|
||||||
head_size: int
|
|
||||||
intermediate_size: int
|
|
||||||
dropout_rate: float = 0.0
|
|
||||||
kernel_init_scale: float = 0.2
|
|
||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
|
||||||
|
|
||||||
@nn.compact
|
|
||||||
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
|
|
||||||
|
|
||||||
# Embedding
|
|
||||||
embeddings = FlaxBertEmbeddings(
|
|
||||||
self.vocab_size,
|
|
||||||
self.hidden_size,
|
|
||||||
self.type_vocab_size,
|
|
||||||
self.max_length,
|
|
||||||
kernel_init_scale=self.kernel_init_scale,
|
|
||||||
dropout_rate=self.dropout_rate,
|
|
||||||
name="embeddings",
|
|
||||||
dtype=self.dtype,
|
|
||||||
)(input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic)
|
|
||||||
|
|
||||||
# N stacked encoding layers
|
|
||||||
encoder = FlaxBertEncoder(
|
|
||||||
self.num_encoder_layers,
|
|
||||||
self.num_heads,
|
|
||||||
self.head_size,
|
|
||||||
self.intermediate_size,
|
|
||||||
kernel_init_scale=self.kernel_init_scale,
|
|
||||||
dropout_rate=self.dropout_rate,
|
|
||||||
name="encoder",
|
|
||||||
dtype=self.dtype,
|
|
||||||
)(embeddings, attention_mask, deterministic=deterministic)
|
|
||||||
|
|
||||||
pooled = FlaxBertPooler(kernel_init_scale=self.kernel_init_scale, name="pooler", dtype=self.dtype)(encoder)
|
|
||||||
return encoder, pooled
|
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertPredictionHeadTransform(nn.Module):
|
class FlaxBertPredictionHeadTransform(nn.Module):
|
||||||
|
hidden_act: str = "gelu"
|
||||||
dtype: jnp.dtype = jnp.float32
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
|
||||||
@nn.compact
|
@nn.compact
|
||||||
def __call__(self, hidden_states):
|
def __call__(self, hidden_states):
|
||||||
hidden_states = nn.Dense(hidden_states.shape[-1], name="dense", dtype=self.dtype)(hidden_states)
|
hidden_states = nn.Dense(hidden_states.shape[-1], name="dense", dtype=self.dtype)(hidden_states)
|
||||||
hidden_states = nn.elu(hidden_states) # TODO: ACT2FN[config.hidden_act]
|
hidden_states = ACT2FN[self.hidden_act](hidden_states)
|
||||||
return FlaxBertLayerNorm(name="LayerNorm", dtype=self.dtype)(hidden_states)
|
return FlaxBertLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertLMPredictionHead(nn.Module):
|
class FlaxBertLMPredictionHead(nn.Module):
|
||||||
vocab_size: int
|
vocab_size: int
|
||||||
|
hidden_act: str = "gelu"
|
||||||
dtype: jnp.dtype = jnp.float32
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
|
||||||
@nn.compact
|
@nn.compact
|
||||||
@@ -428,64 +397,57 @@ class FlaxBertLMPredictionHead(nn.Module):
|
|||||||
# Need a link between the two variables so that the bias is correctly
|
# Need a link between the two variables so that the bias is correctly
|
||||||
# resized with `resize_token_embeddings`
|
# resized with `resize_token_embeddings`
|
||||||
|
|
||||||
hidden_states = FlaxBertPredictionHeadTransform(name="transform", dtype=self.dtype)(hidden_states)
|
hidden_states = FlaxBertPredictionHeadTransform(
|
||||||
|
name="transform", hidden_act=self.hidden_act, dtype=self.dtype
|
||||||
|
)(hidden_states)
|
||||||
hidden_states = nn.Dense(self.vocab_size, name="decoder", dtype=self.dtype)(hidden_states)
|
hidden_states = nn.Dense(self.vocab_size, name="decoder", dtype=self.dtype)(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertOnlyMLMHead(nn.Module):
|
class FlaxBertOnlyMLMHead(nn.Module):
|
||||||
vocab_size: int
|
vocab_size: int
|
||||||
hidden_size: int
|
hidden_act: str = "gelu"
|
||||||
intermediate_size: int
|
|
||||||
head_size: int
|
|
||||||
num_heads: int
|
|
||||||
num_encoder_layers: int
|
|
||||||
type_vocab_size: int
|
|
||||||
max_length: int
|
|
||||||
dropout_rate: float = 0.0
|
|
||||||
dtype: jnp.dtype = jnp.float32
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
|
||||||
@nn.compact
|
@nn.compact
|
||||||
def __call__(
|
def __call__(self, hidden_states):
|
||||||
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
|
hidden_states = FlaxBertLMPredictionHead(
|
||||||
):
|
vocab_size=self.vocab_size, hidden_act=self.hidden_act, name="predictions", dtype=self.dtype
|
||||||
# Model
|
)(hidden_states)
|
||||||
encoder, pooled = FlaxBertModule(
|
return hidden_states
|
||||||
vocab_size=self.vocab_size,
|
|
||||||
type_vocab_size=self.type_vocab_size,
|
|
||||||
hidden_size=self.hidden_size,
|
|
||||||
intermediate_size=self.intermediate_size,
|
|
||||||
head_size=self.hidden_size,
|
|
||||||
num_heads=self.num_heads,
|
|
||||||
num_encoder_layers=self.num_encoder_layers,
|
|
||||||
max_length=self.max_length,
|
|
||||||
dropout_rate=self.dropout_rate,
|
|
||||||
dtype=self.dtype,
|
|
||||||
)(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic)
|
|
||||||
|
|
||||||
# Compute the prediction scores
|
|
||||||
encoder = nn.Dropout(rate=self.dropout_rate)(encoder, deterministic=deterministic)
|
|
||||||
logits = FlaxBertLMPredictionHead(vocab_size=self.vocab_size, name="predictions", dtype=self.dtype)(encoder)
|
|
||||||
|
|
||||||
return logits, pooled
|
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
|
||||||
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
|
|
||||||
BERT_START_DOCSTRING,
|
|
||||||
)
|
|
||||||
class FlaxBertModel(FlaxPreTrainedModel):
|
|
||||||
"""
|
"""
|
||||||
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||||
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
|
models.
|
||||||
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
|
||||||
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_class = FlaxBertModule
|
|
||||||
config_class = BertConfig
|
config_class = BertConfig
|
||||||
base_model_prefix = "bert"
|
base_model_prefix = "bert"
|
||||||
|
|
||||||
|
def _check_inputs(self, input_ids, attention_mask, token_type_ids, position_ids):
|
||||||
|
if token_type_ids is None:
|
||||||
|
token_type_ids = jnp.ones_like(input_ids)
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = jnp.arange(jnp.atleast_2d(input_ids).shape[-1])
|
||||||
|
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = jnp.ones_like(input_ids)
|
||||||
|
|
||||||
|
return input_ids, attention_mask, token_type_ids, position_ids
|
||||||
|
|
||||||
|
def init(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||||
|
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||||
|
jnp.zeros(input_shape, dtype="i4"), None, None, None
|
||||||
|
)
|
||||||
|
|
||||||
|
params_rng, dropout_rng = jax.random.split(rng)
|
||||||
|
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||||
|
|
||||||
|
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def convert_from_pytorch(pt_state: Dict, config: BertConfig) -> Dict:
|
def convert_from_pytorch(pt_state: Dict, config: BertConfig) -> Dict:
|
||||||
jax_state = dict(pt_state)
|
jax_state = dict(pt_state)
|
||||||
@@ -501,6 +463,11 @@ class FlaxBertModel(FlaxPreTrainedModel):
|
|||||||
key = key.replace("weight", "kernel")
|
key = key.replace("weight", "kernel")
|
||||||
jax_state[key] = tensor
|
jax_state[key] = tensor
|
||||||
|
|
||||||
|
if "decoder.weight" in key:
|
||||||
|
del jax_state[key]
|
||||||
|
key = key.replace("weight", "kernel")
|
||||||
|
jax_state[key] = tensor.T
|
||||||
|
|
||||||
# SelfAttention needs also to replace "weight" by "kernel"
|
# SelfAttention needs also to replace "weight" by "kernel"
|
||||||
if {"query", "key", "value"} & key_parts:
|
if {"query", "key", "value"} & key_parts:
|
||||||
|
|
||||||
@@ -526,7 +493,7 @@ class FlaxBertModel(FlaxPreTrainedModel):
|
|||||||
jax_state[key] = tensor
|
jax_state[key] = tensor
|
||||||
|
|
||||||
# There are some transposed parameters w.r.t their PyTorch counterpart
|
# There are some transposed parameters w.r.t their PyTorch counterpart
|
||||||
if "intermediate.dense.kernel" in key or "output.dense.kernel" in key:
|
if "intermediate.dense.kernel" in key or "output.dense.kernel" in key or "transform.dense.kernel" in key:
|
||||||
jax_state[key] = tensor.T
|
jax_state[key] = tensor.T
|
||||||
|
|
||||||
# Self Attention output projection needs to be transposed
|
# Self Attention output projection needs to be transposed
|
||||||
@@ -539,6 +506,11 @@ class FlaxBertModel(FlaxPreTrainedModel):
|
|||||||
if "pooler.dense.kernel" in key:
|
if "pooler.dense.kernel" in key:
|
||||||
jax_state[key] = tensor.T
|
jax_state[key] = tensor.T
|
||||||
|
|
||||||
|
# Hack to correctly load some pytorch models
|
||||||
|
if "predictions.bias" in key:
|
||||||
|
del jax_state[key]
|
||||||
|
jax_state[".".join(key.split(".")[:2]) + ".decoder.bias"] = tensor
|
||||||
|
|
||||||
# Handle LayerNorm conversion
|
# Handle LayerNorm conversion
|
||||||
if "LayerNorm" in key:
|
if "LayerNorm" in key:
|
||||||
del jax_state[key]
|
del jax_state[key]
|
||||||
@@ -555,7 +527,22 @@ class FlaxBertModel(FlaxPreTrainedModel):
|
|||||||
|
|
||||||
return jax_state
|
return jax_state
|
||||||
|
|
||||||
def __init__(self, config: BertConfig, state: dict, seed: int = 0, dtype: jnp.dtype = jnp.float32):
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
|
||||||
|
BERT_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class FlaxBertModel(FlaxBertPreTrainedModel):
|
||||||
|
"""
|
||||||
|
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
||||||
|
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
|
||||||
|
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
||||||
|
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
||||||
|
):
|
||||||
module = FlaxBertModule(
|
module = FlaxBertModule(
|
||||||
vocab_size=config.vocab_size,
|
vocab_size=config.vocab_size,
|
||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
@@ -566,10 +553,12 @@ class FlaxBertModel(FlaxPreTrainedModel):
|
|||||||
head_size=config.hidden_size,
|
head_size=config.hidden_size,
|
||||||
intermediate_size=config.intermediate_size,
|
intermediate_size=config.intermediate_size,
|
||||||
dropout_rate=config.hidden_dropout_prob,
|
dropout_rate=config.hidden_dropout_prob,
|
||||||
|
hidden_act=config.hidden_act,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
super().__init__(config, module, state, seed)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
def __call__(
|
def __call__(
|
||||||
@@ -601,34 +590,62 @@ class FlaxBertModel(FlaxPreTrainedModel):
|
|||||||
rngs=rngs,
|
rngs=rngs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _check_inputs(self, input_ids, attention_mask, token_type_ids, position_ids):
|
|
||||||
if token_type_ids is None:
|
|
||||||
token_type_ids = jnp.ones_like(input_ids)
|
|
||||||
|
|
||||||
if position_ids is None:
|
class FlaxBertModule(nn.Module):
|
||||||
position_ids = jnp.arange(jnp.atleast_2d(input_ids).shape[-1])
|
vocab_size: int
|
||||||
|
hidden_size: int
|
||||||
|
type_vocab_size: int
|
||||||
|
max_length: int
|
||||||
|
num_encoder_layers: int
|
||||||
|
num_heads: int
|
||||||
|
head_size: int
|
||||||
|
intermediate_size: int
|
||||||
|
hidden_act: str = "gelu"
|
||||||
|
dropout_rate: float = 0.0
|
||||||
|
kernel_init_scale: float = 0.2
|
||||||
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
add_pooling_layer: bool = True
|
||||||
|
|
||||||
if attention_mask is None:
|
@nn.compact
|
||||||
attention_mask = jnp.ones_like(input_ids)
|
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
|
||||||
|
|
||||||
return input_ids, attention_mask, token_type_ids, position_ids
|
# Embedding
|
||||||
|
embeddings = FlaxBertEmbeddings(
|
||||||
|
self.vocab_size,
|
||||||
|
self.hidden_size,
|
||||||
|
self.type_vocab_size,
|
||||||
|
self.max_length,
|
||||||
|
kernel_init_scale=self.kernel_init_scale,
|
||||||
|
dropout_rate=self.dropout_rate,
|
||||||
|
name="embeddings",
|
||||||
|
dtype=self.dtype,
|
||||||
|
)(input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic)
|
||||||
|
|
||||||
def init(self, rng: jax.random.PRNGKey, input_shape: Tuple):
|
# N stacked encoding layers
|
||||||
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
encoder = FlaxBertEncoder(
|
||||||
jnp.zeros(input_shape, dtype="i4"), None, None, None
|
self.num_encoder_layers,
|
||||||
)
|
self.num_heads,
|
||||||
|
self.head_size,
|
||||||
|
self.intermediate_size,
|
||||||
|
kernel_init_scale=self.kernel_init_scale,
|
||||||
|
dropout_rate=self.dropout_rate,
|
||||||
|
hidden_act=self.hidden_act,
|
||||||
|
name="encoder",
|
||||||
|
dtype=self.dtype,
|
||||||
|
)(embeddings, attention_mask, deterministic=deterministic)
|
||||||
|
|
||||||
params_rng, dropout_rng = jax.random.split(rng)
|
if not self.add_pooling_layer:
|
||||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
return encoder
|
||||||
|
|
||||||
self.params = self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"]
|
pooled = FlaxBertPooler(kernel_init_scale=self.kernel_init_scale, name="pooler", dtype=self.dtype)(encoder)
|
||||||
|
return encoder, pooled
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertForMaskedLM(FlaxBertModel):
|
class FlaxBertForMaskedLM(FlaxBertPreTrainedModel):
|
||||||
def __init__(self, config: BertConfig, state: dict, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs):
|
def __init__(
|
||||||
super().__init__(config, state, seed, dtype)
|
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
||||||
|
):
|
||||||
self._module = FlaxBertOnlyMLMHead(
|
module = FlaxBertForMaskedLMModule(
|
||||||
vocab_size=config.vocab_size,
|
vocab_size=config.vocab_size,
|
||||||
type_vocab_size=config.type_vocab_size,
|
type_vocab_size=config.type_vocab_size,
|
||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
@@ -636,10 +653,13 @@ class FlaxBertForMaskedLM(FlaxBertModel):
|
|||||||
head_size=config.hidden_size,
|
head_size=config.hidden_size,
|
||||||
num_heads=config.num_attention_heads,
|
num_heads=config.num_attention_heads,
|
||||||
num_encoder_layers=config.num_hidden_layers,
|
num_encoder_layers=config.num_hidden_layers,
|
||||||
max_length=config.max_length,
|
max_length=config.max_position_embeddings,
|
||||||
|
hidden_act=config.hidden_act,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
@@ -659,7 +679,7 @@ class FlaxBertForMaskedLM(FlaxBertModel):
|
|||||||
if dropout_rng is not None:
|
if dropout_rng is not None:
|
||||||
rngs["dropout"] = dropout_rng
|
rngs["dropout"] = dropout_rng
|
||||||
|
|
||||||
pooled, logits = self.module.apply(
|
return self.module.apply(
|
||||||
{"params": params or self.params},
|
{"params": params or self.params},
|
||||||
jnp.array(input_ids, dtype="i4"),
|
jnp.array(input_ids, dtype="i4"),
|
||||||
jnp.array(attention_mask, dtype="i4"),
|
jnp.array(attention_mask, dtype="i4"),
|
||||||
@@ -669,4 +689,45 @@ class FlaxBertForMaskedLM(FlaxBertModel):
|
|||||||
rngs=rngs,
|
rngs=rngs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return logits, pooled
|
|
||||||
|
class FlaxBertForMaskedLMModule(nn.Module):
|
||||||
|
vocab_size: int
|
||||||
|
hidden_size: int
|
||||||
|
intermediate_size: int
|
||||||
|
head_size: int
|
||||||
|
num_heads: int
|
||||||
|
num_encoder_layers: int
|
||||||
|
type_vocab_size: int
|
||||||
|
max_length: int
|
||||||
|
hidden_act: str
|
||||||
|
dropout_rate: float = 0.0
|
||||||
|
dtype: jnp.dtype = jnp.float32
|
||||||
|
|
||||||
|
@nn.compact
|
||||||
|
def __call__(
|
||||||
|
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
|
||||||
|
):
|
||||||
|
# Model
|
||||||
|
encoder = FlaxBertModule(
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
type_vocab_size=self.type_vocab_size,
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
intermediate_size=self.intermediate_size,
|
||||||
|
head_size=self.hidden_size,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
num_encoder_layers=self.num_encoder_layers,
|
||||||
|
max_length=self.max_length,
|
||||||
|
dropout_rate=self.dropout_rate,
|
||||||
|
hidden_act=self.hidden_act,
|
||||||
|
dtype=self.dtype,
|
||||||
|
add_pooling_layer=False,
|
||||||
|
name="bert",
|
||||||
|
)(input_ids, attention_mask, token_type_ids, position_ids, deterministic=deterministic)
|
||||||
|
|
||||||
|
# Compute the prediction scores
|
||||||
|
encoder = nn.Dropout(rate=self.dropout_rate)(encoder, deterministic=deterministic)
|
||||||
|
logits = FlaxBertOnlyMLMHead(
|
||||||
|
vocab_size=self.vocab_size, hidden_act=self.hidden_act, name="cls", dtype=self.dtype
|
||||||
|
)(encoder)
|
||||||
|
|
||||||
|
return (logits,)
|
||||||
|
|||||||
@@ -19,10 +19,11 @@ import numpy as np
|
|||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
from flax.core.frozen_dict import FrozenDict
|
||||||
from jax.random import PRNGKey
|
from jax.random import PRNGKey
|
||||||
|
|
||||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
||||||
from ...modeling_flax_utils import FlaxPreTrainedModel, gelu
|
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from .configuration_roberta import RobertaConfig
|
from .configuration_roberta import RobertaConfig
|
||||||
|
|
||||||
@@ -33,6 +34,23 @@ _CONFIG_FOR_DOC = "RobertaConfig"
|
|||||||
_TOKENIZER_FOR_DOC = "RobertaTokenizer"
|
_TOKENIZER_FOR_DOC = "RobertaTokenizer"
|
||||||
|
|
||||||
|
|
||||||
|
def create_position_ids_from_input_ids(input_ids, padding_idx):
|
||||||
|
"""
|
||||||
|
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
|
||||||
|
are ignored. This is modified from fairseq's `utils.make_positions`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids: jnp.ndarray
|
||||||
|
padding_idx: int
|
||||||
|
|
||||||
|
Returns: jnp.ndarray
|
||||||
|
"""
|
||||||
|
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
|
||||||
|
mask = (input_ids != padding_idx).astype("i4")
|
||||||
|
incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask
|
||||||
|
return incremental_indices.astype("i4") + padding_idx
|
||||||
|
|
||||||
|
|
||||||
ROBERTA_START_DOCSTRING = r"""
|
ROBERTA_START_DOCSTRING = r"""
|
||||||
|
|
||||||
This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the
|
This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the
|
||||||
@@ -208,7 +226,7 @@ class FlaxRobertaAttention(nn.Module):
|
|||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
|
||||||
@nn.compact
|
@nn.compact
|
||||||
def __call__(self, hidden_state, attention_mask, deterministic: bool = True):
|
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||||
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
||||||
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
||||||
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
||||||
@@ -222,28 +240,29 @@ class FlaxRobertaAttention(nn.Module):
|
|||||||
bias_init=jax.nn.initializers.zeros,
|
bias_init=jax.nn.initializers.zeros,
|
||||||
name="self",
|
name="self",
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)(hidden_state, attention_mask)
|
)(hidden_states, attention_mask)
|
||||||
|
|
||||||
layer_norm = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(self_att + hidden_state)
|
layer_norm = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(self_att + hidden_states)
|
||||||
return layer_norm
|
return layer_norm
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta
|
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta
|
||||||
class FlaxRobertaIntermediate(nn.Module):
|
class FlaxRobertaIntermediate(nn.Module):
|
||||||
output_size: int
|
output_size: int
|
||||||
|
hidden_act: str = "gelu"
|
||||||
kernel_init_scale: float = 0.2
|
kernel_init_scale: float = 0.2
|
||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
|
||||||
@nn.compact
|
@nn.compact
|
||||||
def __call__(self, hidden_state):
|
def __call__(self, hidden_states):
|
||||||
# TODO: Add ACT2FN reference to change activation function
|
hidden_states = nn.Dense(
|
||||||
dense = nn.Dense(
|
|
||||||
features=self.output_size,
|
features=self.output_size,
|
||||||
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
||||||
name="dense",
|
name="dense",
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)(hidden_state)
|
)(hidden_states)
|
||||||
return gelu(dense)
|
hidden_states = ACT2FN[self.hidden_act](hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->Roberta
|
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->Roberta
|
||||||
@@ -254,27 +273,28 @@ class FlaxRobertaOutput(nn.Module):
|
|||||||
|
|
||||||
@nn.compact
|
@nn.compact
|
||||||
def __call__(self, intermediate_output, attention_output, deterministic: bool = True):
|
def __call__(self, intermediate_output, attention_output, deterministic: bool = True):
|
||||||
hidden_state = nn.Dense(
|
hidden_states = nn.Dense(
|
||||||
attention_output.shape[-1],
|
attention_output.shape[-1],
|
||||||
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
||||||
name="dense",
|
name="dense",
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)(intermediate_output)
|
)(intermediate_output)
|
||||||
hidden_state = nn.Dropout(rate=self.dropout_rate)(hidden_state, deterministic=deterministic)
|
hidden_states = nn.Dropout(rate=self.dropout_rate)(hidden_states, deterministic=deterministic)
|
||||||
hidden_state = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_state + attention_output)
|
hidden_states = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(hidden_states + attention_output)
|
||||||
return hidden_state
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class FlaxRobertaLayer(nn.Module):
|
class FlaxRobertaLayer(nn.Module):
|
||||||
num_heads: int
|
num_heads: int
|
||||||
head_size: int
|
head_size: int
|
||||||
intermediate_size: int
|
intermediate_size: int
|
||||||
|
hidden_act: str = "gelu"
|
||||||
dropout_rate: float = 0.0
|
dropout_rate: float = 0.0
|
||||||
kernel_init_scale: float = 0.2
|
kernel_init_scale: float = 0.2
|
||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
|
||||||
@nn.compact
|
@nn.compact
|
||||||
def __call__(self, hidden_state, attention_mask, deterministic: bool = True):
|
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||||
attention = FlaxRobertaAttention(
|
attention = FlaxRobertaAttention(
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.head_size,
|
self.head_size,
|
||||||
@@ -282,10 +302,11 @@ class FlaxRobertaLayer(nn.Module):
|
|||||||
dropout_rate=self.dropout_rate,
|
dropout_rate=self.dropout_rate,
|
||||||
name="attention",
|
name="attention",
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)(hidden_state, attention_mask, deterministic=deterministic)
|
)(hidden_states, attention_mask, deterministic=deterministic)
|
||||||
intermediate = FlaxRobertaIntermediate(
|
intermediate = FlaxRobertaIntermediate(
|
||||||
self.intermediate_size,
|
self.intermediate_size,
|
||||||
kernel_init_scale=self.kernel_init_scale,
|
kernel_init_scale=self.kernel_init_scale,
|
||||||
|
hidden_act=self.hidden_act,
|
||||||
name="intermediate",
|
name="intermediate",
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)(attention)
|
)(attention)
|
||||||
@@ -306,6 +327,7 @@ class FlaxRobertaLayerCollection(nn.Module):
|
|||||||
num_heads: int
|
num_heads: int
|
||||||
head_size: int
|
head_size: int
|
||||||
intermediate_size: int
|
intermediate_size: int
|
||||||
|
hidden_act: str = "gelu"
|
||||||
dropout_rate: float = 0.0
|
dropout_rate: float = 0.0
|
||||||
kernel_init_scale: float = 0.2
|
kernel_init_scale: float = 0.2
|
||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
@@ -325,6 +347,7 @@ class FlaxRobertaLayerCollection(nn.Module):
|
|||||||
self.intermediate_size,
|
self.intermediate_size,
|
||||||
kernel_init_scale=self.kernel_init_scale,
|
kernel_init_scale=self.kernel_init_scale,
|
||||||
dropout_rate=self.dropout_rate,
|
dropout_rate=self.dropout_rate,
|
||||||
|
hidden_act=self.hidden_act,
|
||||||
name=f"{i}",
|
name=f"{i}",
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)
|
)
|
||||||
@@ -338,22 +361,24 @@ class FlaxRobertaEncoder(nn.Module):
|
|||||||
num_heads: int
|
num_heads: int
|
||||||
head_size: int
|
head_size: int
|
||||||
intermediate_size: int
|
intermediate_size: int
|
||||||
|
hidden_act: str = "gelu"
|
||||||
dropout_rate: float = 0.0
|
dropout_rate: float = 0.0
|
||||||
kernel_init_scale: float = 0.2
|
kernel_init_scale: float = 0.2
|
||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
|
||||||
@nn.compact
|
@nn.compact
|
||||||
def __call__(self, hidden_state, attention_mask, deterministic: bool = True):
|
def __call__(self, hidden_states, attention_mask, deterministic: bool = True):
|
||||||
layer = FlaxRobertaLayerCollection(
|
layer = FlaxRobertaLayerCollection(
|
||||||
self.num_layers,
|
self.num_layers,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.head_size,
|
self.head_size,
|
||||||
self.intermediate_size,
|
self.intermediate_size,
|
||||||
|
hidden_act=self.hidden_act,
|
||||||
kernel_init_scale=self.kernel_init_scale,
|
kernel_init_scale=self.kernel_init_scale,
|
||||||
dropout_rate=self.dropout_rate,
|
dropout_rate=self.dropout_rate,
|
||||||
name="layer",
|
name="layer",
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
)(hidden_state, attention_mask, deterministic=deterministic)
|
)(hidden_states, attention_mask, deterministic=deterministic)
|
||||||
return layer
|
return layer
|
||||||
|
|
||||||
|
|
||||||
@@ -363,10 +388,10 @@ class FlaxRobertaPooler(nn.Module):
|
|||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
|
||||||
@nn.compact
|
@nn.compact
|
||||||
def __call__(self, hidden_state):
|
def __call__(self, hidden_states):
|
||||||
cls_token = hidden_state[:, 0]
|
cls_token = hidden_states[:, 0]
|
||||||
out = nn.Dense(
|
out = nn.Dense(
|
||||||
hidden_state.shape[-1],
|
hidden_states.shape[-1],
|
||||||
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype),
|
||||||
name="dense",
|
name="dense",
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
@@ -374,64 +399,12 @@ class FlaxRobertaPooler(nn.Module):
|
|||||||
return nn.tanh(out)
|
return nn.tanh(out)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->Roberta
|
class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
|
||||||
class FlaxRobertaModule(nn.Module):
|
|
||||||
vocab_size: int
|
|
||||||
hidden_size: int
|
|
||||||
type_vocab_size: int
|
|
||||||
max_length: int
|
|
||||||
num_encoder_layers: int
|
|
||||||
num_heads: int
|
|
||||||
head_size: int
|
|
||||||
intermediate_size: int
|
|
||||||
dropout_rate: float = 0.0
|
|
||||||
kernel_init_scale: float = 0.2
|
|
||||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
|
||||||
|
|
||||||
@nn.compact
|
|
||||||
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
|
|
||||||
|
|
||||||
# Embedding
|
|
||||||
embeddings = FlaxRobertaEmbeddings(
|
|
||||||
self.vocab_size,
|
|
||||||
self.hidden_size,
|
|
||||||
self.type_vocab_size,
|
|
||||||
self.max_length,
|
|
||||||
kernel_init_scale=self.kernel_init_scale,
|
|
||||||
dropout_rate=self.dropout_rate,
|
|
||||||
name="embeddings",
|
|
||||||
dtype=self.dtype,
|
|
||||||
)(input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic)
|
|
||||||
|
|
||||||
# N stacked encoding layers
|
|
||||||
encoder = FlaxRobertaEncoder(
|
|
||||||
self.num_encoder_layers,
|
|
||||||
self.num_heads,
|
|
||||||
self.head_size,
|
|
||||||
self.intermediate_size,
|
|
||||||
kernel_init_scale=self.kernel_init_scale,
|
|
||||||
dropout_rate=self.dropout_rate,
|
|
||||||
name="encoder",
|
|
||||||
dtype=self.dtype,
|
|
||||||
)(embeddings, attention_mask, deterministic=deterministic)
|
|
||||||
|
|
||||||
pooled = FlaxRobertaPooler(kernel_init_scale=self.kernel_init_scale, name="pooler", dtype=self.dtype)(encoder)
|
|
||||||
return encoder, pooled
|
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
|
||||||
"The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
|
|
||||||
ROBERTA_START_DOCSTRING,
|
|
||||||
)
|
|
||||||
class FlaxRobertaModel(FlaxPreTrainedModel):
|
|
||||||
"""
|
"""
|
||||||
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||||
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
|
models.
|
||||||
all you need`_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
|
|
||||||
Kaiser and Illia Polosukhin.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_class = FlaxRobertaModule
|
|
||||||
config_class = RobertaConfig
|
config_class = RobertaConfig
|
||||||
base_model_prefix = "roberta"
|
base_model_prefix = "roberta"
|
||||||
|
|
||||||
@@ -504,7 +477,49 @@ class FlaxRobertaModel(FlaxPreTrainedModel):
|
|||||||
|
|
||||||
return jax_state
|
return jax_state
|
||||||
|
|
||||||
def __init__(self, config: RobertaConfig, state: dict, seed: int = 0, dtype: jnp.dtype = jnp.float32):
|
def init(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||||
|
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||||
|
jnp.zeros(input_shape, dtype="i4"), None, None, None
|
||||||
|
)
|
||||||
|
|
||||||
|
params_rng, dropout_rng = jax.random.split(rng)
|
||||||
|
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||||
|
|
||||||
|
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"]
|
||||||
|
|
||||||
|
def _check_inputs(self, input_ids, attention_mask, token_type_ids, position_ids):
|
||||||
|
if token_type_ids is None:
|
||||||
|
token_type_ids = jnp.ones_like(input_ids)
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id)
|
||||||
|
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = jnp.ones_like(input_ids)
|
||||||
|
|
||||||
|
return input_ids, attention_mask, token_type_ids, position_ids
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
|
||||||
|
ROBERTA_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class FlaxRobertaModel(FlaxRobertaPreTrainedModel):
|
||||||
|
"""
|
||||||
|
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
||||||
|
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
|
||||||
|
all you need`_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
|
||||||
|
Kaiser and Illia Polosukhin.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: RobertaConfig,
|
||||||
|
input_shape: Tuple = (1, 1),
|
||||||
|
seed: int = 0,
|
||||||
|
dtype: jnp.dtype = jnp.float32,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
module = FlaxRobertaModule(
|
module = FlaxRobertaModule(
|
||||||
vocab_size=config.vocab_size,
|
vocab_size=config.vocab_size,
|
||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
@@ -513,12 +528,14 @@ class FlaxRobertaModel(FlaxPreTrainedModel):
|
|||||||
num_encoder_layers=config.num_hidden_layers,
|
num_encoder_layers=config.num_hidden_layers,
|
||||||
num_heads=config.num_attention_heads,
|
num_heads=config.num_attention_heads,
|
||||||
head_size=config.hidden_size,
|
head_size=config.hidden_size,
|
||||||
|
hidden_act=config.hidden_act,
|
||||||
intermediate_size=config.intermediate_size,
|
intermediate_size=config.intermediate_size,
|
||||||
dropout_rate=config.hidden_dropout_prob,
|
dropout_rate=config.hidden_dropout_prob,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
super().__init__(config, module, state, seed)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
def __call__(
|
def __call__(
|
||||||
@@ -550,42 +567,53 @@ class FlaxRobertaModel(FlaxPreTrainedModel):
|
|||||||
rngs=rngs,
|
rngs=rngs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def init(self, rng: jax.random.PRNGKey, input_shape: Tuple):
|
|
||||||
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
|
||||||
jnp.zeros(input_shape, dtype="i4"), None, None, None
|
|
||||||
)
|
|
||||||
|
|
||||||
params_rng, dropout_rng = jax.random.split(rng)
|
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->Roberta
|
||||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
class FlaxRobertaModule(nn.Module):
|
||||||
|
vocab_size: int
|
||||||
|
hidden_size: int
|
||||||
|
type_vocab_size: int
|
||||||
|
max_length: int
|
||||||
|
num_encoder_layers: int
|
||||||
|
num_heads: int
|
||||||
|
head_size: int
|
||||||
|
intermediate_size: int
|
||||||
|
hidden_act: str = "gelu"
|
||||||
|
dropout_rate: float = 0.0
|
||||||
|
kernel_init_scale: float = 0.2
|
||||||
|
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||||
|
add_pooling_layer: bool = True
|
||||||
|
|
||||||
self.params = self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids)["params"]
|
@nn.compact
|
||||||
|
def __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic: bool = True):
|
||||||
|
|
||||||
def _check_inputs(self, input_ids, attention_mask, token_type_ids, position_ids):
|
# Embedding
|
||||||
|
embeddings = FlaxRobertaEmbeddings(
|
||||||
|
self.vocab_size,
|
||||||
|
self.hidden_size,
|
||||||
|
self.type_vocab_size,
|
||||||
|
self.max_length,
|
||||||
|
kernel_init_scale=self.kernel_init_scale,
|
||||||
|
dropout_rate=self.dropout_rate,
|
||||||
|
name="embeddings",
|
||||||
|
dtype=self.dtype,
|
||||||
|
)(input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic)
|
||||||
|
|
||||||
if token_type_ids is None:
|
# N stacked encoding layers
|
||||||
token_type_ids = jnp.ones_like(input_ids)
|
encoder = FlaxRobertaEncoder(
|
||||||
|
self.num_encoder_layers,
|
||||||
|
self.num_heads,
|
||||||
|
self.head_size,
|
||||||
|
self.intermediate_size,
|
||||||
|
kernel_init_scale=self.kernel_init_scale,
|
||||||
|
dropout_rate=self.dropout_rate,
|
||||||
|
hidden_act=self.hidden_act,
|
||||||
|
name="encoder",
|
||||||
|
dtype=self.dtype,
|
||||||
|
)(embeddings, attention_mask, deterministic=deterministic)
|
||||||
|
|
||||||
if position_ids is None:
|
if not self.add_pooling_layer:
|
||||||
position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id)
|
return encoder
|
||||||
|
|
||||||
if attention_mask is None:
|
pooled = FlaxRobertaPooler(kernel_init_scale=self.kernel_init_scale, name="pooler", dtype=self.dtype)(encoder)
|
||||||
attention_mask = jnp.ones_like(input_ids)
|
return encoder, pooled
|
||||||
|
|
||||||
return input_ids, attention_mask, token_type_ids, position_ids
|
|
||||||
|
|
||||||
|
|
||||||
def create_position_ids_from_input_ids(input_ids, padding_idx):
|
|
||||||
"""
|
|
||||||
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
|
|
||||||
are ignored. This is modified from fairseq's `utils.make_positions`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_ids: jnp.ndarray
|
|
||||||
padding_idx: int
|
|
||||||
|
|
||||||
Returns: jnp.ndarray
|
|
||||||
"""
|
|
||||||
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
|
|
||||||
mask = (input_ids != padding_idx).astype("i4")
|
|
||||||
incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask
|
|
||||||
return incremental_indices.astype("i4") + padding_idx
|
|
||||||
|
|||||||
@@ -2,6 +2,15 @@
|
|||||||
from ..file_utils import requires_flax
|
from ..file_utils import requires_flax
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxPreTrainedModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_flax(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(self, *args, **kwargs):
|
||||||
|
requires_flax(self)
|
||||||
|
|
||||||
|
|
||||||
FLAX_MODEL_MAPPING = None
|
FLAX_MODEL_MAPPING = None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -14,14 +14,16 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from transformers import BertConfig, is_flax_available
|
from transformers import BertConfig, is_flax_available
|
||||||
from transformers.testing_utils import require_flax
|
from transformers.testing_utils import require_flax, slow
|
||||||
|
|
||||||
from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
|
from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
|
||||||
|
|
||||||
|
|
||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
from transformers.models.bert.modeling_flax_bert import FlaxBertModel
|
from transformers.models.bert.modeling_flax_bert import FlaxBertForMaskedLM, FlaxBertModel
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertModelTester(unittest.TestCase):
|
class FlaxBertModelTester(unittest.TestCase):
|
||||||
@@ -105,7 +107,14 @@ class FlaxBertModelTester(unittest.TestCase):
|
|||||||
@require_flax
|
@require_flax
|
||||||
class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
all_model_classes = (FlaxBertModel,) if is_flax_available() else ()
|
all_model_classes = (FlaxBertModel, FlaxBertForMaskedLM) if is_flax_available() else ()
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = FlaxBertModelTester(self)
|
self.model_tester = FlaxBertModelTester(self)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_model_from_pretrained(self):
|
||||||
|
for model_class_name in self.all_model_classes:
|
||||||
|
model = model_class_name.from_pretrained("bert-base-cased")
|
||||||
|
outputs = model(np.ones((1, 1)))
|
||||||
|
self.assertIsNotNone(outputs)
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import random
|
import random
|
||||||
|
import tempfile
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -26,7 +27,7 @@ if is_flax_available():
|
|||||||
|
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from flax.traverse_util import unflatten_dict
|
from transformers.modeling_flax_utils import convert_state_dict_from_pt
|
||||||
|
|
||||||
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
|
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
|
||||||
|
|
||||||
@@ -59,21 +60,13 @@ def random_attention_mask(shape, rng=None):
|
|||||||
return attn_mask
|
return attn_mask
|
||||||
|
|
||||||
|
|
||||||
def convert_pt_model_to_flax(pt_model, config, flax_model_cls):
|
|
||||||
state = pt_model.state_dict()
|
|
||||||
state = {k: v.numpy() for k, v in state.items()}
|
|
||||||
state = flax_model_cls.convert_from_pytorch(state, config)
|
|
||||||
state = unflatten_dict({tuple(k.split(".")): v for k, v in state.items()})
|
|
||||||
return flax_model_cls(config, state, dtype=jnp.float32)
|
|
||||||
|
|
||||||
|
|
||||||
@require_flax
|
@require_flax
|
||||||
class FlaxModelTesterMixin:
|
class FlaxModelTesterMixin:
|
||||||
model_tester = None
|
model_tester = None
|
||||||
all_model_classes = ()
|
all_model_classes = ()
|
||||||
|
|
||||||
def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
|
def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
|
||||||
diff = np.abs((a - b)).sum()
|
diff = np.abs((a - b)).max()
|
||||||
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
|
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@@ -86,30 +79,54 @@ class FlaxModelTesterMixin:
|
|||||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||||
pt_model = pt_model_class(config).eval()
|
pt_model = pt_model_class(config).eval()
|
||||||
|
|
||||||
fx_model = convert_pt_model_to_flax(pt_model, config, model_class)
|
fx_state = convert_state_dict_from_pt(model_class, pt_model.state_dict(), config)
|
||||||
|
fx_model = model_class(config, dtype=jnp.float32)
|
||||||
|
fx_model.params = fx_state
|
||||||
|
|
||||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in inputs_dict.items()}
|
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in inputs_dict.items()}
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||||
|
|
||||||
fx_outputs = fx_model(**inputs_dict)
|
fx_outputs = fx_model(**inputs_dict)
|
||||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 5e-3)
|
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
pt_model.save_pretrained(tmpdirname)
|
||||||
|
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||||
|
|
||||||
|
fx_outputs_loaded = fx_model_loaded(**inputs_dict)
|
||||||
|
self.assertEqual(
|
||||||
|
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
|
||||||
|
)
|
||||||
|
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
|
||||||
|
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 5e-3)
|
||||||
|
|
||||||
|
def test_from_pretrained_save_pretrained(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
with self.subTest(model_class.__name__):
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
outputs = model(**inputs_dict)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model.save_pretrained(tmpdirname)
|
||||||
|
model_loaded = model_class.from_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
outputs_loaded = model_loaded(**inputs_dict)
|
||||||
|
for output_loaded, output in zip(outputs_loaded, outputs):
|
||||||
|
self.assert_almost_equals(output_loaded, output, 5e-3)
|
||||||
|
|
||||||
@require_torch
|
|
||||||
def test_jit_compilation(self):
|
def test_jit_compilation(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
with self.subTest(model_class.__name__):
|
with self.subTest(model_class.__name__):
|
||||||
|
model = model_class(config)
|
||||||
# TODO later: have some way to initialize easily a Flax model from config, for now I go through PT
|
|
||||||
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
|
|
||||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
|
||||||
pt_model = pt_model_class(config).eval()
|
|
||||||
|
|
||||||
model = convert_pt_model_to_flax(pt_model, config, model_class)
|
|
||||||
|
|
||||||
@jax.jit
|
@jax.jit
|
||||||
def model_jitted(input_ids, attention_mask=None, token_type_ids=None):
|
def model_jitted(input_ids, attention_mask=None, token_type_ids=None):
|
||||||
@@ -125,3 +142,14 @@ class FlaxModelTesterMixin:
|
|||||||
self.assertEqual(len(outputs), len(jitted_outputs))
|
self.assertEqual(len(outputs), len(jitted_outputs))
|
||||||
for jitted_output, output in zip(jitted_outputs, outputs):
|
for jitted_output, output in zip(jitted_outputs, outputs):
|
||||||
self.assertEqual(jitted_output.shape, output.shape)
|
self.assertEqual(jitted_output.shape, output.shape)
|
||||||
|
|
||||||
|
def test_naming_convention(self):
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model_class_name = model_class.__name__
|
||||||
|
module_class_name = (
|
||||||
|
model_class_name[:-5] + "Module" if model_class_name[-5:] == "Model" else model_class_name + "Module"
|
||||||
|
)
|
||||||
|
bert_modeling_flax_module = __import__(model_class.__module__, fromlist=[module_class_name])
|
||||||
|
module_cls = getattr(bert_modeling_flax_module, module_class_name)
|
||||||
|
|
||||||
|
self.assertIsNotNone(module_cls)
|
||||||
|
|||||||
@@ -14,8 +14,10 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from transformers import RobertaConfig, is_flax_available
|
from transformers import RobertaConfig, is_flax_available
|
||||||
from transformers.testing_utils import require_flax
|
from transformers.testing_utils import require_flax, slow
|
||||||
|
|
||||||
from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
|
from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
|
||||||
|
|
||||||
@@ -109,3 +111,10 @@ class FlaxRobertaModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = FlaxRobertaModelTester(self)
|
self.model_tester = FlaxRobertaModelTester(self)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_model_from_pretrained(self):
|
||||||
|
for model_class_name in self.all_model_classes:
|
||||||
|
model = model_class_name.from_pretrained("roberta-base")
|
||||||
|
outputs = model(np.ones((1, 1)))
|
||||||
|
self.assertIsNotNone(outputs)
|
||||||
|
|||||||
Reference in New Issue
Block a user