fix flax examples tests (#14646)

* make tensorboard optional

* update test_fetcher for flax examples

* make the tests slow
This commit is contained in:
Suraj Patil
2021-12-07 00:34:27 +05:30
committed by GitHub
parent 03fda7b743
commit 75ae287aec
5 changed files with 74 additions and 15 deletions

View File

@@ -33,11 +33,16 @@ import optax
import transformers
from flax import struct, traverse_util
from flax.jax_utils import replicate, unreplicate
from flax.metrics import tensorboard
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository
from transformers import AutoConfig, AutoTokenizer, FlaxAutoModelForSequenceClassification, PretrainedConfig
from transformers import (
AutoConfig,
AutoTokenizer,
FlaxAutoModelForSequenceClassification,
PretrainedConfig,
is_tensorboard_available,
)
from transformers.file_utils import get_full_repo_name
@@ -404,8 +409,23 @@ def main():
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
# Define a summary writer
summary_writer = tensorboard.SummaryWriter(args.output_dir)
summary_writer.hparams(vars(args))
has_tensorboard = is_tensorboard_available()
if has_tensorboard and jax.process_index() == 0:
try:
from flax.metrics.tensorboard import SummaryWriter
summary_writer = SummaryWriter(args.output_dir)
summary_writer.hparams(vars(args))
except ImportError as ie:
has_tensorboard = False
logger.warning(
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
)
else:
logger.warning(
"Unable to display metrics through TensorBoard because the package is not installed: "
"Please run pip install tensorboard to enable."
)
def write_metric(train_metrics, eval_metrics, train_time, step):
summary_writer.scalar("train_time", train_time, step)
@@ -513,7 +533,10 @@ def main():
logger.info(f" Done! Eval metrics: {eval_metric}")
cur_step = epoch * (len(train_dataset) // train_batch_size)
write_metric(train_metrics, eval_metric, train_time, cur_step)
# Save metrics
if has_tensorboard and jax.process_index() == 0:
write_metric(train_metrics, eval_metric, train_time, cur_step)
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0: