Default to accuracy metric (#11405)

This commit is contained in:
Sylvain Gugger
2021-04-23 14:49:59 -04:00
committed by GitHub
parent e3ff165aa5
commit 1ef152eb48

View File

@@ -367,6 +367,8 @@ def main():
# Get the metric function # Get the metric function
if args.task_name is not None: if args.task_name is not None:
metric = load_metric("glue", args.task_name) metric = load_metric("glue", args.task_name)
else:
metric = load_metric("accuracy")
# Train! # Train!
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps