Flax Remat for LongT5 (#17994)

* [Flax] Add remat (gradient checkpointing)

* fix variable naming in test

* flip: checkpoint using a method

* fix naming

* fix class naming

* apply PVP's suggestions from code review

* add gradient_checkpointing to examples

* Add gradient_checkpointing to run_mlm_flax

* Add remat to longt5

* Add gradient checkpointing test longt5

* Fix args errors

* Fix remaining tests

* Make fixup & quality fixes

* replace kwargs

* remove unecessary kwargs

* Make fixup changes

* revert long_t5_flax changes

* Remove return_dict and copy to LongT5

* Remove test_gradient_checkpointing

Co-authored-by: sanchit-gandhi <sanchit@huggingface.co>
This commit is contained in:
Karim Foda
2022-08-14 17:27:13 +02:00
committed by GitHub
parent 1ccd2515ed
commit d6eeb87170
4 changed files with 149 additions and 39 deletions

View File

@@ -107,6 +107,12 @@ class TrainingArguments:
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
)
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
gradient_checkpointing: bool = field(
default=False,
metadata={
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
},
)
def __post_init__(self):
if self.output_dir is not None:
@@ -640,6 +646,9 @@ def main():
dtype=getattr(jnp, model_args.dtype),
)
if training_args.gradient_checkpointing:
model.enable_gradient_checkpointing()
# Store some constant
num_epochs = int(training_args.num_train_epochs)
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()

View File

@@ -121,6 +121,12 @@ class TrainingArguments:
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
)
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
gradient_checkpointing: bool = field(
default=False,
metadata={
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
},
)
def __post_init__(self):
if self.output_dir is not None:
@@ -535,6 +541,9 @@ def main():
dtype=getattr(jnp, model_args.dtype),
)
if training_args.gradient_checkpointing:
model.enable_gradient_checkpointing()
if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")