From 7544efc92e70dba40c8fde43fe5a9b156c375810 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Carlos=20Falc=C3=A3o=20Petri?= Date: Wed, 17 Nov 2021 14:37:21 -0300 Subject: [PATCH] [Gradient checkpoining] Update Wav2Vec scripts (#14036) Co-authored-by: Stas Bekman --- .../jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py b/examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py index 7eb286b496..4911ecb571 100755 --- a/examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py +++ b/examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py @@ -48,9 +48,6 @@ class ModelArguments: freeze_feature_extractor: Optional[bool] = field( default=True, metadata={"help": "Whether to freeze the feature extractor layers of the model."} ) - gradient_checkpointing: Optional[bool] = field( - default=False, metadata={"help": "Whether to freeze the feature extractor layers of the model."} - ) verbose_logging: Optional[bool] = field( default=False, metadata={"help": "Whether to log verbose messages or not."}, @@ -356,7 +353,6 @@ def main(): config = Wav2Vec2Config.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, - gradient_checkpointing=model_args.gradient_checkpointing, ) if not config.do_stable_layer_norm or config.feat_extract_norm != "layer": @@ -366,6 +362,10 @@ def main(): model = FlaxWav2Vec2ForPreTraining(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)) + # Activate gradient checkpointing if needed + if training_args.gradient_checkpointing: + model.gradient_checkpointing_enable() + data_collator = FlaxDataCollatorForWav2Vec2Pretraining( model=model, feature_extractor=feature_extractor, pad_to_multiple_of=data_args.pad_to_multiple_of )