From 7d6285a921a23c06169e2d90c94faa0d92d00d78 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 5 Jul 2021 23:49:47 +0100 Subject: [PATCH] [Wav2Vec2] Flax - Adapt wav2vec2 script (#12520) * fix_torch_device_generate_test * remove @ * adapt flax pretrain script --- .../wav2vec2/run_wav2vec2_pretrain_flax.py | 37 +++++++++++++------ 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py b/examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py index d0e60b8def..a0a7d38f85 100755 --- a/examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py +++ b/examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py @@ -64,6 +64,12 @@ class ModelArguments: gumbel_temperature_decay: Optional[float] = field( default=0.999995, metadata={"help": "Decay of gumbel temperature during training."} ) + dtype: Optional[str] = field( + default="float32", + metadata={ + "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`." + }, + ) @flax.struct.dataclass @@ -197,7 +203,7 @@ def configure_logger(model_args: ModelArguments, training_args: TrainingArgument logger.setLevel(logging_level) -def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step): +def write_train_metric(summary_writer, train_metrics, train_time, step): summary_writer.scalar("train_time", train_time, step) train_metrics = get_metrics(train_metrics) @@ -206,6 +212,8 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step): for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) + +def write_eval_metric(summary_writer, eval_metrics, step): for metric_name, value in eval_metrics.items(): summary_writer.scalar(f"eval_{metric_name}", value, step) @@ -342,9 +350,7 @@ def main(): "PreTraining is only supported for ``config.do_stable_layer_norm=True`` and ``config.feat_extract_norm='layer'" ) - model = FlaxWav2Vec2ForPreTraining( - config, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype) - ) + model = FlaxWav2Vec2ForPreTraining(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)) data_collator = FlaxDataCollatorForWav2Vec2Pretraining( model=model, feature_extractor=feature_extractor, pad_to_multiple_of=data_args.pad_to_multiple_of @@ -501,11 +507,11 @@ def main(): state = jax_utils.replicate(state) train_time = 0 + train_metrics = [] epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) for epoch in epochs: # ======================== Training ================================ train_start = time.time() - train_metrics = [] # Create sampling rng rng, input_rng = jax.random.split(rng) @@ -516,7 +522,7 @@ def main(): train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size) # Gather the indexes for creating the batch and do a training step - for i, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)): + for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)): samples = [vectorized_datasets["train"][int(idx)] for idx in batch_idx] model_inputs = data_collator(samples) model_inputs = shard(model_inputs.data) @@ -527,11 +533,20 @@ def main(): ) train_metrics.append(train_metric) - train_time += time.time() - train_start + cur_step = epoch * (num_train_samples // train_batch_size) + step - epochs.write( - f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})" - ) + if cur_step % training_args.logging_steps == 0 and cur_step > 0: + # Save metrics + train_metric = jax_utils.unreplicate(train_metric) + train_time += time.time() - train_start + if has_tensorboard and jax.process_index() == 0: + write_train_metric(summary_writer, train_metrics, train_time, cur_step) + + epochs.write( + f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})" + ) + + train_metrics = [] # ======================== Evaluating ============================== num_eval_samples = len(vectorized_datasets["validation"]) @@ -560,7 +575,7 @@ def main(): # Save metrics if has_tensorboard and jax.process_index() == 0: cur_step = epoch * (len(vectorized_datasets["train"]) // train_batch_size) - write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step) + write_eval_metric(summary_writer, eval_metrics, cur_step) # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: