[Flax] Add remat (gradient checkpointing) (#17843)
* [Flax] Add remat (gradient checkpointing) * fix variable naming in test * flip: checkpoint using a method * fix naming * fix class naming * apply PVP's suggestions from code review * make fix-copies * fix big-bird, electra, roberta * cookie-cutter * fix flax big-bird * move test to common
This commit is contained in:
@@ -25,6 +25,7 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict, unfreeze, freeze
|
||||
from flax.linen import combine_masks, make_causal_mask
|
||||
from flax.linen import partitioning as nn_partitioning
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from flax.linen.attention import dot_product_attention_weights
|
||||
from jax import lax
|
||||
@@ -126,6 +127,8 @@ _TOKENIZER_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Tokenizer"
|
||||
|
||||
"""
|
||||
|
||||
remat = nn_partitioning.remat
|
||||
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings with Bert->{{cookiecutter.camelcase_modelname}}
|
||||
@@ -507,11 +510,19 @@ class Flax{{cookiecutter.camelcase_modelname}}Layer(nn.Module):
|
||||
class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module):
|
||||
config: {{cookiecutter.camelcase_modelname}}Config
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.layers = [
|
||||
Flax{{cookiecutter.camelcase_modelname}}Layer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
|
||||
]
|
||||
if self.gradient_checkpointing:
|
||||
Flax{{cookiecutter.camelcase_modelname}}CheckpointLayer = remat(Flax{{cookiecutter.camelcase_modelname}}Layer, static_argnums=(5, 6, 7))
|
||||
self.layers = [
|
||||
Flax{{cookiecutter.camelcase_modelname}}CheckpointLayer(self.config, name=str(i), dtype=self.dtype)
|
||||
for i in range(self.config.num_hidden_layers)
|
||||
]
|
||||
else:
|
||||
self.layers = [
|
||||
Flax{{cookiecutter.camelcase_modelname}}Layer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
|
||||
]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -545,12 +556,12 @@ class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module):
|
||||
layer_outputs = layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask=head_mask[i] if head_mask is not None else None,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
init_cache=init_cache,
|
||||
deterministic=deterministic,
|
||||
output_attentions=output_attentions,
|
||||
head_mask[i] if head_mask is not None else None,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
init_cache,
|
||||
deterministic,
|
||||
output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
@@ -581,9 +592,10 @@ class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module):
|
||||
class Flax{{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
|
||||
config: {{cookiecutter.camelcase_modelname}}Config
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.layer = Flax{{cookiecutter.camelcase_modelname}}LayerCollection(self.config, dtype=self.dtype)
|
||||
self.layer = Flax{{cookiecutter.camelcase_modelname}}LayerCollection(self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -725,11 +737,20 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
gradient_checkpointing: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing
|
||||
def enable_gradient_checkpointing(self):
|
||||
self._module = self.module_class(
|
||||
config=self.config,
|
||||
dtype=self.dtype,
|
||||
gradient_checkpointing=True,
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_weights with Bert->{{cookiecutter.camelcase_modelname}}
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensors
|
||||
@@ -897,10 +918,11 @@ class Flax{{cookiecutter.camelcase_modelname}}Module(nn.Module):
|
||||
config: {{cookiecutter.camelcase_modelname}}Config
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
add_pooling_layer: bool = True
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.embeddings = Flax{{cookiecutter.camelcase_modelname}}Embeddings(self.config, dtype=self.dtype)
|
||||
self.encoder = Flax{{cookiecutter.camelcase_modelname}}Encoder(self.config, dtype=self.dtype)
|
||||
self.encoder = Flax{{cookiecutter.camelcase_modelname}}Encoder(self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
|
||||
self.pooler = Flax{{cookiecutter.camelcase_modelname}}Pooler(self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
@@ -969,9 +991,10 @@ class Flax{{cookiecutter.camelcase_modelname}}Model(Flax{{cookiecutter.camelcase
|
||||
class Flax{{cookiecutter.camelcase_modelname}}ForMaskedLMModule(nn.Module):
|
||||
config: {{cookiecutter.camelcase_modelname}}Config
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype)
|
||||
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
|
||||
self.cls = Flax{{cookiecutter.camelcase_modelname}}OnlyMLMHead(config=self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
@@ -1030,9 +1053,10 @@ append_call_sample_docstring(
|
||||
class Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule(nn.Module):
|
||||
config: {{cookiecutter.camelcase_modelname}}Config
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype)
|
||||
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
|
||||
self.cls = Flax{{cookiecutter.camelcase_modelname}}OnlyMLMHead(config=self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
@@ -1092,9 +1116,10 @@ append_call_sample_docstring(
|
||||
class Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassificationModule(nn.Module):
|
||||
config: {{cookiecutter.camelcase_modelname}}Config
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype)
|
||||
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
self.classifier = nn.Dense(
|
||||
self.config.num_labels,
|
||||
@@ -1163,9 +1188,10 @@ append_call_sample_docstring(
|
||||
class Flax{{cookiecutter.camelcase_modelname}}ForMultipleChoiceModule(nn.Module):
|
||||
config: {{cookiecutter.camelcase_modelname}}Config
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype)
|
||||
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
self.classifier = nn.Dense(1, dtype=self.dtype)
|
||||
|
||||
@@ -1238,9 +1264,10 @@ append_call_sample_docstring(
|
||||
class Flax{{cookiecutter.camelcase_modelname}}ForTokenClassificationModule(nn.Module):
|
||||
config: {{cookiecutter.camelcase_modelname}}Config
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, add_pooling_layer=False)
|
||||
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, add_pooling_layer=False, gradient_checkpointing=self.gradient_checkpointing)
|
||||
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
||||
self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
|
||||
|
||||
@@ -1302,9 +1329,10 @@ append_call_sample_docstring(
|
||||
class Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnsweringModule(nn.Module):
|
||||
config: {{cookiecutter.camelcase_modelname}}Config
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, add_pooling_layer=False)
|
||||
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, add_pooling_layer=False, gradient_checkpointing=self.gradient_checkpointing)
|
||||
self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
@@ -1373,9 +1401,10 @@ append_call_sample_docstring(
|
||||
class Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule(nn.Module):
|
||||
config: {{cookiecutter.camelcase_modelname}}Config
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
gradient_checkpointing: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype)
|
||||
self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing)
|
||||
self.cls = Flax{{cookiecutter.camelcase_modelname}}OnlyMLMHead(config=self.config, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
|
||||
Reference in New Issue
Block a user