fix flax examples tests (#14646)

* make tensorboard optional

* update test_fetcher for flax examples

* make the tests slow
This commit is contained in:
Suraj Patil
2021-12-07 00:34:27 +05:30
committed by GitHub
parent 03fda7b743
commit 75ae287aec
5 changed files with 74 additions and 15 deletions

View File

@@ -36,7 +36,6 @@ import optax
import transformers
from flax import struct, traverse_util
from flax.jax_utils import replicate, unreplicate
from flax.metrics import tensorboard
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository
@@ -46,6 +45,7 @@ from transformers import (
FlaxAutoModelForTokenClassification,
HfArgumentParser,
TrainingArguments,
is_tensorboard_available,
)
from transformers.file_utils import get_full_repo_name
from transformers.utils import check_min_version
@@ -472,8 +472,23 @@ def main():
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
# Define a summary writer
summary_writer = tensorboard.SummaryWriter(training_args.output_dir)
summary_writer.hparams({**training_args.to_dict(), **vars(model_args), **vars(data_args)})
has_tensorboard = is_tensorboard_available()
if has_tensorboard and jax.process_index() == 0:
try:
from flax.metrics.tensorboard import SummaryWriter
summary_writer = SummaryWriter(training_args.output_dir)
summary_writer.hparams({**training_args.to_dict(), **vars(model_args), **vars(data_args)})
except ImportError as ie:
has_tensorboard = False
logger.warning(
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
)
else:
logger.warning(
"Unable to display metrics through TensorBoard because the package is not installed: "
"Please run pip install tensorboard to enable."
)
def write_train_metric(summary_writer, train_metrics, train_time, step):
summary_writer.scalar("train_time", train_time, step)
@@ -605,7 +620,7 @@ def main():
# Save metrics
train_metric = unreplicate(train_metric)
train_time += time.time() - train_start
if jax.process_index() == 0:
if has_tensorboard and jax.process_index() == 0:
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
epochs.write(
@@ -663,7 +678,7 @@ def main():
f"Step... ({cur_step}/{total_steps} | Validation f1: {eval_metrics['f1']}, Validation Acc: {eval_metrics['accuracy']})"
)
if jax.process_index() == 0:
if has_tensorboard and jax.process_index() == 0:
write_eval_metric(summary_writer, eval_metrics, cur_step)
if (cur_step % training_args.save_steps == 0 and cur_step > 0) or (cur_step == total_steps):