[Flax] improve large model init and loading (#16148)
* begin do_init * add params_shape_tree * raise error if params are accessed when do_init is False * don't allow do_init=False when keys are missing * make shape tree a property * assign self._params at the end * add test for do_init * add do_init arg to all flax models * fix param setting * disbale do_init for composite models * update test * add do_init in FlaxBigBirdForMultipleChoice * better names and errors * improve test * style * add a warning when do_init=False * remove extra if * set params after _required_params * add test for from_pretrained * do_init => _do_init * chage warning to info * fix typo * add params in init_weights * add params to gpt neo init * add params to init_weights * update do_init test * Trigger CI * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * update template * trigger CI * style * style * fix template Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -140,7 +140,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
|
|||||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||||
|
|
||||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||||
# init input tensor
|
# init input tensor
|
||||||
input_ids = jnp.zeros(input_shape[0], dtype="i4")
|
input_ids = jnp.zeros(input_shape[0], dtype="i4")
|
||||||
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0])
|
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0])
|
||||||
|
|||||||
@@ -90,6 +90,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
base_model_prefix = ""
|
base_model_prefix = ""
|
||||||
main_input_name = "input_ids"
|
main_input_name = "input_ids"
|
||||||
_auto_class = None
|
_auto_class = None
|
||||||
|
_missing_keys = set()
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -98,6 +99,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
input_shape: Tuple = (1, 1),
|
input_shape: Tuple = (1, 1),
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
dtype: jnp.dtype = jnp.float32,
|
dtype: jnp.dtype = jnp.float32,
|
||||||
|
_do_init: bool = True,
|
||||||
):
|
):
|
||||||
if config is None:
|
if config is None:
|
||||||
raise ValueError("config cannot be None")
|
raise ValueError("config cannot be None")
|
||||||
@@ -112,15 +114,35 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
# Those are public as their type is generic to every derived classes.
|
# Those are public as their type is generic to every derived classes.
|
||||||
self.key = PRNGKey(seed)
|
self.key = PRNGKey(seed)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
self.input_shape = input_shape
|
||||||
|
|
||||||
# randomly initialized parameters
|
# To check if the model was intialized automatically.
|
||||||
random_params = self.init_weights(self.key, input_shape)
|
self._is_initialized = _do_init
|
||||||
|
|
||||||
|
if _do_init:
|
||||||
|
# randomly initialized parameters
|
||||||
|
random_params = self.init_weights(self.key, input_shape)
|
||||||
|
params_shape_tree = jax.eval_shape(lambda params: params, random_params)
|
||||||
|
else:
|
||||||
|
init_fn = partial(self.init_weights, input_shape=input_shape)
|
||||||
|
params_shape_tree = jax.eval_shape(init_fn, self.key)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Model weights are not initialized as `_do_init` is set to `False`. "
|
||||||
|
f"Make sure to call `{self.__class__.__name__}.init_weights` manually to initialize the weights."
|
||||||
|
)
|
||||||
|
|
||||||
|
# get the shape of the parameters
|
||||||
|
self._params_shape_tree = params_shape_tree
|
||||||
|
|
||||||
# save required_params as set
|
# save required_params as set
|
||||||
self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
|
self._required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
|
||||||
self.params = random_params
|
|
||||||
|
|
||||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> Dict:
|
# initialize the parameters
|
||||||
|
if _do_init:
|
||||||
|
self.params = random_params
|
||||||
|
|
||||||
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> Dict:
|
||||||
raise NotImplementedError(f"init method has to be implemented for {self}")
|
raise NotImplementedError(f"init method has to be implemented for {self}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -147,14 +169,31 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def params(self) -> Union[Dict, FrozenDict]:
|
def params(self) -> Union[Dict, FrozenDict]:
|
||||||
|
if not self._is_initialized:
|
||||||
|
raise ValueError(
|
||||||
|
"`params` cannot be accessed from model when the model is created with `_do_init=False`. "
|
||||||
|
"You must call `init_weights` manually and store the params outside of the model and "
|
||||||
|
"pass it explicitly where needed."
|
||||||
|
)
|
||||||
return self._params
|
return self._params
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def required_params(self) -> Set:
|
def required_params(self) -> Set:
|
||||||
return self._required_params
|
return self._required_params
|
||||||
|
|
||||||
|
@property
|
||||||
|
def params_shape_tree(self) -> Dict:
|
||||||
|
return self._params_shape_tree
|
||||||
|
|
||||||
@params.setter
|
@params.setter
|
||||||
def params(self, params: Union[Dict, FrozenDict]):
|
def params(self, params: Union[Dict, FrozenDict]):
|
||||||
|
# don't set params if the model is not initialized
|
||||||
|
if not self._is_initialized:
|
||||||
|
raise ValueError(
|
||||||
|
"`params` cannot be set from model when the model is created with `_do_init=False`. "
|
||||||
|
"You store the params outside of the model."
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(params, FrozenDict):
|
if isinstance(params, FrozenDict):
|
||||||
params = unfreeze(params)
|
params = unfreeze(params)
|
||||||
param_keys = set(flatten_dict(params).keys())
|
param_keys = set(flatten_dict(params).keys())
|
||||||
@@ -417,6 +456,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
revision = kwargs.pop("revision", None)
|
revision = kwargs.pop("revision", None)
|
||||||
from_pipeline = kwargs.pop("_from_pipeline", None)
|
from_pipeline = kwargs.pop("_from_pipeline", None)
|
||||||
from_auto_class = kwargs.pop("_from_auto", False)
|
from_auto_class = kwargs.pop("_from_auto", False)
|
||||||
|
_do_init = kwargs.pop("_do_init", True)
|
||||||
|
|
||||||
user_agent = {"file_type": "model", "framework": "flax", "from_auto_class": from_auto_class}
|
user_agent = {"file_type": "model", "framework": "flax", "from_auto_class": from_auto_class}
|
||||||
if from_pipeline is not None:
|
if from_pipeline is not None:
|
||||||
@@ -553,7 +593,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
resolved_archive_file = None
|
resolved_archive_file = None
|
||||||
|
|
||||||
# init random models
|
# init random models
|
||||||
model = cls(config, *model_args, **model_kwargs)
|
model = cls(config, *model_args, _do_init=_do_init, **model_kwargs)
|
||||||
|
|
||||||
if from_pt:
|
if from_pt:
|
||||||
state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file)
|
state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file)
|
||||||
@@ -577,25 +617,36 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
# make sure all arrays are stored as jnp.arrays
|
# make sure all arrays are stored as jnp.arrays
|
||||||
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
|
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
|
||||||
# https://github.com/google/flax/issues/1261
|
# https://github.com/google/flax/issues/1261
|
||||||
state = jax.tree_util.tree_map(jnp.array, state)
|
if _do_init:
|
||||||
|
state = jax.tree_util.tree_map(jnp.array, state)
|
||||||
|
else:
|
||||||
|
# keep the params on CPU if we don't want to initialize
|
||||||
|
state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state)
|
||||||
|
|
||||||
# if model is base model only use model_prefix key
|
# if model is base model only use model_prefix key
|
||||||
if cls.base_model_prefix not in dict(model.params) and cls.base_model_prefix in state:
|
if cls.base_model_prefix not in dict(model.params_shape_tree) and cls.base_model_prefix in state:
|
||||||
state = state[cls.base_model_prefix]
|
state = state[cls.base_model_prefix]
|
||||||
|
|
||||||
# if model is head model and we are loading weights from base model
|
# if model is head model and we are loading weights from base model
|
||||||
# we initialize new params dict with base_model_prefix
|
# we initialize new params dict with base_model_prefix
|
||||||
if cls.base_model_prefix in dict(model.params) and cls.base_model_prefix not in state:
|
if cls.base_model_prefix in dict(model.params_shape_tree) and cls.base_model_prefix not in state:
|
||||||
state = {cls.base_model_prefix: state}
|
state = {cls.base_model_prefix: state}
|
||||||
|
|
||||||
# flatten dicts
|
# flatten dicts
|
||||||
state = flatten_dict(state)
|
state = flatten_dict(state)
|
||||||
|
|
||||||
random_state = flatten_dict(unfreeze(model.params))
|
random_state = flatten_dict(unfreeze(model.params if _do_init else model.params_shape_tree))
|
||||||
|
|
||||||
missing_keys = model.required_params - set(state.keys())
|
missing_keys = model.required_params - set(state.keys())
|
||||||
unexpected_keys = set(state.keys()) - model.required_params
|
unexpected_keys = set(state.keys()) - model.required_params
|
||||||
|
|
||||||
|
if missing_keys and not _do_init:
|
||||||
|
logger.warn(
|
||||||
|
f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. "
|
||||||
|
f"Make sure to call model.init_weights to initialize the missing weights."
|
||||||
|
)
|
||||||
|
cls._missing_keys = missing_keys
|
||||||
|
|
||||||
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
|
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
|
||||||
# matching the weights in the model.
|
# matching the weights in the model.
|
||||||
mismatched_keys = []
|
mismatched_keys = []
|
||||||
@@ -612,9 +663,10 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
"model."
|
"model."
|
||||||
)
|
)
|
||||||
|
|
||||||
# add missing keys as random parameters
|
# add missing keys as random parameters if we are initializing
|
||||||
for missing_key in missing_keys:
|
if missing_keys and _do_init:
|
||||||
state[missing_key] = random_state[missing_key]
|
for missing_key in missing_keys:
|
||||||
|
state[missing_key] = random_state[missing_key]
|
||||||
|
|
||||||
# remove unexpected keys to not be saved again
|
# remove unexpected keys to not be saved again
|
||||||
for unexpected_key in unexpected_keys:
|
for unexpected_key in unexpected_keys:
|
||||||
@@ -680,10 +732,12 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
"See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this."
|
"See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this."
|
||||||
)
|
)
|
||||||
|
|
||||||
# set correct parameters
|
if _do_init:
|
||||||
model.params = unflatten_dict(state)
|
# set correct parameters
|
||||||
|
model.params = unflatten_dict(state)
|
||||||
return model
|
return model
|
||||||
|
else:
|
||||||
|
return model, unflatten_dict(state)
|
||||||
|
|
||||||
def save_pretrained(self, save_directory: Union[str, os.PathLike], params=None, push_to_hub=False, **kwargs):
|
def save_pretrained(self, save_directory: Union[str, os.PathLike], params=None, push_to_hub=False, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -21,8 +21,9 @@ import flax
|
|||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from flax.core.frozen_dict import FrozenDict
|
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||||
from flax.linen.attention import dot_product_attention_weights
|
from flax.linen.attention import dot_product_attention_weights
|
||||||
|
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||||
from jax import lax
|
from jax import lax
|
||||||
|
|
||||||
from ...modeling_flax_outputs import (
|
from ...modeling_flax_outputs import (
|
||||||
@@ -522,12 +523,13 @@ class FlaxAlbertPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
input_shape: Tuple = (1, 1),
|
input_shape: Tuple = (1, 1),
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
dtype: jnp.dtype = jnp.float32,
|
dtype: jnp.dtype = jnp.float32,
|
||||||
|
_do_init: bool = True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||||
|
|
||||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||||
# init input tensors
|
# init input tensors
|
||||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||||
token_type_ids = jnp.zeros_like(input_ids)
|
token_type_ids = jnp.zeros_like(input_ids)
|
||||||
@@ -537,9 +539,19 @@ class FlaxAlbertPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
params_rng, dropout_rng = jax.random.split(rng)
|
params_rng, dropout_rng = jax.random.split(rng)
|
||||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||||
|
|
||||||
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, position_ids, return_dict=False)[
|
random_params = self.module.init(
|
||||||
"params"
|
rngs, input_ids, attention_mask, token_type_ids, position_ids, return_dict=False
|
||||||
]
|
)["params"]
|
||||||
|
|
||||||
|
if params is not None:
|
||||||
|
random_params = flatten_dict(unfreeze(random_params))
|
||||||
|
params = flatten_dict(unfreeze(params))
|
||||||
|
for missing_key in self._missing_keys:
|
||||||
|
params[missing_key] = random_params[missing_key]
|
||||||
|
self._missing_keys = set()
|
||||||
|
return freeze(unflatten_dict(params))
|
||||||
|
else:
|
||||||
|
return random_params
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
def __call__(
|
def __call__(
|
||||||
|
|||||||
@@ -24,9 +24,10 @@ import numpy as np
|
|||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from flax.core.frozen_dict import FrozenDict, unfreeze
|
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||||
from flax.linen import combine_masks, make_causal_mask
|
from flax.linen import combine_masks, make_causal_mask
|
||||||
from flax.linen.attention import dot_product_attention_weights
|
from flax.linen.attention import dot_product_attention_weights
|
||||||
|
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||||
from jax import lax
|
from jax import lax
|
||||||
from jax.random import PRNGKey
|
from jax.random import PRNGKey
|
||||||
|
|
||||||
@@ -912,12 +913,13 @@ class FlaxBartPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
input_shape: Tuple[int] = (1, 1),
|
input_shape: Tuple[int] = (1, 1),
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
dtype: jnp.dtype = jnp.float32,
|
dtype: jnp.dtype = jnp.float32,
|
||||||
|
_do_init: bool = True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||||
|
|
||||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||||
# init input tensors
|
# init input tensors
|
||||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||||
# make sure initialization pass will work for FlaxBartForSequenceClassificationModule
|
# make sure initialization pass will work for FlaxBartForSequenceClassificationModule
|
||||||
@@ -933,7 +935,7 @@ class FlaxBartPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
params_rng, dropout_rng = jax.random.split(rng)
|
params_rng, dropout_rng = jax.random.split(rng)
|
||||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||||
|
|
||||||
return self.module.init(
|
random_params = self.module.init(
|
||||||
rngs,
|
rngs,
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
@@ -943,6 +945,16 @@ class FlaxBartPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
decoder_position_ids,
|
decoder_position_ids,
|
||||||
)["params"]
|
)["params"]
|
||||||
|
|
||||||
|
if params is not None:
|
||||||
|
random_params = flatten_dict(unfreeze(random_params))
|
||||||
|
params = flatten_dict(unfreeze(params))
|
||||||
|
for missing_key in self._missing_keys:
|
||||||
|
params[missing_key] = random_params[missing_key]
|
||||||
|
self._missing_keys = set()
|
||||||
|
return freeze(unflatten_dict(params))
|
||||||
|
else:
|
||||||
|
return random_params
|
||||||
|
|
||||||
def init_cache(self, batch_size, max_length, encoder_outputs):
|
def init_cache(self, batch_size, max_length, encoder_outputs):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
@@ -1737,14 +1749,15 @@ class FlaxBartDecoderPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
input_shape: Tuple[int] = (1, 1),
|
input_shape: Tuple[int] = (1, 1),
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
dtype: jnp.dtype = jnp.float32,
|
dtype: jnp.dtype = jnp.float32,
|
||||||
|
_do_init: bool = True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
config.is_decoder = True
|
config.is_decoder = True
|
||||||
config.is_encoder_decoder = False
|
config.is_encoder_decoder = False
|
||||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||||
|
|
||||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||||
# init input tensors
|
# init input tensors
|
||||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||||
attention_mask = jnp.ones_like(input_ids)
|
attention_mask = jnp.ones_like(input_ids)
|
||||||
|
|||||||
@@ -22,8 +22,9 @@ import flax
|
|||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from flax.core.frozen_dict import FrozenDict
|
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||||
from flax.linen.attention import dot_product_attention_weights
|
from flax.linen.attention import dot_product_attention_weights
|
||||||
|
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||||
|
|
||||||
from ...modeling_flax_outputs import (
|
from ...modeling_flax_outputs import (
|
||||||
FlaxBaseModelOutput,
|
FlaxBaseModelOutput,
|
||||||
@@ -591,13 +592,21 @@ class FlaxBeitPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
main_input_name = "pixel_values"
|
main_input_name = "pixel_values"
|
||||||
module_class: nn.Module = None
|
module_class: nn.Module = None
|
||||||
|
|
||||||
def __init__(self, config: BeitConfig, input_shape=None, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs):
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: BeitConfig,
|
||||||
|
input_shape=None,
|
||||||
|
seed: int = 0,
|
||||||
|
dtype: jnp.dtype = jnp.float32,
|
||||||
|
_do_init: bool = True,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||||
if input_shape is None:
|
if input_shape is None:
|
||||||
input_shape = (1, config.image_size, config.image_size, 3)
|
input_shape = (1, config.image_size, config.image_size, 3)
|
||||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||||
|
|
||||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||||
# init input tensors
|
# init input tensors
|
||||||
pixel_values = jnp.zeros(input_shape, dtype=self.dtype)
|
pixel_values = jnp.zeros(input_shape, dtype=self.dtype)
|
||||||
|
|
||||||
@@ -605,7 +614,17 @@ class FlaxBeitPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
dropout_rng, droppath_rng = jax.random.split(dropout_rng)
|
dropout_rng, droppath_rng = jax.random.split(dropout_rng)
|
||||||
rngs = {"params": params_rng, "dropout": dropout_rng, "droppath": droppath_rng}
|
rngs = {"params": params_rng, "dropout": dropout_rng, "droppath": droppath_rng}
|
||||||
|
|
||||||
return self.module.init(rngs, pixel_values, return_dict=False)["params"]
|
random_params = self.module.init(rngs, pixel_values, return_dict=False)["params"]
|
||||||
|
|
||||||
|
if params is not None:
|
||||||
|
random_params = flatten_dict(unfreeze(random_params))
|
||||||
|
params = flatten_dict(unfreeze(params))
|
||||||
|
for missing_key in self._missing_keys:
|
||||||
|
params[missing_key] = random_params[missing_key]
|
||||||
|
self._missing_keys = set()
|
||||||
|
return freeze(unflatten_dict(params))
|
||||||
|
else:
|
||||||
|
return random_params
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
def __call__(
|
def __call__(
|
||||||
|
|||||||
@@ -21,8 +21,9 @@ import flax
|
|||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from flax.core.frozen_dict import FrozenDict
|
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||||
from flax.linen.attention import dot_product_attention_weights
|
from flax.linen.attention import dot_product_attention_weights
|
||||||
|
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||||
from jax import lax
|
from jax import lax
|
||||||
|
|
||||||
from ...modeling_flax_outputs import (
|
from ...modeling_flax_outputs import (
|
||||||
@@ -616,12 +617,18 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
module_class: nn.Module = None
|
module_class: nn.Module = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
self,
|
||||||
|
config: BertConfig,
|
||||||
|
input_shape: Tuple = (1, 1),
|
||||||
|
seed: int = 0,
|
||||||
|
dtype: jnp.dtype = jnp.float32,
|
||||||
|
_do_init: bool = True,
|
||||||
|
**kwargs
|
||||||
):
|
):
|
||||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||||
|
|
||||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||||
# init input tensors
|
# init input tensors
|
||||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||||
token_type_ids = jnp.zeros_like(input_ids)
|
token_type_ids = jnp.zeros_like(input_ids)
|
||||||
@@ -632,10 +639,20 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
params_rng, dropout_rng = jax.random.split(rng)
|
params_rng, dropout_rng = jax.random.split(rng)
|
||||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||||
|
|
||||||
return self.module.init(
|
random_params = self.module.init(
|
||||||
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
|
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
|
||||||
)["params"]
|
)["params"]
|
||||||
|
|
||||||
|
if params is not None:
|
||||||
|
random_params = flatten_dict(unfreeze(random_params))
|
||||||
|
params = flatten_dict(unfreeze(params))
|
||||||
|
for missing_key in self._missing_keys:
|
||||||
|
params[missing_key] = random_params[missing_key]
|
||||||
|
self._missing_keys = set()
|
||||||
|
return freeze(unflatten_dict(params))
|
||||||
|
else:
|
||||||
|
return random_params
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -21,8 +21,9 @@ import flax
|
|||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from flax.core.frozen_dict import FrozenDict
|
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||||
from flax.linen.attention import dot_product_attention_weights
|
from flax.linen.attention import dot_product_attention_weights
|
||||||
|
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||||
from jax import lax
|
from jax import lax
|
||||||
|
|
||||||
from ...modeling_flax_outputs import (
|
from ...modeling_flax_outputs import (
|
||||||
@@ -1420,6 +1421,7 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
input_shape: Optional[tuple] = None,
|
input_shape: Optional[tuple] = None,
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
dtype: jnp.dtype = jnp.float32,
|
dtype: jnp.dtype = jnp.float32,
|
||||||
|
_do_init: bool = True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||||
@@ -1428,9 +1430,9 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
elif input_shape is None:
|
elif input_shape is None:
|
||||||
input_shape = (1, 1)
|
input_shape = (1, 1)
|
||||||
|
|
||||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||||
|
|
||||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||||
# init input tensors
|
# init input tensors
|
||||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||||
token_type_ids = jnp.zeros_like(input_ids)
|
token_type_ids = jnp.zeros_like(input_ids)
|
||||||
@@ -1441,10 +1443,20 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
params_rng, dropout_rng = jax.random.split(rng)
|
params_rng, dropout_rng = jax.random.split(rng)
|
||||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||||
|
|
||||||
return self.module.init(
|
random_params = self.module.init(
|
||||||
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
|
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
|
||||||
)["params"]
|
)["params"]
|
||||||
|
|
||||||
|
if params is not None:
|
||||||
|
random_params = flatten_dict(unfreeze(random_params))
|
||||||
|
params = flatten_dict(unfreeze(params))
|
||||||
|
for missing_key in self._missing_keys:
|
||||||
|
params[missing_key] = random_params[missing_key]
|
||||||
|
self._missing_keys = set()
|
||||||
|
return freeze(unflatten_dict(params))
|
||||||
|
else:
|
||||||
|
return random_params
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@@ -1897,13 +1909,14 @@ class FlaxBigBirdForMultipleChoice(FlaxBigBirdPreTrainedModel):
|
|||||||
input_shape: Optional[tuple] = None,
|
input_shape: Optional[tuple] = None,
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
dtype: jnp.dtype = jnp.float32,
|
dtype: jnp.dtype = jnp.float32,
|
||||||
|
_do_init: bool = True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
if config.attention_type == "block_sparse" and input_shape is None:
|
if config.attention_type == "block_sparse" and input_shape is None:
|
||||||
input_shape = (1, 1, 12 * config.block_size)
|
input_shape = (1, 1, 12 * config.block_size)
|
||||||
elif input_shape is None:
|
elif input_shape is None:
|
||||||
input_shape = (1, 1)
|
input_shape = (1, 1)
|
||||||
super().__init__(config, input_shape=input_shape, seed=seed, dtype=dtype)
|
super().__init__(config, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||||
|
|
||||||
|
|
||||||
overwrite_call_docstring(
|
overwrite_call_docstring(
|
||||||
|
|||||||
@@ -24,9 +24,10 @@ import numpy as np
|
|||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from flax.core.frozen_dict import FrozenDict, unfreeze
|
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||||
from flax.linen import combine_masks, make_causal_mask
|
from flax.linen import combine_masks, make_causal_mask
|
||||||
from flax.linen.attention import dot_product_attention_weights
|
from flax.linen.attention import dot_product_attention_weights
|
||||||
|
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||||
from jax import lax
|
from jax import lax
|
||||||
from jax.random import PRNGKey
|
from jax.random import PRNGKey
|
||||||
|
|
||||||
@@ -887,12 +888,13 @@ class FlaxBlenderbotPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
input_shape: Tuple[int] = (1, 1),
|
input_shape: Tuple[int] = (1, 1),
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
dtype: jnp.dtype = jnp.float32,
|
dtype: jnp.dtype = jnp.float32,
|
||||||
|
_do_init: bool = True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||||
|
|
||||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||||
# init input tensors
|
# init input tensors
|
||||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||||
# make sure initialization pass will work for FlaxBlenderbotForSequenceClassificationModule
|
# make sure initialization pass will work for FlaxBlenderbotForSequenceClassificationModule
|
||||||
@@ -908,7 +910,7 @@ class FlaxBlenderbotPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
params_rng, dropout_rng = jax.random.split(rng)
|
params_rng, dropout_rng = jax.random.split(rng)
|
||||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||||
|
|
||||||
return self.module.init(
|
random_params = self.module.init(
|
||||||
rngs,
|
rngs,
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
@@ -918,6 +920,16 @@ class FlaxBlenderbotPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
decoder_position_ids,
|
decoder_position_ids,
|
||||||
)["params"]
|
)["params"]
|
||||||
|
|
||||||
|
if params is not None:
|
||||||
|
random_params = flatten_dict(unfreeze(random_params))
|
||||||
|
params = flatten_dict(unfreeze(params))
|
||||||
|
for missing_key in self._missing_keys:
|
||||||
|
params[missing_key] = random_params[missing_key]
|
||||||
|
self._missing_keys = set()
|
||||||
|
return freeze(unflatten_dict(params))
|
||||||
|
else:
|
||||||
|
return random_params
|
||||||
|
|
||||||
def init_cache(self, batch_size, max_length, encoder_outputs):
|
def init_cache(self, batch_size, max_length, encoder_outputs):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -25,9 +25,10 @@ import numpy as np
|
|||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from flax.core.frozen_dict import FrozenDict, unfreeze
|
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||||
from flax.linen import combine_masks, make_causal_mask
|
from flax.linen import combine_masks, make_causal_mask
|
||||||
from flax.linen.attention import dot_product_attention_weights
|
from flax.linen.attention import dot_product_attention_weights
|
||||||
|
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||||
from jax import lax
|
from jax import lax
|
||||||
from jax.random import PRNGKey
|
from jax.random import PRNGKey
|
||||||
|
|
||||||
@@ -885,12 +886,13 @@ class FlaxBlenderbotSmallPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
input_shape: Tuple[int] = (1, 1),
|
input_shape: Tuple[int] = (1, 1),
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
dtype: jnp.dtype = jnp.float32,
|
dtype: jnp.dtype = jnp.float32,
|
||||||
|
_do_init: bool = True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||||
|
|
||||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||||
# init input tensors
|
# init input tensors
|
||||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||||
# make sure initialization pass will work for FlaxBlenderbotSmallForSequenceClassificationModule
|
# make sure initialization pass will work for FlaxBlenderbotSmallForSequenceClassificationModule
|
||||||
@@ -906,7 +908,7 @@ class FlaxBlenderbotSmallPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
params_rng, dropout_rng = jax.random.split(rng)
|
params_rng, dropout_rng = jax.random.split(rng)
|
||||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||||
|
|
||||||
return self.module.init(
|
random_params = self.module.init(
|
||||||
rngs,
|
rngs,
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
@@ -916,6 +918,16 @@ class FlaxBlenderbotSmallPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
decoder_position_ids,
|
decoder_position_ids,
|
||||||
)["params"]
|
)["params"]
|
||||||
|
|
||||||
|
if params is not None:
|
||||||
|
random_params = flatten_dict(unfreeze(random_params))
|
||||||
|
params = flatten_dict(unfreeze(params))
|
||||||
|
for missing_key in self._missing_keys:
|
||||||
|
params[missing_key] = random_params[missing_key]
|
||||||
|
self._missing_keys = set()
|
||||||
|
return freeze(unflatten_dict(params))
|
||||||
|
else:
|
||||||
|
return random_params
|
||||||
|
|
||||||
def init_cache(self, batch_size, max_length, encoder_outputs):
|
def init_cache(self, batch_size, max_length, encoder_outputs):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -19,9 +19,10 @@ import flax
|
|||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from flax.core.frozen_dict import FrozenDict
|
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||||
from flax.linen import combine_masks, make_causal_mask
|
from flax.linen import combine_masks, make_causal_mask
|
||||||
from flax.linen.attention import dot_product_attention_weights
|
from flax.linen.attention import dot_product_attention_weights
|
||||||
|
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||||
from jax import lax
|
from jax import lax
|
||||||
|
|
||||||
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling
|
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling
|
||||||
@@ -585,12 +586,18 @@ class FlaxCLIPTextPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
module_class: nn.Module = None
|
module_class: nn.Module = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, config: CLIPTextConfig, input_shape=(1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
self,
|
||||||
|
config: CLIPTextConfig,
|
||||||
|
input_shape=(1, 1),
|
||||||
|
seed: int = 0,
|
||||||
|
dtype: jnp.dtype = jnp.float32,
|
||||||
|
_do_init: bool = True,
|
||||||
|
**kwargs
|
||||||
):
|
):
|
||||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||||
|
|
||||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||||
# init input tensor
|
# init input tensor
|
||||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||||
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
|
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
|
||||||
@@ -599,7 +606,17 @@ class FlaxCLIPTextPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
params_rng, dropout_rng = jax.random.split(rng)
|
params_rng, dropout_rng = jax.random.split(rng)
|
||||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||||
|
|
||||||
return self.module.init(rngs, input_ids, attention_mask, position_ids)["params"]
|
random_params = self.module.init(rngs, input_ids, attention_mask, position_ids)["params"]
|
||||||
|
|
||||||
|
if params is not None:
|
||||||
|
random_params = flatten_dict(unfreeze(random_params))
|
||||||
|
params = flatten_dict(unfreeze(params))
|
||||||
|
for missing_key in self._missing_keys:
|
||||||
|
params[missing_key] = random_params[missing_key]
|
||||||
|
self._missing_keys = set()
|
||||||
|
return freeze(unflatten_dict(params))
|
||||||
|
else:
|
||||||
|
return random_params
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@@ -654,21 +671,32 @@ class FlaxCLIPVisionPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
input_shape: Optional[Tuple] = None,
|
input_shape: Optional[Tuple] = None,
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
dtype: jnp.dtype = jnp.float32,
|
dtype: jnp.dtype = jnp.float32,
|
||||||
|
_do_init: bool = True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
if input_shape is None:
|
if input_shape is None:
|
||||||
input_shape = (1, config.image_size, config.image_size, 3)
|
input_shape = (1, config.image_size, config.image_size, 3)
|
||||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||||
|
|
||||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||||
# init input tensor
|
# init input tensor
|
||||||
pixel_values = jax.random.normal(rng, input_shape)
|
pixel_values = jax.random.normal(rng, input_shape)
|
||||||
|
|
||||||
params_rng, dropout_rng = jax.random.split(rng)
|
params_rng, dropout_rng = jax.random.split(rng)
|
||||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||||
|
|
||||||
return self.module.init(rngs, pixel_values)["params"]
|
random_params = self.module.init(rngs, pixel_values)["params"]
|
||||||
|
|
||||||
|
if params is not None:
|
||||||
|
random_params = flatten_dict(unfreeze(random_params))
|
||||||
|
params = flatten_dict(unfreeze(params))
|
||||||
|
for missing_key in self._missing_keys:
|
||||||
|
params[missing_key] = random_params[missing_key]
|
||||||
|
self._missing_keys = set()
|
||||||
|
return freeze(unflatten_dict(params))
|
||||||
|
else:
|
||||||
|
return random_params
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@@ -714,14 +742,15 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
input_shape: Optional[Tuple] = None,
|
input_shape: Optional[Tuple] = None,
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
dtype: jnp.dtype = jnp.float32,
|
dtype: jnp.dtype = jnp.float32,
|
||||||
|
_do_init: bool = True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
if input_shape is None:
|
if input_shape is None:
|
||||||
input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3))
|
input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3))
|
||||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||||
|
|
||||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||||
# init input tensor
|
# init input tensor
|
||||||
input_ids = jnp.zeros(input_shape[0], dtype="i4")
|
input_ids = jnp.zeros(input_shape[0], dtype="i4")
|
||||||
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0])
|
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0])
|
||||||
@@ -732,7 +761,17 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
params_rng, dropout_rng = jax.random.split(rng)
|
params_rng, dropout_rng = jax.random.split(rng)
|
||||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||||
|
|
||||||
return self.module.init(rngs, input_ids, pixel_values, attention_mask, position_ids)["params"]
|
random_params = self.module.init(rngs, input_ids, pixel_values, attention_mask, position_ids)["params"]
|
||||||
|
|
||||||
|
if params is not None:
|
||||||
|
random_params = flatten_dict(unfreeze(random_params))
|
||||||
|
params = flatten_dict(unfreeze(params))
|
||||||
|
for missing_key in self._missing_keys:
|
||||||
|
params[missing_key] = random_params[missing_key]
|
||||||
|
self._missing_keys = set()
|
||||||
|
return freeze(unflatten_dict(params))
|
||||||
|
else:
|
||||||
|
return random_params
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -21,7 +21,8 @@ import numpy as np
|
|||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from flax.core.frozen_dict import FrozenDict
|
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||||
|
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||||
from jax import lax
|
from jax import lax
|
||||||
|
|
||||||
from ...modeling_flax_outputs import (
|
from ...modeling_flax_outputs import (
|
||||||
@@ -428,12 +429,13 @@ class FlaxDistilBertPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
input_shape: Tuple = (1, 1),
|
input_shape: Tuple = (1, 1),
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
dtype: jnp.dtype = jnp.float32,
|
dtype: jnp.dtype = jnp.float32,
|
||||||
|
_do_init: bool = True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||||
|
|
||||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||||
# init input tensors
|
# init input tensors
|
||||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||||
attention_mask = jnp.ones_like(input_ids)
|
attention_mask = jnp.ones_like(input_ids)
|
||||||
@@ -441,7 +443,17 @@ class FlaxDistilBertPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
params_rng, dropout_rng = jax.random.split(rng)
|
params_rng, dropout_rng = jax.random.split(rng)
|
||||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||||
|
|
||||||
return self.module.init(rngs, input_ids, attention_mask, return_dict=False)["params"]
|
random_params = self.module.init(rngs, input_ids, attention_mask, return_dict=False)["params"]
|
||||||
|
|
||||||
|
if params is not None:
|
||||||
|
random_params = flatten_dict(unfreeze(random_params))
|
||||||
|
params = flatten_dict(unfreeze(params))
|
||||||
|
for missing_key in self._missing_keys:
|
||||||
|
params[missing_key] = random_params[missing_key]
|
||||||
|
self._missing_keys = set()
|
||||||
|
return freeze(unflatten_dict(params))
|
||||||
|
else:
|
||||||
|
return random_params
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
def __call__(
|
def __call__(
|
||||||
|
|||||||
@@ -21,8 +21,9 @@ import flax
|
|||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from flax.core.frozen_dict import FrozenDict
|
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||||
from flax.linen.attention import dot_product_attention_weights
|
from flax.linen.attention import dot_product_attention_weights
|
||||||
|
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||||
from jax import lax
|
from jax import lax
|
||||||
from jax.random import PRNGKey
|
from jax.random import PRNGKey
|
||||||
|
|
||||||
@@ -541,12 +542,13 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
input_shape: Tuple = (1, 1),
|
input_shape: Tuple = (1, 1),
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
dtype: jnp.dtype = jnp.float32,
|
dtype: jnp.dtype = jnp.float32,
|
||||||
|
_do_init: bool = True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||||
|
|
||||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||||
# init input tensors
|
# init input tensors
|
||||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||||
token_type_ids = jnp.zeros_like(input_ids)
|
token_type_ids = jnp.zeros_like(input_ids)
|
||||||
@@ -557,10 +559,20 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
params_rng, dropout_rng = jax.random.split(rng)
|
params_rng, dropout_rng = jax.random.split(rng)
|
||||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||||
|
|
||||||
return self.module.init(
|
random_params = self.module.init(
|
||||||
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
|
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
|
||||||
)["params"]
|
)["params"]
|
||||||
|
|
||||||
|
if params is not None:
|
||||||
|
random_params = flatten_dict(unfreeze(random_params))
|
||||||
|
params = flatten_dict(unfreeze(params))
|
||||||
|
for missing_key in self._missing_keys:
|
||||||
|
params[missing_key] = random_params[missing_key]
|
||||||
|
self._missing_keys = set()
|
||||||
|
return freeze(unflatten_dict(params))
|
||||||
|
else:
|
||||||
|
return random_params
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -21,7 +21,8 @@ from typing import Optional, Tuple, Union
|
|||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from flax.core.frozen_dict import FrozenDict, unfreeze
|
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||||
|
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||||
from jax import lax
|
from jax import lax
|
||||||
from jax.random import PRNGKey
|
from jax.random import PRNGKey
|
||||||
|
|
||||||
@@ -315,11 +316,17 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
|
|||||||
input_shape: Optional[Tuple] = None,
|
input_shape: Optional[Tuple] = None,
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
dtype: jnp.dtype = jnp.float32,
|
dtype: jnp.dtype = jnp.float32,
|
||||||
|
_do_init: bool = True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
if input_shape is None:
|
if input_shape is None:
|
||||||
input_shape = ((1, 1), (1, 1))
|
input_shape = ((1, 1), (1, 1))
|
||||||
|
|
||||||
|
if not _do_init:
|
||||||
|
raise ValueError(
|
||||||
|
"`FlaxEncoderDecoderModel` cannot be created without initializing, `_do_init` must be `True`."
|
||||||
|
)
|
||||||
|
|
||||||
if config.decoder.cross_attention_hidden_size is not None:
|
if config.decoder.cross_attention_hidden_size is not None:
|
||||||
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
|
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -330,9 +337,9 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||||
|
|
||||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||||
encoder_input_shape, decoder_input_shape = input_shape
|
encoder_input_shape, decoder_input_shape = input_shape
|
||||||
|
|
||||||
# init input tensors
|
# init input tensors
|
||||||
@@ -356,7 +363,7 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
|
|||||||
params_rng, dropout_rng = jax.random.split(rng)
|
params_rng, dropout_rng = jax.random.split(rng)
|
||||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||||
|
|
||||||
return self.module.init(
|
random_params = self.module.init(
|
||||||
rngs,
|
rngs,
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
@@ -366,6 +373,16 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
|
|||||||
decoder_position_ids,
|
decoder_position_ids,
|
||||||
)["params"]
|
)["params"]
|
||||||
|
|
||||||
|
if params is not None:
|
||||||
|
random_params = flatten_dict(unfreeze(random_params))
|
||||||
|
params = flatten_dict(unfreeze(params))
|
||||||
|
for missing_key in self._missing_keys:
|
||||||
|
params[missing_key] = random_params[missing_key]
|
||||||
|
self._missing_keys = set()
|
||||||
|
return freeze(unflatten_dict(params))
|
||||||
|
else:
|
||||||
|
return random_params
|
||||||
|
|
||||||
def init_cache(self, batch_size, max_length, encoder_outputs):
|
def init_cache(self, batch_size, max_length, encoder_outputs):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -18,9 +18,10 @@ from typing import Any, Optional, Tuple
|
|||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from flax.core.frozen_dict import FrozenDict, unfreeze
|
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||||
from flax.linen import combine_masks, make_causal_mask
|
from flax.linen import combine_masks, make_causal_mask
|
||||||
from flax.linen.attention import dot_product_attention_weights
|
from flax.linen.attention import dot_product_attention_weights
|
||||||
|
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||||
from jax import lax
|
from jax import lax
|
||||||
|
|
||||||
from ...modeling_flax_outputs import (
|
from ...modeling_flax_outputs import (
|
||||||
@@ -394,12 +395,13 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
|
|||||||
input_shape: Tuple = (1, 1),
|
input_shape: Tuple = (1, 1),
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
dtype: jnp.dtype = jnp.float32,
|
dtype: jnp.dtype = jnp.float32,
|
||||||
|
_do_init: bool = True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||||
|
|
||||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||||
# init input tensors
|
# init input tensors
|
||||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||||
attention_mask = jnp.ones_like(input_ids)
|
attention_mask = jnp.ones_like(input_ids)
|
||||||
@@ -422,7 +424,17 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)
|
module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)
|
||||||
|
|
||||||
return module_init_outputs["params"]
|
random_params = module_init_outputs["params"]
|
||||||
|
|
||||||
|
if params is not None:
|
||||||
|
random_params = flatten_dict(unfreeze(random_params))
|
||||||
|
params = flatten_dict(unfreeze(params))
|
||||||
|
for missing_key in self._missing_keys:
|
||||||
|
params[missing_key] = random_params[missing_key]
|
||||||
|
self._missing_keys = set()
|
||||||
|
return freeze(unflatten_dict(params))
|
||||||
|
else:
|
||||||
|
return random_params
|
||||||
|
|
||||||
def init_cache(self, batch_size, max_length):
|
def init_cache(self, batch_size, max_length):
|
||||||
r"""
|
r"""
|
||||||
|
|||||||
@@ -19,9 +19,10 @@ from typing import Optional, Tuple
|
|||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from flax.core.frozen_dict import FrozenDict, unfreeze
|
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||||
from flax.linen import combine_masks, make_causal_mask
|
from flax.linen import combine_masks, make_causal_mask
|
||||||
from flax.linen.attention import dot_product_attention_weights
|
from flax.linen.attention import dot_product_attention_weights
|
||||||
|
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||||
from jax import lax
|
from jax import lax
|
||||||
|
|
||||||
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
|
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
|
||||||
@@ -353,12 +354,13 @@ class FlaxGPTNeoPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
input_shape: Tuple = (1, 1),
|
input_shape: Tuple = (1, 1),
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
dtype: jnp.dtype = jnp.float32,
|
dtype: jnp.dtype = jnp.float32,
|
||||||
|
_do_init: bool = True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||||
|
|
||||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||||
# init input tensors
|
# init input tensors
|
||||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||||
attention_mask = jnp.ones_like(input_ids)
|
attention_mask = jnp.ones_like(input_ids)
|
||||||
@@ -366,7 +368,17 @@ class FlaxGPTNeoPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
params_rng, dropout_rng = jax.random.split(rng)
|
params_rng, dropout_rng = jax.random.split(rng)
|
||||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||||
|
|
||||||
return self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"]
|
random_params = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"]
|
||||||
|
|
||||||
|
if params is not None:
|
||||||
|
random_params = flatten_dict(unfreeze(random_params))
|
||||||
|
params = flatten_dict(unfreeze(params))
|
||||||
|
for missing_key in self._missing_keys:
|
||||||
|
params[missing_key] = random_params[missing_key]
|
||||||
|
self._missing_keys = set()
|
||||||
|
return freeze(unflatten_dict(params))
|
||||||
|
else:
|
||||||
|
return random_params
|
||||||
|
|
||||||
def init_cache(self, batch_size, max_length):
|
def init_cache(self, batch_size, max_length):
|
||||||
r"""
|
r"""
|
||||||
|
|||||||
@@ -21,9 +21,10 @@ import numpy as np
|
|||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from flax.core.frozen_dict import FrozenDict, unfreeze
|
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||||
from flax.linen import combine_masks, make_causal_mask
|
from flax.linen import combine_masks, make_causal_mask
|
||||||
from flax.linen.attention import dot_product_attention_weights
|
from flax.linen.attention import dot_product_attention_weights
|
||||||
|
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||||
from jax import lax
|
from jax import lax
|
||||||
|
|
||||||
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
|
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
|
||||||
@@ -373,12 +374,13 @@ class FlaxGPTJPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
input_shape: Tuple = (1, 1),
|
input_shape: Tuple = (1, 1),
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
dtype: jnp.dtype = jnp.float32,
|
dtype: jnp.dtype = jnp.float32,
|
||||||
|
_do_init: bool = True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||||
|
|
||||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||||
# init input tensors
|
# init input tensors
|
||||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||||
attention_mask = jnp.ones_like(input_ids)
|
attention_mask = jnp.ones_like(input_ids)
|
||||||
@@ -401,7 +403,17 @@ class FlaxGPTJPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)
|
module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)
|
||||||
|
|
||||||
return module_init_outputs["params"]
|
random_params = module_init_outputs["params"]
|
||||||
|
|
||||||
|
if params is not None:
|
||||||
|
random_params = flatten_dict(unfreeze(random_params))
|
||||||
|
params = flatten_dict(unfreeze(params))
|
||||||
|
for missing_key in self._missing_keys:
|
||||||
|
params[missing_key] = random_params[missing_key]
|
||||||
|
self._missing_keys = set()
|
||||||
|
return freeze(unflatten_dict(params))
|
||||||
|
else:
|
||||||
|
return random_params
|
||||||
|
|
||||||
def init_cache(self, batch_size, max_length):
|
def init_cache(self, batch_size, max_length):
|
||||||
r"""
|
r"""
|
||||||
|
|||||||
@@ -24,9 +24,10 @@ import numpy as np
|
|||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from flax.core.frozen_dict import FrozenDict, unfreeze
|
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||||
from flax.linen import combine_masks, make_causal_mask
|
from flax.linen import combine_masks, make_causal_mask
|
||||||
from flax.linen.attention import dot_product_attention_weights
|
from flax.linen.attention import dot_product_attention_weights
|
||||||
|
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||||
from jax import lax
|
from jax import lax
|
||||||
from jax.random import PRNGKey
|
from jax.random import PRNGKey
|
||||||
|
|
||||||
@@ -882,12 +883,13 @@ class FlaxMarianPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
input_shape: Tuple[int] = (1, 1),
|
input_shape: Tuple[int] = (1, 1),
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
dtype: jnp.dtype = jnp.float32,
|
dtype: jnp.dtype = jnp.float32,
|
||||||
|
_do_init: bool = True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||||
|
|
||||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||||
# init input tensors
|
# init input tensors
|
||||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||||
# make sure initialization pass will work for FlaxMarianForSequenceClassificationModule
|
# make sure initialization pass will work for FlaxMarianForSequenceClassificationModule
|
||||||
@@ -903,7 +905,7 @@ class FlaxMarianPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
params_rng, dropout_rng = jax.random.split(rng)
|
params_rng, dropout_rng = jax.random.split(rng)
|
||||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||||
|
|
||||||
return self.module.init(
|
random_params = self.module.init(
|
||||||
rngs,
|
rngs,
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
@@ -913,6 +915,16 @@ class FlaxMarianPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
decoder_position_ids,
|
decoder_position_ids,
|
||||||
)["params"]
|
)["params"]
|
||||||
|
|
||||||
|
if params is not None:
|
||||||
|
random_params = flatten_dict(unfreeze(random_params))
|
||||||
|
params = flatten_dict(unfreeze(params))
|
||||||
|
for missing_key in self._missing_keys:
|
||||||
|
params[missing_key] = random_params[missing_key]
|
||||||
|
self._missing_keys = set()
|
||||||
|
return freeze(unflatten_dict(params))
|
||||||
|
else:
|
||||||
|
return random_params
|
||||||
|
|
||||||
def init_cache(self, batch_size, max_length, encoder_outputs):
|
def init_cache(self, batch_size, max_length, encoder_outputs):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -24,9 +24,10 @@ import numpy as np
|
|||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from flax.core.frozen_dict import FrozenDict, unfreeze
|
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||||
from flax.linen import combine_masks, make_causal_mask
|
from flax.linen import combine_masks, make_causal_mask
|
||||||
from flax.linen.attention import dot_product_attention_weights
|
from flax.linen.attention import dot_product_attention_weights
|
||||||
|
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||||
from jax import lax
|
from jax import lax
|
||||||
from jax.random import PRNGKey
|
from jax.random import PRNGKey
|
||||||
|
|
||||||
@@ -951,12 +952,13 @@ class FlaxMBartPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
input_shape: Tuple[int] = (1, 1),
|
input_shape: Tuple[int] = (1, 1),
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
dtype: jnp.dtype = jnp.float32,
|
dtype: jnp.dtype = jnp.float32,
|
||||||
|
_do_init: bool = True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||||
|
|
||||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||||
# init input tensors
|
# init input tensors
|
||||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||||
# make sure initialization pass will work for FlaxMBartForSequenceClassificationModule
|
# make sure initialization pass will work for FlaxMBartForSequenceClassificationModule
|
||||||
@@ -972,7 +974,7 @@ class FlaxMBartPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
params_rng, dropout_rng = jax.random.split(rng)
|
params_rng, dropout_rng = jax.random.split(rng)
|
||||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||||
|
|
||||||
return self.module.init(
|
random_params = self.module.init(
|
||||||
rngs,
|
rngs,
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
@@ -982,6 +984,16 @@ class FlaxMBartPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
decoder_position_ids,
|
decoder_position_ids,
|
||||||
)["params"]
|
)["params"]
|
||||||
|
|
||||||
|
if params is not None:
|
||||||
|
random_params = flatten_dict(unfreeze(random_params))
|
||||||
|
params = flatten_dict(unfreeze(params))
|
||||||
|
for missing_key in self._missing_keys:
|
||||||
|
params[missing_key] = random_params[missing_key]
|
||||||
|
self._missing_keys = set()
|
||||||
|
return freeze(unflatten_dict(params))
|
||||||
|
else:
|
||||||
|
return random_params
|
||||||
|
|
||||||
# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartPreTrainedModel.init_cache with Bart->MBart
|
# Copied from transformers.models.bart.modeling_flax_bart.FlaxBartPreTrainedModel.init_cache with Bart->MBart
|
||||||
def init_cache(self, batch_size, max_length, encoder_outputs):
|
def init_cache(self, batch_size, max_length, encoder_outputs):
|
||||||
r"""
|
r"""
|
||||||
|
|||||||
@@ -25,9 +25,10 @@ import numpy as np
|
|||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from flax.core.frozen_dict import FrozenDict, unfreeze
|
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||||
from flax.linen import combine_masks, make_causal_mask
|
from flax.linen import combine_masks, make_causal_mask
|
||||||
from flax.linen.attention import dot_product_attention_weights
|
from flax.linen.attention import dot_product_attention_weights
|
||||||
|
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||||
from jax import lax
|
from jax import lax
|
||||||
from jax.random import PRNGKey
|
from jax.random import PRNGKey
|
||||||
|
|
||||||
@@ -901,12 +902,13 @@ class FlaxPegasusPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
input_shape: Tuple[int] = (1, 1),
|
input_shape: Tuple[int] = (1, 1),
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
dtype: jnp.dtype = jnp.float32,
|
dtype: jnp.dtype = jnp.float32,
|
||||||
|
_do_init: bool = True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||||
|
|
||||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||||
# init input tensors
|
# init input tensors
|
||||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||||
attention_mask = jnp.ones_like(input_ids)
|
attention_mask = jnp.ones_like(input_ids)
|
||||||
@@ -920,7 +922,7 @@ class FlaxPegasusPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
params_rng, dropout_rng = jax.random.split(rng)
|
params_rng, dropout_rng = jax.random.split(rng)
|
||||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||||
|
|
||||||
return self.module.init(
|
random_params = self.module.init(
|
||||||
rngs,
|
rngs,
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
@@ -930,6 +932,16 @@ class FlaxPegasusPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
decoder_position_ids,
|
decoder_position_ids,
|
||||||
)["params"]
|
)["params"]
|
||||||
|
|
||||||
|
if params is not None:
|
||||||
|
random_params = flatten_dict(unfreeze(random_params))
|
||||||
|
params = flatten_dict(unfreeze(params))
|
||||||
|
for missing_key in self._missing_keys:
|
||||||
|
params[missing_key] = random_params[missing_key]
|
||||||
|
self._missing_keys = set()
|
||||||
|
return freeze(unflatten_dict(params))
|
||||||
|
else:
|
||||||
|
return random_params
|
||||||
|
|
||||||
def init_cache(self, batch_size, max_length, encoder_outputs):
|
def init_cache(self, batch_size, max_length, encoder_outputs):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -19,8 +19,9 @@ import numpy as np
|
|||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from flax.core.frozen_dict import FrozenDict
|
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||||
from flax.linen.attention import dot_product_attention_weights
|
from flax.linen.attention import dot_product_attention_weights
|
||||||
|
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||||
from jax import lax
|
from jax import lax
|
||||||
from jax.random import PRNGKey
|
from jax.random import PRNGKey
|
||||||
|
|
||||||
@@ -585,12 +586,13 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
input_shape: Tuple = (1, 1),
|
input_shape: Tuple = (1, 1),
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
dtype: jnp.dtype = jnp.float32,
|
dtype: jnp.dtype = jnp.float32,
|
||||||
|
_do_init: bool = True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||||
|
|
||||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||||
# init input tensors
|
# init input tensors
|
||||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||||
token_type_ids = jnp.ones_like(input_ids)
|
token_type_ids = jnp.ones_like(input_ids)
|
||||||
@@ -601,10 +603,20 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
params_rng, dropout_rng = jax.random.split(rng)
|
params_rng, dropout_rng = jax.random.split(rng)
|
||||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||||
|
|
||||||
return self.module.init(
|
random_params = self.module.init(
|
||||||
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
|
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
|
||||||
)["params"]
|
)["params"]
|
||||||
|
|
||||||
|
if params is not None:
|
||||||
|
random_params = flatten_dict(unfreeze(random_params))
|
||||||
|
params = flatten_dict(unfreeze(params))
|
||||||
|
for missing_key in self._missing_keys:
|
||||||
|
params[missing_key] = random_params[missing_key]
|
||||||
|
self._missing_keys = set()
|
||||||
|
return freeze(unflatten_dict(params))
|
||||||
|
else:
|
||||||
|
return random_params
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -21,8 +21,9 @@ import numpy as np
|
|||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from flax.core.frozen_dict import FrozenDict
|
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||||
from flax.linen.attention import dot_product_attention_weights
|
from flax.linen.attention import dot_product_attention_weights
|
||||||
|
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||||
from jax import lax
|
from jax import lax
|
||||||
|
|
||||||
from ...modeling_flax_outputs import (
|
from ...modeling_flax_outputs import (
|
||||||
@@ -621,12 +622,13 @@ class FlaxRoFormerPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
input_shape: Tuple = (1, 1),
|
input_shape: Tuple = (1, 1),
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
dtype: jnp.dtype = jnp.float32,
|
dtype: jnp.dtype = jnp.float32,
|
||||||
|
_do_init: bool = True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||||
|
|
||||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||||
# init input tensors
|
# init input tensors
|
||||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||||
token_type_ids = jnp.zeros_like(input_ids)
|
token_type_ids = jnp.zeros_like(input_ids)
|
||||||
@@ -636,9 +638,19 @@ class FlaxRoFormerPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
params_rng, dropout_rng = jax.random.split(rng)
|
params_rng, dropout_rng = jax.random.split(rng)
|
||||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||||
|
|
||||||
return self.module.init(rngs, input_ids, attention_mask, token_type_ids, head_mask, return_dict=False)[
|
random_params = self.module.init(
|
||||||
"params"
|
rngs, input_ids, attention_mask, token_type_ids, head_mask, return_dict=False
|
||||||
]
|
)["params"]
|
||||||
|
|
||||||
|
if params is not None:
|
||||||
|
random_params = flatten_dict(unfreeze(random_params))
|
||||||
|
params = flatten_dict(unfreeze(params))
|
||||||
|
for missing_key in self._missing_keys:
|
||||||
|
params[missing_key] = random_params[missing_key]
|
||||||
|
self._missing_keys = set()
|
||||||
|
return freeze(unflatten_dict(params))
|
||||||
|
else:
|
||||||
|
return random_params
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
def __call__(
|
def __call__(
|
||||||
|
|||||||
@@ -20,7 +20,8 @@ from typing import Optional, Tuple, Union
|
|||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from flax.core.frozen_dict import FrozenDict, unfreeze
|
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||||
|
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||||
from jax import lax
|
from jax import lax
|
||||||
from jax.random import PRNGKey
|
from jax.random import PRNGKey
|
||||||
|
|
||||||
@@ -343,8 +344,15 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
|
|||||||
input_shape: Optional[Tuple] = None,
|
input_shape: Optional[Tuple] = None,
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
dtype: jnp.dtype = jnp.float32,
|
dtype: jnp.dtype = jnp.float32,
|
||||||
|
_do_init: bool = True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
|
|
||||||
|
if not _do_init:
|
||||||
|
raise ValueError(
|
||||||
|
"`FlaxSpeechEncoderDecoderModel` cannot be created without initializing, `_do_init` must be `True`."
|
||||||
|
)
|
||||||
|
|
||||||
if config.decoder.cross_attention_hidden_size is not None:
|
if config.decoder.cross_attention_hidden_size is not None:
|
||||||
# Raise ValueError or option to project enc to dec hidden_size (eg EncAdapterLayer)
|
# Raise ValueError or option to project enc to dec hidden_size (eg EncAdapterLayer)
|
||||||
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
|
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
|
||||||
@@ -365,9 +373,9 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
|
|||||||
decoder_input_length = module._get_feat_extract_output_lengths(encoder_input_length)
|
decoder_input_length = module._get_feat_extract_output_lengths(encoder_input_length)
|
||||||
input_shape = ((1, encoder_input_length), (1, decoder_input_length))
|
input_shape = ((1, encoder_input_length), (1, decoder_input_length))
|
||||||
|
|
||||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||||
|
|
||||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||||
encoder_input_shape, decoder_input_shape = input_shape
|
encoder_input_shape, decoder_input_shape = input_shape
|
||||||
|
|
||||||
# init input DeviceArrays
|
# init input DeviceArrays
|
||||||
@@ -390,7 +398,7 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
|
|||||||
params_rng, dropout_rng = jax.random.split(rng)
|
params_rng, dropout_rng = jax.random.split(rng)
|
||||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||||
|
|
||||||
return self.module.init(
|
random_params = self.module.init(
|
||||||
rngs,
|
rngs,
|
||||||
inputs,
|
inputs,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
@@ -399,6 +407,16 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
|
|||||||
decoder_position_ids,
|
decoder_position_ids,
|
||||||
)["params"]
|
)["params"]
|
||||||
|
|
||||||
|
if params is not None:
|
||||||
|
random_params = flatten_dict(unfreeze(random_params))
|
||||||
|
params = flatten_dict(unfreeze(params))
|
||||||
|
for missing_key in self._missing_keys:
|
||||||
|
params[missing_key] = random_params[missing_key]
|
||||||
|
self._missing_keys = set()
|
||||||
|
return freeze(unflatten_dict(params))
|
||||||
|
else:
|
||||||
|
return random_params
|
||||||
|
|
||||||
def init_cache(self, batch_size, max_length, encoder_outputs):
|
def init_cache(self, batch_size, max_length, encoder_outputs):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -23,9 +23,10 @@ import numpy as np
|
|||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from flax.core.frozen_dict import FrozenDict, unfreeze
|
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||||
from flax.linen import combine_masks, make_causal_mask
|
from flax.linen import combine_masks, make_causal_mask
|
||||||
from flax.linen.attention import dot_product_attention_weights
|
from flax.linen.attention import dot_product_attention_weights
|
||||||
|
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||||
from jax.random import PRNGKey
|
from jax.random import PRNGKey
|
||||||
|
|
||||||
from ...modeling_flax_outputs import (
|
from ...modeling_flax_outputs import (
|
||||||
@@ -919,12 +920,13 @@ class FlaxT5PreTrainedModel(FlaxPreTrainedModel):
|
|||||||
input_shape: Tuple[int] = (1, 1),
|
input_shape: Tuple[int] = (1, 1),
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
dtype: jnp.dtype = jnp.float32,
|
dtype: jnp.dtype = jnp.float32,
|
||||||
|
_do_init: bool = True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||||
|
|
||||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||||
# init input tensors
|
# init input tensors
|
||||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||||
|
|
||||||
@@ -935,7 +937,7 @@ class FlaxT5PreTrainedModel(FlaxPreTrainedModel):
|
|||||||
params_rng, dropout_rng = jax.random.split(rng)
|
params_rng, dropout_rng = jax.random.split(rng)
|
||||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||||
|
|
||||||
return self.module.init(
|
random_params = self.module.init(
|
||||||
rngs,
|
rngs,
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
@@ -943,6 +945,16 @@ class FlaxT5PreTrainedModel(FlaxPreTrainedModel):
|
|||||||
decoder_attention_mask,
|
decoder_attention_mask,
|
||||||
)["params"]
|
)["params"]
|
||||||
|
|
||||||
|
if params is not None:
|
||||||
|
random_params = flatten_dict(unfreeze(random_params))
|
||||||
|
params = flatten_dict(unfreeze(params))
|
||||||
|
for missing_key in self._missing_keys:
|
||||||
|
params[missing_key] = random_params[missing_key]
|
||||||
|
self._missing_keys = set()
|
||||||
|
return freeze(unflatten_dict(params))
|
||||||
|
else:
|
||||||
|
return random_params
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -21,7 +21,8 @@ from typing import Optional, Tuple, Union
|
|||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from flax.core.frozen_dict import FrozenDict, unfreeze
|
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||||
|
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||||
from jax import lax
|
from jax import lax
|
||||||
from jax.random import PRNGKey
|
from jax.random import PRNGKey
|
||||||
|
|
||||||
@@ -282,8 +283,14 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel):
|
|||||||
input_shape: Optional[Tuple] = None,
|
input_shape: Optional[Tuple] = None,
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
dtype: jnp.dtype = jnp.float32,
|
dtype: jnp.dtype = jnp.float32,
|
||||||
|
_do_init: bool = True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
|
if not _do_init:
|
||||||
|
raise ValueError(
|
||||||
|
"`FlaxVisionEncoderDecoderModel` cannot be created without initializing, `_do_init` must be `True`."
|
||||||
|
)
|
||||||
|
|
||||||
if input_shape is None:
|
if input_shape is None:
|
||||||
num_channels = getattr(config.encoder, "num_channels", 3)
|
num_channels = getattr(config.encoder, "num_channels", 3)
|
||||||
input_shape = (
|
input_shape = (
|
||||||
@@ -301,9 +308,9 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||||
|
|
||||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||||
encoder_input_shape, decoder_input_shape = input_shape
|
encoder_input_shape, decoder_input_shape = input_shape
|
||||||
|
|
||||||
# init input tensors
|
# init input tensors
|
||||||
@@ -325,7 +332,7 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel):
|
|||||||
params_rng, dropout_rng = jax.random.split(rng)
|
params_rng, dropout_rng = jax.random.split(rng)
|
||||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||||
|
|
||||||
return self.module.init(
|
random_params = self.module.init(
|
||||||
rngs,
|
rngs,
|
||||||
pixel_values,
|
pixel_values,
|
||||||
decoder_input_ids,
|
decoder_input_ids,
|
||||||
@@ -333,6 +340,16 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel):
|
|||||||
decoder_position_ids,
|
decoder_position_ids,
|
||||||
)["params"]
|
)["params"]
|
||||||
|
|
||||||
|
if params is not None:
|
||||||
|
random_params = flatten_dict(unfreeze(random_params))
|
||||||
|
params = flatten_dict(unfreeze(params))
|
||||||
|
for missing_key in self._missing_keys:
|
||||||
|
params[missing_key] = random_params[missing_key]
|
||||||
|
self._missing_keys = set()
|
||||||
|
return freeze(unflatten_dict(params))
|
||||||
|
else:
|
||||||
|
return random_params
|
||||||
|
|
||||||
def init_cache(self, batch_size, max_length, encoder_outputs):
|
def init_cache(self, batch_size, max_length, encoder_outputs):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -20,7 +20,8 @@ from typing import Optional, Tuple
|
|||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from flax.core.frozen_dict import FrozenDict
|
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||||
|
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||||
|
|
||||||
from ...modeling_flax_utils import FlaxPreTrainedModel, append_replace_return_docstrings, overwrite_call_docstring
|
from ...modeling_flax_utils import FlaxPreTrainedModel, append_replace_return_docstrings, overwrite_call_docstring
|
||||||
from ...utils import add_start_docstrings, logging
|
from ...utils import add_start_docstrings, logging
|
||||||
@@ -225,15 +226,22 @@ class FlaxVisionTextDualEncoderModel(FlaxPreTrainedModel):
|
|||||||
input_shape: Optional[Tuple] = None,
|
input_shape: Optional[Tuple] = None,
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
dtype: jnp.dtype = jnp.float32,
|
dtype: jnp.dtype = jnp.float32,
|
||||||
|
_do_init: bool = True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
|
|
||||||
|
if not _do_init:
|
||||||
|
raise ValueError(
|
||||||
|
"`FlaxVisionTextDualEncoderModel` cannot be created without initializing, `_do_init` must be `True`."
|
||||||
|
)
|
||||||
|
|
||||||
if input_shape is None:
|
if input_shape is None:
|
||||||
input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3))
|
input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3))
|
||||||
|
|
||||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||||
|
|
||||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||||
# init input tensor
|
# init input tensor
|
||||||
input_ids = jnp.zeros(input_shape[0], dtype="i4")
|
input_ids = jnp.zeros(input_shape[0], dtype="i4")
|
||||||
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0])
|
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0])
|
||||||
@@ -245,7 +253,19 @@ class FlaxVisionTextDualEncoderModel(FlaxPreTrainedModel):
|
|||||||
params_rng, dropout_rng = jax.random.split(rng)
|
params_rng, dropout_rng = jax.random.split(rng)
|
||||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||||
|
|
||||||
return self.module.init(rngs, input_ids, pixel_values, attention_mask, position_ids, token_type_ids)["params"]
|
random_params = self.module.init(rngs, input_ids, pixel_values, attention_mask, position_ids, token_type_ids)[
|
||||||
|
"params"
|
||||||
|
]
|
||||||
|
|
||||||
|
if params is not None:
|
||||||
|
random_params = flatten_dict(unfreeze(random_params))
|
||||||
|
params = flatten_dict(unfreeze(params))
|
||||||
|
for missing_key in self._missing_keys:
|
||||||
|
params[missing_key] = random_params[missing_key]
|
||||||
|
self._missing_keys = set()
|
||||||
|
return freeze(unflatten_dict(params))
|
||||||
|
else:
|
||||||
|
return random_params
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -18,8 +18,9 @@ from typing import Optional, Tuple
|
|||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from flax.core.frozen_dict import FrozenDict
|
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||||
from flax.linen.attention import dot_product_attention_weights
|
from flax.linen.attention import dot_product_attention_weights
|
||||||
|
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||||
|
|
||||||
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling, FlaxSequenceClassifierOutput
|
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling, FlaxSequenceClassifierOutput
|
||||||
from ...modeling_flax_utils import (
|
from ...modeling_flax_utils import (
|
||||||
@@ -407,20 +408,38 @@ class FlaxViTPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
main_input_name = "pixel_values"
|
main_input_name = "pixel_values"
|
||||||
module_class: nn.Module = None
|
module_class: nn.Module = None
|
||||||
|
|
||||||
def __init__(self, config: ViTConfig, input_shape=None, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs):
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: ViTConfig,
|
||||||
|
input_shape=None,
|
||||||
|
seed: int = 0,
|
||||||
|
dtype: jnp.dtype = jnp.float32,
|
||||||
|
_do_init: bool = True,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||||
if input_shape is None:
|
if input_shape is None:
|
||||||
input_shape = (1, config.image_size, config.image_size, 3)
|
input_shape = (1, config.image_size, config.image_size, 3)
|
||||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||||
|
|
||||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||||
# init input tensors
|
# init input tensors
|
||||||
pixel_values = jnp.zeros(input_shape, dtype=self.dtype)
|
pixel_values = jnp.zeros(input_shape, dtype=self.dtype)
|
||||||
|
|
||||||
params_rng, dropout_rng = jax.random.split(rng)
|
params_rng, dropout_rng = jax.random.split(rng)
|
||||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||||
|
|
||||||
return self.module.init(rngs, pixel_values, return_dict=False)["params"]
|
random_params = self.module.init(rngs, pixel_values, return_dict=False)["params"]
|
||||||
|
|
||||||
|
if params is not None:
|
||||||
|
random_params = flatten_dict(unfreeze(random_params))
|
||||||
|
params = flatten_dict(unfreeze(params))
|
||||||
|
for missing_key in self._missing_keys:
|
||||||
|
params[missing_key] = random_params[missing_key]
|
||||||
|
self._missing_keys = set()
|
||||||
|
return freeze(unflatten_dict(params))
|
||||||
|
else:
|
||||||
|
return random_params
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
def __call__(
|
def __call__(
|
||||||
|
|||||||
@@ -23,8 +23,9 @@ import flax
|
|||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from flax.core.frozen_dict import FrozenDict
|
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||||
from flax.linen.attention import dot_product_attention_weights
|
from flax.linen.attention import dot_product_attention_weights
|
||||||
|
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||||
from jax import lax
|
from jax import lax
|
||||||
|
|
||||||
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
|
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
|
||||||
@@ -858,19 +859,30 @@ class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel):
|
|||||||
input_shape: Tuple = (1, 1024),
|
input_shape: Tuple = (1, 1024),
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
dtype: jnp.dtype = jnp.float32,
|
dtype: jnp.dtype = jnp.float32,
|
||||||
|
_do_init: bool = True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||||
|
|
||||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||||
# init input tensors
|
# init input tensors
|
||||||
input_values = jnp.zeros(input_shape, dtype="i4")
|
input_values = jnp.zeros(input_shape, dtype="i4")
|
||||||
attention_mask = jnp.ones_like(input_values)
|
attention_mask = jnp.ones_like(input_values)
|
||||||
params_rng, dropout_rng = jax.random.split(rng, 2)
|
params_rng, dropout_rng = jax.random.split(rng, 2)
|
||||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||||
|
|
||||||
return self.module.init(rngs, input_values, attention_mask, return_dict=False)["params"]
|
random_params = self.module.init(rngs, input_values, attention_mask, return_dict=False)["params"]
|
||||||
|
|
||||||
|
if params is not None:
|
||||||
|
random_params = flatten_dict(unfreeze(random_params))
|
||||||
|
params = flatten_dict(unfreeze(params))
|
||||||
|
for missing_key in self._missing_keys:
|
||||||
|
params[missing_key] = random_params[missing_key]
|
||||||
|
self._missing_keys = set()
|
||||||
|
return freeze(unflatten_dict(params))
|
||||||
|
else:
|
||||||
|
return random_params
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
|
||||||
def __call__(
|
def __call__(
|
||||||
|
|||||||
@@ -25,9 +25,10 @@ import numpy as np
|
|||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from flax.core.frozen_dict import FrozenDict, unfreeze
|
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||||
from flax.linen import combine_masks, make_causal_mask
|
from flax.linen import combine_masks, make_causal_mask
|
||||||
from flax.linen.attention import dot_product_attention_weights
|
from flax.linen.attention import dot_product_attention_weights
|
||||||
|
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||||
from jax import lax
|
from jax import lax
|
||||||
from jax.random import PRNGKey
|
from jax.random import PRNGKey
|
||||||
|
|
||||||
@@ -561,12 +562,13 @@ class FlaxXGLMPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
input_shape: Tuple[int] = (1, 1),
|
input_shape: Tuple[int] = (1, 1),
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
dtype: jnp.dtype = jnp.float32,
|
dtype: jnp.dtype = jnp.float32,
|
||||||
|
_do_init: bool = True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||||
|
|
||||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||||
# init input tensors
|
# init input tensors
|
||||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||||
attention_mask = jnp.ones_like(input_ids)
|
attention_mask = jnp.ones_like(input_ids)
|
||||||
@@ -589,7 +591,17 @@ class FlaxXGLMPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)
|
module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)
|
||||||
|
|
||||||
return module_init_outputs["params"]
|
random_params = module_init_outputs["params"]
|
||||||
|
|
||||||
|
if params is not None:
|
||||||
|
random_params = flatten_dict(unfreeze(random_params))
|
||||||
|
params = flatten_dict(unfreeze(params))
|
||||||
|
for missing_key in self._missing_keys:
|
||||||
|
params[missing_key] = random_params[missing_key]
|
||||||
|
self._missing_keys = set()
|
||||||
|
return freeze(unflatten_dict(params))
|
||||||
|
else:
|
||||||
|
return random_params
|
||||||
|
|
||||||
def init_cache(self, batch_size, max_length):
|
def init_cache(self, batch_size, max_length):
|
||||||
r"""
|
r"""
|
||||||
|
|||||||
@@ -23,7 +23,8 @@ import numpy as np
|
|||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from flax.core.frozen_dict import FrozenDict
|
from flax.core.frozen_dict import FrozenDict, unfreeze, freeze
|
||||||
|
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||||
from flax.linen.attention import dot_product_attention_weights
|
from flax.linen.attention import dot_product_attention_weights
|
||||||
from jax import lax
|
from jax import lax
|
||||||
|
|
||||||
@@ -586,12 +587,18 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
|
|||||||
module_class: nn.Module = None
|
module_class: nn.Module = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, config: {{cookiecutter.camelcase_modelname}}Config, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
self,
|
||||||
|
config: {{cookiecutter.camelcase_modelname}}Config,
|
||||||
|
input_shape: Tuple = (1, 1),
|
||||||
|
seed: int = 0,
|
||||||
|
dtype: jnp.dtype = jnp.float32,
|
||||||
|
_do_init: bool = True,
|
||||||
|
**kwargs
|
||||||
):
|
):
|
||||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||||
|
|
||||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||||
# init input tensors
|
# init input tensors
|
||||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||||
token_type_ids = jnp.zeros_like(input_ids)
|
token_type_ids = jnp.zeros_like(input_ids)
|
||||||
@@ -602,10 +609,20 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
|
|||||||
params_rng, dropout_rng = jax.random.split(rng)
|
params_rng, dropout_rng = jax.random.split(rng)
|
||||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||||
|
|
||||||
return self.module.init(
|
random_params = self.module.init(
|
||||||
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
|
rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
|
||||||
)["params"]
|
)["params"]
|
||||||
|
|
||||||
|
if params is not None:
|
||||||
|
random_params = flatten_dict(unfreeze(random_params))
|
||||||
|
params = flatten_dict(unfreeze(params))
|
||||||
|
for missing_key in self._missing_keys:
|
||||||
|
params[missing_key] = random_params[missing_key]
|
||||||
|
self._missing_keys = set()
|
||||||
|
return freeze(unflatten_dict(params))
|
||||||
|
else:
|
||||||
|
return random_params
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@@ -1130,9 +1147,10 @@ from typing import Callable, Optional, Tuple
|
|||||||
import flax.linen as nn
|
import flax.linen as nn
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from flax.core.frozen_dict import FrozenDict, unfreeze
|
from flax.core.frozen_dict import FrozenDict, unfreeze, freeze
|
||||||
from flax.linen import combine_masks, make_causal_mask
|
from flax.linen import combine_masks, make_causal_mask
|
||||||
from flax.linen.attention import dot_product_attention_weights
|
from flax.linen.attention import dot_product_attention_weights
|
||||||
|
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||||
from jax import lax
|
from jax import lax
|
||||||
from jax.random import PRNGKey
|
from jax.random import PRNGKey
|
||||||
|
|
||||||
@@ -2031,12 +2049,13 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
|
|||||||
input_shape: Tuple[int] = (1, 1),
|
input_shape: Tuple[int] = (1, 1),
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
dtype: jnp.dtype = jnp.float32,
|
dtype: jnp.dtype = jnp.float32,
|
||||||
|
_do_init: bool = True,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||||
|
|
||||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||||
# init input tensors
|
# init input tensors
|
||||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||||
# make sure initialization pass will work for Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassificationModule
|
# make sure initialization pass will work for Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassificationModule
|
||||||
@@ -2052,7 +2071,7 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
|
|||||||
params_rng, dropout_rng = jax.random.split(rng)
|
params_rng, dropout_rng = jax.random.split(rng)
|
||||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||||
|
|
||||||
return self.module.init(
|
random_params = self.module.init(
|
||||||
rngs,
|
rngs,
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
@@ -2062,6 +2081,16 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
|
|||||||
decoder_position_ids,
|
decoder_position_ids,
|
||||||
)["params"]
|
)["params"]
|
||||||
|
|
||||||
|
if params is not None:
|
||||||
|
random_params = flatten_dict(unfreeze(random_params))
|
||||||
|
params = flatten_dict(unfreeze(params))
|
||||||
|
for missing_key in self._missing_keys:
|
||||||
|
params[missing_key] = random_params[missing_key]
|
||||||
|
self._missing_keys = set()
|
||||||
|
return freeze(unflatten_dict(params))
|
||||||
|
else:
|
||||||
|
return random_params
|
||||||
|
|
||||||
def init_cache(self, batch_size, max_length, encoder_outputs):
|
def init_cache(self, batch_size, max_length, encoder_outputs):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ if is_flax_available():
|
|||||||
|
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from flax.core.frozen_dict import unfreeze
|
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||||
from transformers import (
|
from transformers import (
|
||||||
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||||
@@ -904,6 +904,93 @@ class FlaxModelTesterMixin:
|
|||||||
else:
|
else:
|
||||||
_check_attentions_validity(outputs.attentions)
|
_check_attentions_validity(outputs.attentions)
|
||||||
|
|
||||||
|
def test_no_automatic_init(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.return_dict = True
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config, _do_init=False)
|
||||||
|
|
||||||
|
# Check that accesing parmas raises an ValueError when _do_init is False
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
params = model.params
|
||||||
|
|
||||||
|
# Check if we params can be properly initialized when calling init_weights
|
||||||
|
params = model.init_weights(model.key, model.input_shape)
|
||||||
|
self.assertIsInstance(params, FrozenDict)
|
||||||
|
# Check if all required parmas are initialized
|
||||||
|
keys = set(flatten_dict(unfreeze(params)).keys())
|
||||||
|
self.assertTrue(all(k in keys for k in model.required_params))
|
||||||
|
# Check if the shapes match
|
||||||
|
flat_params = flatten_dict(unfreeze(params))
|
||||||
|
for k, v in flatten_dict(unfreeze(model.params_shape_tree)).items():
|
||||||
|
self.assertEqual(
|
||||||
|
v.shape,
|
||||||
|
flat_params[k].shape,
|
||||||
|
"Shapes of {} do not match. Expecting {}, got {}.".format(k, v.shape, flat_params[k].shape),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that setting params raises an ValueError when _do_init is False
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
model.params = params
|
||||||
|
|
||||||
|
# Check if we can do a forward pass
|
||||||
|
inputs_dict["output_hidden_states"] = True
|
||||||
|
inputs = self._prepare_for_class(inputs_dict, model_class).copy()
|
||||||
|
model(**inputs, params=params)
|
||||||
|
|
||||||
|
def test_from_pretrained_with_no_automatic_init(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.return_dict = True
|
||||||
|
|
||||||
|
def _assert_all_params_initialised(model, params):
|
||||||
|
# Check if all required parmas are loaded
|
||||||
|
keys = set(flatten_dict(unfreeze(params)).keys())
|
||||||
|
self.assertTrue(all(k in keys for k in model.required_params))
|
||||||
|
# Check if the shapes match
|
||||||
|
flat_params = flatten_dict(unfreeze(params))
|
||||||
|
for k, v in flatten_dict(unfreeze(model.params_shape_tree)).items():
|
||||||
|
self.assertEqual(
|
||||||
|
v.shape,
|
||||||
|
flat_params[k].shape,
|
||||||
|
"Shapes of {} do not match. Expecting {}, got {}.".format(k, v.shape, flat_params[k].shape),
|
||||||
|
)
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
# init the model
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
# save the model in the temporary directory
|
||||||
|
# load the saved model with _do_init=False
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model.save_pretrained(tmpdirname)
|
||||||
|
model, params = model_class.from_pretrained(tmpdirname, _do_init=False)
|
||||||
|
|
||||||
|
# Check that accesing parmas raises an ValueError when _do_init is False
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
params = model.params
|
||||||
|
|
||||||
|
# Check if all required parmas are loaded
|
||||||
|
_assert_all_params_initialised(model, params)
|
||||||
|
|
||||||
|
# Check that setting params raises an ValueError when _do_init is False
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
model.params = params
|
||||||
|
|
||||||
|
# Check if init_weights initializes missing keys from from_pretrained
|
||||||
|
flat_params = flatten_dict(unfreeze(params))
|
||||||
|
random_key = random.choice(list(flat_params.keys()))
|
||||||
|
flat_params.pop(random_key)
|
||||||
|
params = freeze(unflatten_dict(flat_params))
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model.save_pretrained(tmpdirname, params=params)
|
||||||
|
model, params = model_class.from_pretrained(tmpdirname, _do_init=False)
|
||||||
|
|
||||||
|
params = model.init_weights(model.key, model.input_shape, params=params)
|
||||||
|
# Check if all required parmas are loaded
|
||||||
|
_assert_all_params_initialised(model, params)
|
||||||
|
|
||||||
|
|
||||||
@require_flax
|
@require_flax
|
||||||
@is_staging_test
|
@is_staging_test
|
||||||
|
|||||||
Reference in New Issue
Block a user