[Flax] Add remat (gradient checkpointing) (#17843)
* [Flax] Add remat (gradient checkpointing) * fix variable naming in test * flip: checkpoint using a method * fix naming * fix class naming * apply PVP's suggestions from code review * make fix-copies * fix big-bird, electra, roberta * cookie-cutter * fix flax big-bird * move test to common
This commit is contained in:
@@ -235,6 +235,9 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> Dict:
|
||||
raise NotImplementedError(f"init method has to be implemented for {self}")
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
raise NotImplementedError(f"gradient checkpointing method has to be implemented for {self}")
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config, **kwargs):
|
||||
"""
|
||||
|
||||
@@ -23,6 +23,7 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||
from flax.linen import combine_masks, make_causal_mask
|
||||
from flax.linen import partitioning as nn_partitioning
|
||||
from flax.linen.attention import dot_product_attention_weights
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from jax import lax
|
||||
@@ -56,6 +57,8 @@ _CHECKPOINT_FOR_DOC = "bert-base-uncased"
|
||||
_CONFIG_FOR_DOC = "BertConfig"
|
||||
_TOKENIZER_FOR_DOC = "BertTokenizer"
|
||||
|
||||
remat = nn_partitioning.remat
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class FlaxBertForPreTrainingOutput(ModelOutput):
|
||||
@@ -544,11 +547,19 @@ class FlaxBertLayer(nn.Module):
|
||||
class FlaxBertLayerCollection(nn.Module):
|
||||
config: BertConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.layers = [
|
||||
FlaxBertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
|
||||
]
|
||||
if self.gradient_checkpointing:
|
||||
FlaxBertCheckpointLayer = remat(FlaxBertLayer, static_argnums=(5, 6, 7))
|
||||
self.layers = [
|
||||
FlaxBertCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
|
||||
for i in range(self.config.num_hidden_layers)
|
||||
]
|
||||
else:
|
||||
self.layers = [
|
||||
FlaxBertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
|
||||
]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -582,12 +593,12 @@ class FlaxBertLayerCollection(nn.Module):
|
||||
layer_outputs = layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask=head_mask[i] if head_mask is not None else None,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
init_cache=init_cache,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
head_mask[i] if head_mask is not None else None,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
init_cache,
|
||||
deterministic,
|
||||
output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
@@ -617,9 +628,14 @@ class FlaxBertLayerCollection(nn.Module):
|
||||
class FlaxBertEncoder(nn.Module):
|
||||
config: BertConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.layer = FlaxBertLayerCollection(self.config, dtype=self.dtype)
|
||||
self.layer = FlaxBertLayerCollection(
|
||||
self.config,
|
||||
dtype=self.dtype,
|
||||
gradient_checkpointing=self.gradient_checkpointing,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -756,11 +772,24 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
gradient_checkpointing: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
module = self.module_class(
|
||||
config=config,
|
||||
dtype=dtype,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
**kwargs,
|
||||
)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
self._module = self.module_class(
|
||||
config=self.config,
|
||||
dtype=self.dtype,
|
||||
gradient_checkpointing=True,
|
||||
)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensors
|
||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||
@@ -925,10 +954,15 @@ class FlaxBertModule(nn.Module):
|
||||
config: BertConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
add_pooling_layer: bool = True
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.embeddings = FlaxBertEmbeddings(self.config, dtype=self.dtype)
|
||||
self.encoder = FlaxBertEncoder(self.config, dtype=self.dtype)
|
||||
self.encoder = FlaxBertEncoder(
|
||||
self.config,
|
||||
dtype=self.dtype,
|
||||
gradient_checkpointing=self.gradient_checkpointing,
|
||||
)
|
||||
self.pooler = FlaxBertPooler(self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
@@ -1003,9 +1037,14 @@ append_call_sample_docstring(
|
||||
class FlaxBertForPreTrainingModule(nn.Module):
|
||||
config: BertConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype)
|
||||
self.bert = FlaxBertModule(
|
||||
config=self.config,
|
||||
dtype=self.dtype,
|
||||
gradient_checkpointing=self.gradient_checkpointing,
|
||||
)
|
||||
self.cls = FlaxBertPreTrainingHeads(config=self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
@@ -1099,9 +1138,15 @@ append_replace_return_docstrings(
|
||||
class FlaxBertForMaskedLMModule(nn.Module):
|
||||
config: BertConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.bert = FlaxBertModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
|
||||
self.bert = FlaxBertModule(
|
||||
config=self.config,
|
||||
add_pooling_layer=False,
|
||||
dtype=self.dtype,
|
||||
gradient_checkpointing=self.gradient_checkpointing,
|
||||
)
|
||||
self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
@@ -1161,9 +1206,14 @@ append_call_sample_docstring(
|
||||
class FlaxBertForNextSentencePredictionModule(nn.Module):
|
||||
config: BertConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype)
|
||||
self.bert = FlaxBertModule(
|
||||
config=self.config,
|
||||
dtype=self.dtype,
|
||||
gradient_checkpointing=self.gradient_checkpointing,
|
||||
)
|
||||
self.cls = FlaxBertOnlyNSPHead(dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
@@ -1248,9 +1298,14 @@ append_replace_return_docstrings(
|
||||
class FlaxBertForSequenceClassificationModule(nn.Module):
|
||||
config: BertConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype)
|
||||
self.bert = FlaxBertModule(
|
||||
config=self.config,
|
||||
dtype=self.dtype,
|
||||
gradient_checkpointing=self.gradient_checkpointing,
|
||||
)
|
||||
classifier_dropout = (
|
||||
self.config.classifier_dropout
|
||||
if self.config.classifier_dropout is not None
|
||||
@@ -1324,9 +1379,14 @@ append_call_sample_docstring(
|
||||
class FlaxBertForMultipleChoiceModule(nn.Module):
|
||||
config: BertConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype)
|
||||
self.bert = FlaxBertModule(
|
||||
config=self.config,
|
||||
dtype=self.dtype,
|
||||
gradient_checkpointing=self.gradient_checkpointing,
|
||||
)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
self.classifier = nn.Dense(1, dtype=self.dtype)
|
||||
|
||||
@@ -1399,9 +1459,15 @@ append_call_sample_docstring(
|
||||
class FlaxBertForTokenClassificationModule(nn.Module):
|
||||
config: BertConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
|
||||
self.bert = FlaxBertModule(
|
||||
config=self.config,
|
||||
dtype=self.dtype,
|
||||
add_pooling_layer=False,
|
||||
gradient_checkpointing=self.gradient_checkpointing,
|
||||
)
|
||||
classifier_dropout = (
|
||||
self.config.classifier_dropout
|
||||
if self.config.classifier_dropout is not None
|
||||
@@ -1468,9 +1534,15 @@ append_call_sample_docstring(
|
||||
class FlaxBertForQuestionAnsweringModule(nn.Module):
|
||||
config: BertConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.bert = FlaxBertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
|
||||
self.bert = FlaxBertModule(
|
||||
config=self.config,
|
||||
dtype=self.dtype,
|
||||
add_pooling_layer=False,
|
||||
gradient_checkpointing=self.gradient_checkpointing,
|
||||
)
|
||||
self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
@@ -1539,9 +1611,15 @@ append_call_sample_docstring(
|
||||
class FlaxBertForCausalLMModule(nn.Module):
|
||||
config: BertConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.bert = FlaxBertModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
|
||||
self.bert = FlaxBertModule(
|
||||
config=self.config,
|
||||
add_pooling_layer=False,
|
||||
dtype=self.dtype,
|
||||
gradient_checkpointing=self.gradient_checkpointing,
|
||||
)
|
||||
self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
|
||||
@@ -23,6 +23,7 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||
from flax.linen import combine_masks, make_causal_mask
|
||||
from flax.linen import partitioning as nn_partitioning
|
||||
from flax.linen.attention import dot_product_attention_weights
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from jax import lax
|
||||
@@ -54,6 +55,8 @@ _CHECKPOINT_FOR_DOC = "google/bigbird-roberta-base"
|
||||
_CONFIG_FOR_DOC = "BigBirdConfig"
|
||||
_TOKENIZER_FOR_DOC = "BigBirdTokenizer"
|
||||
|
||||
remat = nn_partitioning.remat
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class FlaxBigBirdForPreTrainingOutput(ModelOutput):
|
||||
@@ -1368,12 +1371,20 @@ class FlaxBigBirdLayer(nn.Module):
|
||||
class FlaxBigBirdLayerCollection(nn.Module):
|
||||
config: BigBirdConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.layers = [
|
||||
FlaxBigBirdLayer(self.config, layer_id=i, name=str(i), dtype=self.dtype)
|
||||
for i in range(self.config.num_hidden_layers)
|
||||
]
|
||||
if self.gradient_checkpointing:
|
||||
FlaxBigBirdCheckpointLayer = remat(FlaxBigBirdLayer, static_argnums=(5, 6, 7))
|
||||
self.layers = [
|
||||
FlaxBigBirdCheckpointLayer(self.config, layer_id=i, name=str(i), dtype=self.dtype)
|
||||
for i in range(self.config.num_hidden_layers)
|
||||
]
|
||||
else:
|
||||
self.layers = [
|
||||
FlaxBigBirdLayer(self.config, layer_id=i, name=str(i), dtype=self.dtype)
|
||||
for i in range(self.config.num_hidden_layers)
|
||||
]
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection.__call__ with Bert->BigBird
|
||||
def __call__(
|
||||
@@ -1408,12 +1419,12 @@ class FlaxBigBirdLayerCollection(nn.Module):
|
||||
layer_outputs = layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask=head_mask[i] if head_mask is not None else None,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
init_cache=init_cache,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
head_mask[i] if head_mask is not None else None,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
init_cache,
|
||||
deterministic,
|
||||
output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
@@ -1444,9 +1455,14 @@ class FlaxBigBirdLayerCollection(nn.Module):
|
||||
class FlaxBigBirdEncoder(nn.Module):
|
||||
config: BigBirdConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.layer = FlaxBigBirdLayerCollection(self.config, dtype=self.dtype)
|
||||
self.layer = FlaxBigBirdLayerCollection(
|
||||
self.config,
|
||||
dtype=self.dtype,
|
||||
gradient_checkpointing=self.gradient_checkpointing,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -1559,9 +1575,10 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel):
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
gradient_checkpointing: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
|
||||
if config.attention_type == "block_sparse" and input_shape is None:
|
||||
input_shape = (1, 12 * config.block_size)
|
||||
elif input_shape is None:
|
||||
@@ -1569,6 +1586,14 @@ class FlaxBigBirdPreTrainedModel(FlaxPreTrainedModel):
|
||||
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing
|
||||
def enable_gradient_checkpointing(self):
|
||||
self._module = self.module_class(
|
||||
config=self.config,
|
||||
dtype=self.dtype,
|
||||
gradient_checkpointing=True,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_weights
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensors
|
||||
@@ -1735,10 +1760,13 @@ class FlaxBigBirdModule(nn.Module):
|
||||
config: BigBirdConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
add_pooling_layer: bool = True
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.embeddings = FlaxBigBirdEmbeddings(self.config, dtype=self.dtype)
|
||||
self.encoder = FlaxBigBirdEncoder(self.config, dtype=self.dtype)
|
||||
self.encoder = FlaxBigBirdEncoder(
|
||||
self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
|
||||
)
|
||||
self.pooler = nn.Dense(
|
||||
self.config.hidden_size,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
@@ -1812,9 +1840,14 @@ append_call_sample_docstring(
|
||||
class FlaxBigBirdForPreTrainingModule(nn.Module):
|
||||
config: BigBirdConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.bert = FlaxBigBirdModule(config=self.config, dtype=self.dtype)
|
||||
self.bert = FlaxBigBirdModule(
|
||||
config=self.config,
|
||||
dtype=self.dtype,
|
||||
gradient_checkpointing=self.gradient_checkpointing,
|
||||
)
|
||||
self.cls = FlaxBigBirdPreTrainingHeads(config=self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
@@ -1910,9 +1943,15 @@ append_replace_return_docstrings(
|
||||
class FlaxBigBirdForMaskedLMModule(nn.Module):
|
||||
config: BigBirdConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.bert = FlaxBigBirdModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
|
||||
self.bert = FlaxBigBirdModule(
|
||||
config=self.config,
|
||||
add_pooling_layer=False,
|
||||
dtype=self.dtype,
|
||||
gradient_checkpointing=self.gradient_checkpointing,
|
||||
)
|
||||
self.cls = FlaxBigBirdOnlyMLMHead(config=self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
@@ -1999,9 +2038,12 @@ class FlaxBigBirdClassificationHead(nn.Module):
|
||||
class FlaxBigBirdForSequenceClassificationModule(nn.Module):
|
||||
config: BigBirdConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.bert = FlaxBigBirdModule(config=self.config, dtype=self.dtype)
|
||||
self.bert = FlaxBigBirdModule(
|
||||
config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
|
||||
)
|
||||
self.classifier = FlaxBigBirdClassificationHead(self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
@@ -2067,9 +2109,14 @@ append_call_sample_docstring(
|
||||
class FlaxBigBirdForMultipleChoiceModule(nn.Module):
|
||||
config: BigBirdConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.bert = FlaxBigBirdModule(config=self.config, dtype=self.dtype)
|
||||
self.bert = FlaxBigBirdModule(
|
||||
config=self.config,
|
||||
dtype=self.dtype,
|
||||
gradient_checkpointing=self.gradient_checkpointing,
|
||||
)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
self.classifier = nn.Dense(1, dtype=self.dtype)
|
||||
|
||||
@@ -2162,9 +2209,15 @@ append_call_sample_docstring(
|
||||
class FlaxBigBirdForTokenClassificationModule(nn.Module):
|
||||
config: BigBirdConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.bert = FlaxBigBirdModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
|
||||
self.bert = FlaxBigBirdModule(
|
||||
config=self.config,
|
||||
dtype=self.dtype,
|
||||
add_pooling_layer=False,
|
||||
gradient_checkpointing=self.gradient_checkpointing,
|
||||
)
|
||||
classifier_dropout = (
|
||||
self.config.classifier_dropout
|
||||
if self.config.classifier_dropout is not None
|
||||
@@ -2255,10 +2308,16 @@ class FlaxBigBirdForQuestionAnsweringModule(nn.Module):
|
||||
config: BigBirdConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
add_pooling_layer: bool = False
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.config.num_labels = 2
|
||||
self.bert = FlaxBigBirdModule(self.config, dtype=self.dtype, add_pooling_layer=self.add_pooling_layer)
|
||||
self.bert = FlaxBigBirdModule(
|
||||
self.config,
|
||||
dtype=self.dtype,
|
||||
add_pooling_layer=self.add_pooling_layer,
|
||||
gradient_checkpointing=self.gradient_checkpointing,
|
||||
)
|
||||
self.qa_classifier = FlaxBigBirdForQuestionAnsweringHead(self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
@@ -2414,9 +2473,15 @@ append_call_sample_docstring(
|
||||
class FlaxBigBirdForCausalLMModule(nn.Module):
|
||||
config: BigBirdConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.bert = FlaxBigBirdModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
|
||||
self.bert = FlaxBigBirdModule(
|
||||
config=self.config,
|
||||
add_pooling_layer=False,
|
||||
dtype=self.dtype,
|
||||
gradient_checkpointing=self.gradient_checkpointing,
|
||||
)
|
||||
self.cls = FlaxBigBirdOnlyMLMHead(config=self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
|
||||
@@ -23,6 +23,7 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||
from flax.linen import combine_masks, make_causal_mask
|
||||
from flax.linen import partitioning as nn_partitioning
|
||||
from flax.linen.attention import dot_product_attention_weights
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from jax import lax
|
||||
@@ -54,6 +55,8 @@ _CHECKPOINT_FOR_DOC = "google/electra-small-discriminator"
|
||||
_CONFIG_FOR_DOC = "ElectraConfig"
|
||||
_TOKENIZER_FOR_DOC = "ElectraTokenizer"
|
||||
|
||||
remat = nn_partitioning.remat
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class FlaxElectraForPreTrainingOutput(ModelOutput):
|
||||
@@ -521,11 +524,20 @@ class FlaxElectraLayer(nn.Module):
|
||||
class FlaxElectraLayerCollection(nn.Module):
|
||||
config: ElectraConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.layers = [
|
||||
FlaxElectraLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
|
||||
]
|
||||
if self.gradient_checkpointing:
|
||||
FlaxElectraCheckpointLayer = remat(FlaxElectraLayer, static_argnums=(5, 6, 7))
|
||||
self.layers = [
|
||||
FlaxElectraCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
|
||||
for i in range(self.config.num_hidden_layers)
|
||||
]
|
||||
else:
|
||||
self.layers = [
|
||||
FlaxElectraLayer(self.config, name=str(i), dtype=self.dtype)
|
||||
for i in range(self.config.num_hidden_layers)
|
||||
]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -559,12 +571,12 @@ class FlaxElectraLayerCollection(nn.Module):
|
||||
layer_outputs = layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask=head_mask[i] if head_mask is not None else None,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
init_cache=init_cache,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
head_mask[i] if head_mask is not None else None,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
init_cache,
|
||||
deterministic,
|
||||
output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
@@ -595,9 +607,14 @@ class FlaxElectraLayerCollection(nn.Module):
|
||||
class FlaxElectraEncoder(nn.Module):
|
||||
config: ElectraConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.layer = FlaxElectraLayerCollection(self.config, dtype=self.dtype)
|
||||
self.layer = FlaxElectraLayerCollection(
|
||||
self.config,
|
||||
dtype=self.dtype,
|
||||
gradient_checkpointing=self.gradient_checkpointing,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -675,11 +692,20 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel):
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
gradient_checkpointing: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing
|
||||
def enable_gradient_checkpointing(self):
|
||||
self._module = self.module_class(
|
||||
config=self.config,
|
||||
dtype=self.dtype,
|
||||
gradient_checkpointing=True,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_weights
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensors
|
||||
@@ -845,12 +871,15 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel):
|
||||
class FlaxElectraModule(nn.Module):
|
||||
config: ElectraConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.embeddings = FlaxElectraEmbeddings(self.config, dtype=self.dtype)
|
||||
if self.config.embedding_size != self.config.hidden_size:
|
||||
self.embeddings_project = nn.Dense(self.config.hidden_size, dtype=self.dtype)
|
||||
self.encoder = FlaxElectraEncoder(self.config, dtype=self.dtype)
|
||||
self.encoder = FlaxElectraEncoder(
|
||||
self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -925,9 +954,12 @@ class FlaxElectraTiedDense(nn.Module):
|
||||
class FlaxElectraForMaskedLMModule(nn.Module):
|
||||
config: ElectraConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype)
|
||||
self.electra = FlaxElectraModule(
|
||||
config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
|
||||
)
|
||||
self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config, dtype=self.dtype)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.generator_lm_head = FlaxElectraTiedDense(self.config.vocab_size, dtype=self.dtype)
|
||||
@@ -989,9 +1021,12 @@ append_call_sample_docstring(
|
||||
class FlaxElectraForPreTrainingModule(nn.Module):
|
||||
config: ElectraConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype)
|
||||
self.electra = FlaxElectraModule(
|
||||
config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
|
||||
)
|
||||
self.discriminator_predictions = FlaxElectraDiscriminatorPredictions(config=self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
@@ -1074,9 +1109,12 @@ append_replace_return_docstrings(
|
||||
class FlaxElectraForTokenClassificationModule(nn.Module):
|
||||
config: ElectraConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype)
|
||||
self.electra = FlaxElectraModule(
|
||||
config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
|
||||
)
|
||||
classifier_dropout = (
|
||||
self.config.classifier_dropout
|
||||
if self.config.classifier_dropout is not None
|
||||
@@ -1218,9 +1256,12 @@ class FlaxElectraSequenceSummary(nn.Module):
|
||||
class FlaxElectraForMultipleChoiceModule(nn.Module):
|
||||
config: ElectraConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype)
|
||||
self.electra = FlaxElectraModule(
|
||||
config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
|
||||
)
|
||||
self.sequence_summary = FlaxElectraSequenceSummary(config=self.config, dtype=self.dtype)
|
||||
self.classifier = nn.Dense(1, dtype=self.dtype)
|
||||
|
||||
@@ -1297,9 +1338,12 @@ append_call_sample_docstring(
|
||||
class FlaxElectraForQuestionAnsweringModule(nn.Module):
|
||||
config: ElectraConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype)
|
||||
self.electra = FlaxElectraModule(
|
||||
config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
|
||||
)
|
||||
self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
@@ -1392,9 +1436,12 @@ class FlaxElectraClassificationHead(nn.Module):
|
||||
class FlaxElectraForSequenceClassificationModule(nn.Module):
|
||||
config: ElectraConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype)
|
||||
self.electra = FlaxElectraModule(
|
||||
config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
|
||||
)
|
||||
self.classifier = FlaxElectraClassificationHead(config=self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
@@ -1457,9 +1504,12 @@ append_call_sample_docstring(
|
||||
class FlaxElectraForCausalLMModule(nn.Module):
|
||||
config: ElectraConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype)
|
||||
self.electra = FlaxElectraModule(
|
||||
config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
|
||||
)
|
||||
self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config, dtype=self.dtype)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.generator_lm_head = FlaxElectraTiedDense(self.config.vocab_size, dtype=self.dtype)
|
||||
|
||||
@@ -21,6 +21,7 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||
from flax.linen import combine_masks, make_causal_mask
|
||||
from flax.linen import partitioning as nn_partitioning
|
||||
from flax.linen.attention import dot_product_attention_weights
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from jax import lax
|
||||
@@ -47,6 +48,8 @@ _CHECKPOINT_FOR_DOC = "roberta-base"
|
||||
_CONFIG_FOR_DOC = "RobertaConfig"
|
||||
_TOKENIZER_FOR_DOC = "RobertaTokenizer"
|
||||
|
||||
remat = nn_partitioning.remat
|
||||
|
||||
|
||||
def create_position_ids_from_input_ids(input_ids, padding_idx):
|
||||
"""
|
||||
@@ -511,11 +514,20 @@ class FlaxRobertaLayer(nn.Module):
|
||||
class FlaxRobertaLayerCollection(nn.Module):
|
||||
config: RobertaConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.layers = [
|
||||
FlaxRobertaLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
|
||||
]
|
||||
if self.gradient_checkpointing:
|
||||
FlaxRobertaCheckpointLayer = remat(FlaxRobertaLayer, static_argnums=(5, 6, 7))
|
||||
self.layers = [
|
||||
FlaxRobertaCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
|
||||
for i in range(self.config.num_hidden_layers)
|
||||
]
|
||||
else:
|
||||
self.layers = [
|
||||
FlaxRobertaLayer(self.config, name=str(i), dtype=self.dtype)
|
||||
for i in range(self.config.num_hidden_layers)
|
||||
]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -549,12 +561,12 @@ class FlaxRobertaLayerCollection(nn.Module):
|
||||
layer_outputs = layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask=head_mask[i] if head_mask is not None else None,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
init_cache=init_cache,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
head_mask[i] if head_mask is not None else None,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
init_cache,
|
||||
deterministic,
|
||||
output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
@@ -585,9 +597,14 @@ class FlaxRobertaLayerCollection(nn.Module):
|
||||
class FlaxRobertaEncoder(nn.Module):
|
||||
config: RobertaConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.layer = FlaxRobertaLayerCollection(self.config, dtype=self.dtype)
|
||||
self.layer = FlaxRobertaLayerCollection(
|
||||
self.config,
|
||||
dtype=self.dtype,
|
||||
gradient_checkpointing=self.gradient_checkpointing,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -719,11 +736,20 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
gradient_checkpointing: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing
|
||||
def enable_gradient_checkpointing(self):
|
||||
self._module = self.module_class(
|
||||
config=self.config,
|
||||
dtype=self.dtype,
|
||||
gradient_checkpointing=True,
|
||||
)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensors
|
||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||
@@ -889,10 +915,15 @@ class FlaxRobertaModule(nn.Module):
|
||||
config: RobertaConfig
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
add_pooling_layer: bool = True
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.embeddings = FlaxRobertaEmbeddings(self.config, dtype=self.dtype)
|
||||
self.encoder = FlaxRobertaEncoder(self.config, dtype=self.dtype)
|
||||
self.encoder = FlaxRobertaEncoder(
|
||||
self.config,
|
||||
dtype=self.dtype,
|
||||
gradient_checkpointing=self.gradient_checkpointing,
|
||||
)
|
||||
self.pooler = FlaxRobertaPooler(self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
@@ -967,9 +998,15 @@ append_call_sample_docstring(
|
||||
class FlaxRobertaForMaskedLMModule(nn.Module):
|
||||
config: RobertaConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.roberta = FlaxRobertaModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
|
||||
self.roberta = FlaxRobertaModule(
|
||||
config=self.config,
|
||||
add_pooling_layer=False,
|
||||
dtype=self.dtype,
|
||||
gradient_checkpointing=self.gradient_checkpointing,
|
||||
)
|
||||
self.lm_head = FlaxRobertaLMHead(config=self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
@@ -1034,9 +1071,15 @@ append_call_sample_docstring(
|
||||
class FlaxRobertaForSequenceClassificationModule(nn.Module):
|
||||
config: RobertaConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.roberta = FlaxRobertaModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
|
||||
self.roberta = FlaxRobertaModule(
|
||||
config=self.config,
|
||||
dtype=self.dtype,
|
||||
add_pooling_layer=False,
|
||||
gradient_checkpointing=self.gradient_checkpointing,
|
||||
)
|
||||
self.classifier = FlaxRobertaClassificationHead(config=self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
@@ -1101,9 +1144,14 @@ append_call_sample_docstring(
|
||||
class FlaxRobertaForMultipleChoiceModule(nn.Module):
|
||||
config: RobertaConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.roberta = FlaxRobertaModule(config=self.config, dtype=self.dtype)
|
||||
self.roberta = FlaxRobertaModule(
|
||||
config=self.config,
|
||||
dtype=self.dtype,
|
||||
gradient_checkpointing=self.gradient_checkpointing,
|
||||
)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
self.classifier = nn.Dense(1, dtype=self.dtype)
|
||||
|
||||
@@ -1181,9 +1229,15 @@ append_call_sample_docstring(
|
||||
class FlaxRobertaForTokenClassificationModule(nn.Module):
|
||||
config: RobertaConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.roberta = FlaxRobertaModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
|
||||
self.roberta = FlaxRobertaModule(
|
||||
config=self.config,
|
||||
dtype=self.dtype,
|
||||
add_pooling_layer=False,
|
||||
gradient_checkpointing=self.gradient_checkpointing,
|
||||
)
|
||||
classifier_dropout = (
|
||||
self.config.classifier_dropout
|
||||
if self.config.classifier_dropout is not None
|
||||
@@ -1255,9 +1309,15 @@ append_call_sample_docstring(
|
||||
class FlaxRobertaForQuestionAnsweringModule(nn.Module):
|
||||
config: RobertaConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.roberta = FlaxRobertaModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
|
||||
self.roberta = FlaxRobertaModule(
|
||||
config=self.config,
|
||||
dtype=self.dtype,
|
||||
add_pooling_layer=False,
|
||||
gradient_checkpointing=self.gradient_checkpointing,
|
||||
)
|
||||
self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
@@ -1326,9 +1386,15 @@ append_call_sample_docstring(
|
||||
class FlaxRobertaForCausalLMModule(nn.Module):
|
||||
config: RobertaConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.roberta = FlaxRobertaModule(config=self.config, add_pooling_layer=False, dtype=self.dtype)
|
||||
self.roberta = FlaxRobertaModule(
|
||||
config=self.config,
|
||||
add_pooling_layer=False,
|
||||
dtype=self.dtype,
|
||||
gradient_checkpointing=self.gradient_checkpointing,
|
||||
)
|
||||
self.lm_head = FlaxRobertaLMHead(config=self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
|
||||
@@ -25,6 +25,7 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict, unfreeze, freeze
|
||||
from flax.linen import combine_masks, make_causal_mask
|
||||
from flax.linen import partitioning as nn_partitioning
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from flax.linen.attention import dot_product_attention_weights
|
||||
from jax import lax
|
||||
@@ -126,6 +127,8 @@ _TOKENIZER_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Tokenizer"
|
||||
|
||||
"""
|
||||
|
||||
remat = nn_partitioning.remat
|
||||
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings with Bert->{{cookiecutter.camelcase_modelname}}
|
||||
@@ -507,11 +510,19 @@ class Flax{{cookiecutter.camelcase_modelname}}Layer(nn.Module):
|
||||
class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module):
|
||||
config: {{cookiecutter.camelcase_modelname}}Config
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.layers = [
|
||||
Flax{{cookiecutter.camelcase_modelname}}Layer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
|
||||
]
|
||||
if self.gradient_checkpointing:
|
||||
Flax{{cookiecutter.camelcase_modelname}}CheckpointLayer = remat(Flax{{cookiecutter.camelcase_modelname}}Layer, static_argnums=(5, 6, 7))
|
||||
self.layers = [
|
||||
Flax{{cookiecutter.camelcase_modelname}}CheckpointLayer(self.config, name=str(i), dtype=self.dtype)
|
||||
for i in range(self.config.num_hidden_layers)
|
||||
]
|
||||
else:
|
||||
self.layers = [
|
||||
Flax{{cookiecutter.camelcase_modelname}}Layer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
|
||||
]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -545,12 +556,12 @@ class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module):
|
||||
layer_outputs = layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask=head_mask[i] if head_mask is not None else None,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
init_cache=init_cache,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
head_mask[i] if head_mask is not None else None,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
init_cache,
|
||||
deterministic,
|
||||
output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
@@ -581,9 +592,10 @@ class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module):
|
||||
class Flax{{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
|
||||
config: {{cookiecutter.camelcase_modelname}}Config
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.layer = Flax{{cookiecutter.camelcase_modelname}}LayerCollection(self.config, dtype=self.dtype)
|
||||
self.layer = Flax{{cookiecutter.camelcase_modelname}}LayerCollection(self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -725,11 +737,20 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
gradient_checkpointing: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing
|
||||
def enable_gradient_checkpointing(self):
|
||||
self._module = self.module_class(
|
||||
config=self.config,
|
||||
dtype=self.dtype,
|
||||
gradient_checkpointing=True,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_weights with Bert->{{cookiecutter.camelcase_modelname}}
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensors
|
||||
@@ -897,10 +918,11 @@ class Flax{{cookiecutter.camelcase_modelname}}Module(nn.Module):
|
||||
config: {{cookiecutter.camelcase_modelname}}Config
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
add_pooling_layer: bool = True
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.embeddings = Flax{{cookiecutter.camelcase_modelname}}Embeddings(self.config, dtype=self.dtype)
|
||||
self.encoder = Flax{{cookiecutter.camelcase_modelname}}Encoder(self.config, dtype=self.dtype)
|
||||
self.encoder = Flax{{cookiecutter.camelcase_modelname}}Encoder(self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
|
||||
self.pooler = Flax{{cookiecutter.camelcase_modelname}}Pooler(self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
@@ -969,9 +991,10 @@ class Flax{{cookiecutter.camelcase_modelname}}Model(Flax{{cookiecutter.camelcase
|
||||
class Flax{{cookiecutter.camelcase_modelname}}ForMaskedLMModule(nn.Module):
|
||||
config: {{cookiecutter.camelcase_modelname}}Config
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype)
|
||||
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
|
||||
self.cls = Flax{{cookiecutter.camelcase_modelname}}OnlyMLMHead(config=self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
@@ -1030,9 +1053,10 @@ append_call_sample_docstring(
|
||||
class Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule(nn.Module):
|
||||
config: {{cookiecutter.camelcase_modelname}}Config
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype)
|
||||
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
|
||||
self.cls = Flax{{cookiecutter.camelcase_modelname}}OnlyMLMHead(config=self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
@@ -1092,9 +1116,10 @@ append_call_sample_docstring(
|
||||
class Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassificationModule(nn.Module):
|
||||
config: {{cookiecutter.camelcase_modelname}}Config
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype)
|
||||
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
self.classifier = nn.Dense(
|
||||
self.config.num_labels,
|
||||
@@ -1163,9 +1188,10 @@ append_call_sample_docstring(
|
||||
class Flax{{cookiecutter.camelcase_modelname}}ForMultipleChoiceModule(nn.Module):
|
||||
config: {{cookiecutter.camelcase_modelname}}Config
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype)
|
||||
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
self.classifier = nn.Dense(1, dtype=self.dtype)
|
||||
|
||||
@@ -1238,9 +1264,10 @@ append_call_sample_docstring(
|
||||
class Flax{{cookiecutter.camelcase_modelname}}ForTokenClassificationModule(nn.Module):
|
||||
config: {{cookiecutter.camelcase_modelname}}Config
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, add_pooling_layer=False)
|
||||
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, add_pooling_layer=False, gradient_checkpointing=self.gradient_checkpointing)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
|
||||
|
||||
@@ -1302,9 +1329,10 @@ append_call_sample_docstring(
|
||||
class Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnsweringModule(nn.Module):
|
||||
config: {{cookiecutter.camelcase_modelname}}Config
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, add_pooling_layer=False)
|
||||
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, add_pooling_layer=False, gradient_checkpointing=self.gradient_checkpointing)
|
||||
self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
@@ -1373,9 +1401,10 @@ append_call_sample_docstring(
|
||||
class Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule(nn.Module):
|
||||
config: {{cookiecutter.camelcase_modelname}}Config
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype)
|
||||
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
|
||||
self.cls = Flax{{cookiecutter.camelcase_modelname}}OnlyMLMHead(config=self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
|
||||
@@ -1099,6 +1099,33 @@ class FlaxModelTesterMixin:
|
||||
for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(new_model.params).values()):
|
||||
self.assertTrue(np.allclose(np.array(p1), np.array(p2)))
|
||||
|
||||
def test_gradient_checkpointing(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
# prepare inputs
|
||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
model = model_class(config)
|
||||
remat_model = model_class(config)
|
||||
|
||||
try:
|
||||
remat_model.enable_gradient_checkpointing()
|
||||
except NotImplementedError:
|
||||
continue
|
||||
|
||||
outputs = model(**prepared_inputs_dict)
|
||||
remat_outputs = remat_model(**prepared_inputs_dict)
|
||||
|
||||
# ensure that the dicts of outputs contain the same keys
|
||||
self.assertEqual(outputs.keys(), remat_outputs.keys())
|
||||
|
||||
outputs = outputs.to_tuple()
|
||||
remat_outputs = remat_outputs.to_tuple()
|
||||
|
||||
# ensure that the outputs remain precisely equal
|
||||
for output, remat_output in zip(outputs, remat_outputs):
|
||||
self.assertTrue((output == remat_output).all())
|
||||
|
||||
|
||||
@require_flax
|
||||
@is_staging_test
|
||||
|
||||
Reference in New Issue
Block a user