Fix metric calculation in examples and setup tests to run on multi-gpu for no_trainer scripts (#17331)
* Fix length in no_trainer examples * Add setup and teardown * Use new accelerator config generator to automatically make tests able to run based on environment
This commit is contained in:
@@ -310,7 +310,9 @@ def parse_args():
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
|
||||
# If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
|
||||
accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
|
||||
if args.source_prefix is None and args.model_name_or_path in [
|
||||
"t5-small",
|
||||
"t5-base",
|
||||
@@ -322,9 +324,6 @@ def main():
|
||||
"You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with "
|
||||
"`--source_prefix 'summarize: ' `"
|
||||
)
|
||||
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
|
||||
# If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
|
||||
accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
@@ -675,11 +674,11 @@ def main():
|
||||
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
|
||||
# If we are in a multiprocess environment, the last batch has duplicates
|
||||
if accelerator.num_processes > 1:
|
||||
if step == len(eval_dataloader):
|
||||
if step == len(eval_dataloader) - 1:
|
||||
decoded_preds = decoded_preds[: len(eval_dataloader.dataset) - samples_seen]
|
||||
decoded_labels = decoded_labels[: len(eval_dataloader.dataset) - samples_seen]
|
||||
else:
|
||||
samples_seen += decoded_labels.shape[0]
|
||||
samples_seen += len(decoded_labels)
|
||||
|
||||
metric.add_batch(
|
||||
predictions=decoded_preds,
|
||||
|
||||
Reference in New Issue
Block a user