From 75ae287aecf20a37c232a41e25443a3421a8b5e2 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Tue, 7 Dec 2021 00:34:27 +0530 Subject: [PATCH] fix flax examples tests (#14646) * make tensorboard optional * update test_fetcher for flax examples * make the tests slow --- examples/flax/question-answering/run_qa.py | 25 +++++++++++--- examples/flax/test_examples.py | 4 +++ .../flax/text-classification/run_flax_glue.py | 33 ++++++++++++++++--- .../flax/token-classification/run_flax_ner.py | 25 +++++++++++--- utils/tests_fetcher.py | 2 ++ 5 files changed, 74 insertions(+), 15 deletions(-) diff --git a/examples/flax/question-answering/run_qa.py b/examples/flax/question-answering/run_qa.py index 38d8229966..0e7f1ad26c 100644 --- a/examples/flax/question-answering/run_qa.py +++ b/examples/flax/question-answering/run_qa.py @@ -40,7 +40,6 @@ import optax import transformers from flax import struct, traverse_util from flax.jax_utils import replicate, unreplicate -from flax.metrics import tensorboard from flax.training import train_state from flax.training.common_utils import get_metrics, onehot, shard from huggingface_hub import Repository @@ -52,6 +51,7 @@ from transformers import ( HfArgumentParser, PreTrainedTokenizerFast, TrainingArguments, + is_tensorboard_available, ) from transformers.file_utils import get_full_repo_name from transformers.utils import check_min_version @@ -716,8 +716,23 @@ def main(): logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") # Define a summary writer - summary_writer = tensorboard.SummaryWriter(training_args.output_dir) - summary_writer.hparams({**training_args.to_dict(), **vars(model_args), **vars(data_args)}) + has_tensorboard = is_tensorboard_available() + if has_tensorboard and jax.process_index() == 0: + try: + from flax.metrics.tensorboard import SummaryWriter + + summary_writer = SummaryWriter(training_args.output_dir) + summary_writer.hparams({**training_args.to_dict(), **vars(model_args), **vars(data_args)}) + except ImportError as ie: + has_tensorboard = False + logger.warning( + f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" + ) + else: + logger.warning( + "Unable to display metrics through TensorBoard because the package is not installed: " + "Please run pip install tensorboard to enable." + ) def write_train_metric(summary_writer, train_metrics, train_time, step): summary_writer.scalar("train_time", train_time, step) @@ -833,7 +848,7 @@ def main(): # Save metrics train_metric = unreplicate(train_metric) train_time += time.time() - train_start - if jax.process_index() == 0: + if has_tensorboard and jax.process_index() == 0: write_train_metric(summary_writer, train_metrics, train_time, cur_step) epochs.write( @@ -898,7 +913,7 @@ def main(): logger.info(f"Step... ({cur_step}/{total_steps} | Evaluation metrics: {eval_metrics})") - if jax.process_index() == 0: + if has_tensorboard and jax.process_index() == 0: write_eval_metric(summary_writer, eval_metrics, cur_step) if (cur_step % training_args.save_steps == 0 and cur_step > 0) or (cur_step == total_steps): diff --git a/examples/flax/test_examples.py b/examples/flax/test_examples.py index f46e3ac75f..2f6f83cc84 100644 --- a/examples/flax/test_examples.py +++ b/examples/flax/test_examples.py @@ -96,6 +96,7 @@ class ExamplesTests(TestCasePlus): result = get_results(tmp_dir) self.assertGreaterEqual(result["eval_accuracy"], 0.75) + @slow def test_run_clm(self): stream_handler = logging.StreamHandler(sys.stdout) logger.addHandler(stream_handler) @@ -155,6 +156,7 @@ class ExamplesTests(TestCasePlus): self.assertGreaterEqual(result["test_rougeL"], 7) self.assertGreaterEqual(result["test_rougeLsum"], 7) + @slow def test_run_mlm(self): stream_handler = logging.StreamHandler(sys.stdout) logger.addHandler(stream_handler) @@ -208,6 +210,7 @@ class ExamplesTests(TestCasePlus): result = get_results(tmp_dir) self.assertGreaterEqual(result["eval_accuracy"], 0.42) + @slow def test_run_ner(self): stream_handler = logging.StreamHandler(sys.stdout) logger.addHandler(stream_handler) @@ -240,6 +243,7 @@ class ExamplesTests(TestCasePlus): self.assertGreaterEqual(result["eval_accuracy"], 0.75) self.assertGreaterEqual(result["eval_f1"], 0.3) + @slow def test_run_qa(self): stream_handler = logging.StreamHandler(sys.stdout) logger.addHandler(stream_handler) diff --git a/examples/flax/text-classification/run_flax_glue.py b/examples/flax/text-classification/run_flax_glue.py index 9044331db5..f27b7cd05c 100755 --- a/examples/flax/text-classification/run_flax_glue.py +++ b/examples/flax/text-classification/run_flax_glue.py @@ -33,11 +33,16 @@ import optax import transformers from flax import struct, traverse_util from flax.jax_utils import replicate, unreplicate -from flax.metrics import tensorboard from flax.training import train_state from flax.training.common_utils import get_metrics, onehot, shard from huggingface_hub import Repository -from transformers import AutoConfig, AutoTokenizer, FlaxAutoModelForSequenceClassification, PretrainedConfig +from transformers import ( + AutoConfig, + AutoTokenizer, + FlaxAutoModelForSequenceClassification, + PretrainedConfig, + is_tensorboard_available, +) from transformers.file_utils import get_full_repo_name @@ -404,8 +409,23 @@ def main(): logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") # Define a summary writer - summary_writer = tensorboard.SummaryWriter(args.output_dir) - summary_writer.hparams(vars(args)) + has_tensorboard = is_tensorboard_available() + if has_tensorboard and jax.process_index() == 0: + try: + from flax.metrics.tensorboard import SummaryWriter + + summary_writer = SummaryWriter(args.output_dir) + summary_writer.hparams(vars(args)) + except ImportError as ie: + has_tensorboard = False + logger.warning( + f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" + ) + else: + logger.warning( + "Unable to display metrics through TensorBoard because the package is not installed: " + "Please run pip install tensorboard to enable." + ) def write_metric(train_metrics, eval_metrics, train_time, step): summary_writer.scalar("train_time", train_time, step) @@ -513,7 +533,10 @@ def main(): logger.info(f" Done! Eval metrics: {eval_metric}") cur_step = epoch * (len(train_dataset) // train_batch_size) - write_metric(train_metrics, eval_metric, train_time, cur_step) + + # Save metrics + if has_tensorboard and jax.process_index() == 0: + write_metric(train_metrics, eval_metric, train_time, cur_step) # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: diff --git a/examples/flax/token-classification/run_flax_ner.py b/examples/flax/token-classification/run_flax_ner.py index 3a49ff1fc3..9109c04692 100644 --- a/examples/flax/token-classification/run_flax_ner.py +++ b/examples/flax/token-classification/run_flax_ner.py @@ -36,7 +36,6 @@ import optax import transformers from flax import struct, traverse_util from flax.jax_utils import replicate, unreplicate -from flax.metrics import tensorboard from flax.training import train_state from flax.training.common_utils import get_metrics, onehot, shard from huggingface_hub import Repository @@ -46,6 +45,7 @@ from transformers import ( FlaxAutoModelForTokenClassification, HfArgumentParser, TrainingArguments, + is_tensorboard_available, ) from transformers.file_utils import get_full_repo_name from transformers.utils import check_min_version @@ -472,8 +472,23 @@ def main(): logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") # Define a summary writer - summary_writer = tensorboard.SummaryWriter(training_args.output_dir) - summary_writer.hparams({**training_args.to_dict(), **vars(model_args), **vars(data_args)}) + has_tensorboard = is_tensorboard_available() + if has_tensorboard and jax.process_index() == 0: + try: + from flax.metrics.tensorboard import SummaryWriter + + summary_writer = SummaryWriter(training_args.output_dir) + summary_writer.hparams({**training_args.to_dict(), **vars(model_args), **vars(data_args)}) + except ImportError as ie: + has_tensorboard = False + logger.warning( + f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" + ) + else: + logger.warning( + "Unable to display metrics through TensorBoard because the package is not installed: " + "Please run pip install tensorboard to enable." + ) def write_train_metric(summary_writer, train_metrics, train_time, step): summary_writer.scalar("train_time", train_time, step) @@ -605,7 +620,7 @@ def main(): # Save metrics train_metric = unreplicate(train_metric) train_time += time.time() - train_start - if jax.process_index() == 0: + if has_tensorboard and jax.process_index() == 0: write_train_metric(summary_writer, train_metrics, train_time, cur_step) epochs.write( @@ -663,7 +678,7 @@ def main(): f"Step... ({cur_step}/{total_steps} | Validation f1: {eval_metrics['f1']}, Validation Acc: {eval_metrics['accuracy']})" ) - if jax.process_index() == 0: + if has_tensorboard and jax.process_index() == 0: write_eval_metric(summary_writer, eval_metrics, cur_step) if (cur_step % training_args.save_steps == 0 and cur_step > 0) or (cur_step == total_steps): diff --git a/utils/tests_fetcher.py b/utils/tests_fetcher.py index dbfd77b62e..6b99db2305 100644 --- a/utils/tests_fetcher.py +++ b/utils/tests_fetcher.py @@ -431,6 +431,8 @@ def infer_tests_to_run(output_file, diff_with_last_commit=False, filters=None): # Example files are tested separately elif f.startswith("examples/pytorch"): test_files_to_run.append("examples/pytorch/test_examples.py") + elif f.startswith("examples/flax"): + test_files_to_run.append("examples/flax/test_examples.py") else: new_tests = module_to_test_file(f) if new_tests is not None: