Flax Remat for LongT5 (#17994)
* [Flax] Add remat (gradient checkpointing) * fix variable naming in test * flip: checkpoint using a method * fix naming * fix class naming * apply PVP's suggestions from code review * add gradient_checkpointing to examples * Add gradient_checkpointing to run_mlm_flax * Add remat to longt5 * Add gradient checkpointing test longt5 * Fix args errors * Fix remaining tests * Make fixup & quality fixes * replace kwargs * remove unecessary kwargs * Make fixup changes * revert long_t5_flax changes * Remove return_dict and copy to LongT5 * Remove test_gradient_checkpointing Co-authored-by: sanchit-gandhi <sanchit@huggingface.co>
This commit is contained in:
@@ -107,6 +107,12 @@ class TrainingArguments:
|
||||
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
|
||||
)
|
||||
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
|
||||
gradient_checkpointing: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.output_dir is not None:
|
||||
@@ -640,6 +646,9 @@ def main():
|
||||
dtype=getattr(jnp, model_args.dtype),
|
||||
)
|
||||
|
||||
if training_args.gradient_checkpointing:
|
||||
model.enable_gradient_checkpointing()
|
||||
|
||||
# Store some constant
|
||||
num_epochs = int(training_args.num_train_epochs)
|
||||
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
||||
|
||||
Reference in New Issue
Block a user