Fixup no_trainer examples scripts and add more tests (#16765)

* Change tracking to store_true

* Remove step param and use it in the log dictionary directly

* use vars(args) when passing args to init_trackers

* Include tracking tests since tensorboard is already a dep
This commit is contained in:
Zachary Mueller
2022-04-13 14:40:48 -04:00
committed by GitHub
parent 3a16ab25c8
commit be752d12f8
10 changed files with 84 additions and 63 deletions

View File

@@ -219,7 +219,7 @@ def parse_args():
)
parser.add_argument(
"--with_tracking",
required=False,
action="store_true",
help="Whether to load in all available experiment trackers from the environment and use them for logging.",
)
args = parser.parse_args()
@@ -246,7 +246,7 @@ def main():
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
# If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
accelerator = Accelerator(log_with="all") if args.with_tracking else Accelerator()
accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -545,7 +545,10 @@ def main():
# We need to initialize the trackers we use, and also store our configuration
if args.with_tracking:
accelerator.init_trackers("clm_no_trainer", args)
experiment_config = vars(args)
# TensorBoard cannot log Enums, need the raw value
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
accelerator.init_trackers("ner_no_trainer", experiment_config)
# Metrics
metric = load_metric("seqeval")
@@ -676,12 +679,7 @@ def main():
accelerator.print(f"epoch {epoch}:", eval_metric)
if args.with_tracking:
accelerator.log(
{
"seqeval": eval_metric,
"train_loss": total_loss,
"epoch": epoch,
},
step=completed_steps,
{"seqeval": eval_metric, "train_loss": total_loss, "epoch": epoch, "step": completed_steps},
)
if args.push_to_hub and epoch < args.num_train_epochs - 1: