From bb4ac2b5a8df9d4aaf6cd04e5fd104568a4bb351 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 5 Jul 2021 18:14:50 +0100 Subject: [PATCH] [Flax] Correct flax training scripts (#12514) * fix_torch_device_generate_test * remove @ * add logging steps * correct training scripts * correct readme * correct --- examples/flax/language-modeling/README.md | 4 +- .../flax/language-modeling/run_clm_flax.py | 59 ++++++++++-------- .../flax/language-modeling/run_mlm_flax.py | 61 ++++++++++--------- .../flax/language-modeling/run_t5_mlm_flax.py | 25 +++++--- 4 files changed, 87 insertions(+), 62 deletions(-) diff --git a/examples/flax/language-modeling/README.md b/examples/flax/language-modeling/README.md index 81fdca27e0..48b5613283 100644 --- a/examples/flax/language-modeling/README.md +++ b/examples/flax/language-modeling/README.md @@ -137,10 +137,10 @@ Next we can run the example script to pretrain the model: --learning_rate="3e-4" \ --warmup_steps="1000" \ --overwrite_output_dir \ - --pad_to_max_length \ --num_train_epochs="18" \ --adam_beta1="0.9" \ --adam_beta2="0.98" \ + --logging_steps="500" \ --push_to_hub ``` @@ -233,6 +233,7 @@ Next we can run the example script to pretrain the model: --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \ --overwrite_output_dir \ --num_train_epochs="20" \ + --logging_steps="500" \ --push_to_hub ``` @@ -368,6 +369,7 @@ Next we can run the example script to pretrain the model: --warmup_steps="5000" \ --overwrite_output_dir \ --num_train_epochs="10" \ + --logging_steps="500" \ --push_to_hub ``` diff --git a/examples/flax/language-modeling/run_clm_flax.py b/examples/flax/language-modeling/run_clm_flax.py index e664e5718a..8ade7d4284 100755 --- a/examples/flax/language-modeling/run_clm_flax.py +++ b/examples/flax/language-modeling/run_clm_flax.py @@ -57,22 +57,6 @@ from transformers.testing_utils import CaptureLogger logger = logging.getLogger(__name__) -# Cache the result -has_tensorboard = is_tensorboard_available() -if has_tensorboard: - try: - from flax.metrics.tensorboard import SummaryWriter - except ImportError as ie: - has_tensorboard = False - print(f"Unable to display metrics through TensorBoard because some package are not installed: {ie}") - -else: - print( - "Unable to display metrics through TensorBoard because the package is not installed: " - "Please run pip install tensorboard to enable." - ) - - MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_CAUSAL_LM_MAPPING.keys()) MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) @@ -214,7 +198,7 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf yield batch -def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step): +def write_train_metric(summary_writer, train_metrics, train_time, step): summary_writer.scalar("train_time", train_time, step) train_metrics = get_metrics(train_metrics) @@ -223,6 +207,8 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step): for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) + +def write_eval_metric(summary_writer, eval_metrics, step): for metric_name, value in eval_metrics.items(): summary_writer.scalar(f"eval_{metric_name}", value, step) @@ -450,8 +436,22 @@ def main(): eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) # Enable tensorboard only on the master node + has_tensorboard = is_tensorboard_available() if has_tensorboard and jax.process_index() == 0: - summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir)) + try: + from flax.metrics.tensorboard import SummaryWriter + + summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir)) + except ImportError as ie: + has_tensorboard = False + logger.warning( + f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" + ) + else: + logger.warning( + "Unable to display metrics through TensorBoard because the package is not installed: " + "Please run pip install tensorboard to enable." + ) # Initialize our training rng = jax.random.PRNGKey(training_args.seed) @@ -554,6 +554,7 @@ def main(): logger.info(f" Total optimization steps = {total_train_steps}") train_time = 0 + train_metrics = [] epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) for epoch in epochs: # ======================== Training ================================ @@ -561,24 +562,30 @@ def main(): # Create sampling rng rng, input_rng = jax.random.split(rng) - train_metrics = [] # Generate an epoch by shuffling sampling indices from the train dataset train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True) steps_per_epoch = len(train_dataset) // train_batch_size # train - for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): + for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): batch = next(train_loader) state, train_metric = p_train_step(state, batch) train_metrics.append(train_metric) - train_time += time.time() - train_start + cur_step = epoch * (len(train_dataset) // train_batch_size) + step - train_metric = unreplicate(train_metric) + if cur_step % training_args.logging_steps and cur_step > 0: + # Save metrics + train_metric = unreplicate(train_metric) + train_time += time.time() - train_start + if has_tensorboard and jax.process_index() == 0: + write_train_metric(summary_writer, train_metrics, train_time, cur_step) - epochs.write( - f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})" - ) + epochs.write( + f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})" + ) + + train_metrics = [] # ======================== Evaluating ============================== eval_metrics = [] @@ -608,7 +615,7 @@ def main(): # Save metrics if has_tensorboard and jax.process_index() == 0: cur_step = epoch * (len(train_dataset) // train_batch_size) - write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step) + write_eval_metric(summary_writer, eval_metrics, cur_step) # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: diff --git a/examples/flax/language-modeling/run_mlm_flax.py b/examples/flax/language-modeling/run_mlm_flax.py index e3058c4ca7..5d9fda11a3 100755 --- a/examples/flax/language-modeling/run_mlm_flax.py +++ b/examples/flax/language-modeling/run_mlm_flax.py @@ -56,22 +56,6 @@ from transformers import ( ) -# Cache the result -has_tensorboard = is_tensorboard_available() -if has_tensorboard: - try: - from flax.metrics.tensorboard import SummaryWriter - except ImportError as ie: - has_tensorboard = False - print(f"Unable to display metrics through TensorBoard because some package are not installed: {ie}") - -else: - print( - "Unable to display metrics through TensorBoard because the package is not installed: " - "Please run pip install tensorboard to enable." - ) - - MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys()) MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) @@ -269,7 +253,7 @@ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndar return batch_idx -def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step): +def write_train_metric(summary_writer, train_metrics, train_time, step): summary_writer.scalar("train_time", train_time, step) train_metrics = get_metrics(train_metrics) @@ -278,6 +262,8 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step): for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) + +def write_eval_metric(summary_writer, eval_metrics, step): for metric_name, value in eval_metrics.items(): summary_writer.scalar(f"eval_{metric_name}", value, step) @@ -315,10 +301,6 @@ if __name__ == "__main__": # Log on each process the small summary: logger = logging.getLogger(__name__) - logger.warning( - f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" - + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" - ) # Set the verbosity to info of the Transformers logger (on main process only): logger.info(f"Training/evaluation parameters {training_args}") @@ -471,8 +453,22 @@ if __name__ == "__main__": ) # Enable tensorboard only on the master node + has_tensorboard = is_tensorboard_available() if has_tensorboard and jax.process_index() == 0: - summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir)) + try: + from flax.metrics.tensorboard import SummaryWriter + + summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir)) + except ImportError as ie: + has_tensorboard = False + logger.warning( + f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" + ) + else: + logger.warning( + "Unable to display metrics through TensorBoard because the package is not installed: " + "Please run pip install tensorboard to enable." + ) # Data collator # This one will take care of randomly masking the tokens. @@ -601,7 +597,7 @@ if __name__ == "__main__": train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size) # Gather the indexes for creating the batch and do a training step - for i, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)): + for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)): samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx] model_inputs = data_collator(samples, pad_to_multiple_of=16) @@ -610,11 +606,20 @@ if __name__ == "__main__": state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs) train_metrics.append(train_metric) - train_time += time.time() - train_start + cur_step = epoch * num_train_samples + step - epochs.write( - f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})" - ) + if cur_step % training_args.logging_steps and cur_step > 0: + # Save metrics + train_metric = jax_utils.unreplicate(train_metric) + train_time += time.time() - train_start + if has_tensorboard and jax.process_index() == 0: + write_train_metric(summary_writer, train_metrics, train_time, cur_step) + + epochs.write( + f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})" + ) + + train_metrics = [] # ======================== Evaluating ============================== num_eval_samples = len(tokenized_datasets["validation"]) @@ -645,7 +650,7 @@ if __name__ == "__main__": # Save metrics if has_tensorboard and jax.process_index() == 0: cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size) - write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step) + write_eval_metric(summary_writer, eval_metrics, cur_step) # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: diff --git a/examples/flax/language-modeling/run_t5_mlm_flax.py b/examples/flax/language-modeling/run_t5_mlm_flax.py index 49f4cf1d79..795dc7faeb 100755 --- a/examples/flax/language-modeling/run_t5_mlm_flax.py +++ b/examples/flax/language-modeling/run_t5_mlm_flax.py @@ -382,7 +382,7 @@ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndar return batch_idx -def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step): +def write_train_metric(summary_writer, train_metrics, train_time, step): summary_writer.scalar("train_time", train_time, step) train_metrics = get_metrics(train_metrics) @@ -391,6 +391,8 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step): for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) + +def write_eval_metric(summary_writer, eval_metrics, step): for metric_name, value in eval_metrics.items(): summary_writer.scalar(f"eval_{metric_name}", value, step) @@ -711,7 +713,7 @@ if __name__ == "__main__": train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size) # Gather the indexes for creating the batch and do a training step - for i, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)): + for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)): samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx] model_inputs = data_collator(samples) @@ -720,11 +722,20 @@ if __name__ == "__main__": state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs) train_metrics.append(train_metric) - train_time += time.time() - train_start + cur_step = epoch * num_train_samples + step - epochs.write( - f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})" - ) + if cur_step % training_args.logging_steps and cur_step > 0: + # Save metrics + train_metric = jax_utils.unreplicate(train_metric) + train_time += time.time() - train_start + if has_tensorboard and jax.process_index() == 0: + write_train_metric(summary_writer, train_metrics, train_time, cur_step) + + epochs.write( + f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})" + ) + + train_metrics = [] # ======================== Evaluating ============================== num_eval_samples = len(tokenized_datasets["validation"]) @@ -753,7 +764,7 @@ if __name__ == "__main__": # Save metrics if has_tensorboard and jax.process_index() == 0: cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size) - write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step) + write_eval_metric(summary_writer, eval_metrics, cur_step) # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: