Fix metric calculation in examples and setup tests to run on multi-gpu for no_trainer scripts (#17331)

* Fix length in no_trainer examples

* Add setup and teardown

* Use new accelerator config generator to automatically make tests able to run based on environment
This commit is contained in:
Zachary Mueller
2022-05-18 14:17:40 -04:00
committed by GitHub
parent 6e195eb9de
commit 1762ded30a
8 changed files with 91 additions and 121 deletions

View File

@@ -310,7 +310,9 @@ def parse_args():
def main():
args = parse_args()
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
# If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
if args.source_prefix is None and args.model_name_or_path in [
"t5-small",
"t5-base",
@@ -322,9 +324,6 @@ def main():
"You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with "
"`--source_prefix 'summarize: ' `"
)
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
# If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -675,11 +674,11 @@ def main():
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
# If we are in a multiprocess environment, the last batch has duplicates
if accelerator.num_processes > 1:
if step == len(eval_dataloader):
if step == len(eval_dataloader) - 1:
decoded_preds = decoded_preds[: len(eval_dataloader.dataset) - samples_seen]
decoded_labels = decoded_labels[: len(eval_dataloader.dataset) - samples_seen]
else:
samples_seen += decoded_labels.shape[0]
samples_seen += len(decoded_labels)
metric.add_batch(
predictions=decoded_preds,