Update no_trainer script for summarization (#19277)
* Update no_trainer script for summarization * removed unnecessary import * fixes notation mistake * removed: unused variable
This commit is contained in:
@@ -669,7 +669,6 @@ def main():
|
|||||||
"max_length": args.val_max_target_length if args is not None else config.max_length,
|
"max_length": args.val_max_target_length if args is not None else config.max_length,
|
||||||
"num_beams": args.num_beams,
|
"num_beams": args.num_beams,
|
||||||
}
|
}
|
||||||
samples_seen = 0
|
|
||||||
for step, batch in enumerate(eval_dataloader):
|
for step, batch in enumerate(eval_dataloader):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
generated_tokens = accelerator.unwrap_model(model).generate(
|
generated_tokens = accelerator.unwrap_model(model).generate(
|
||||||
@@ -686,7 +685,7 @@ def main():
|
|||||||
# If we did not pad to max length, we need to pad the labels too
|
# If we did not pad to max length, we need to pad the labels too
|
||||||
labels = accelerator.pad_across_processes(batch["labels"], dim=1, pad_index=tokenizer.pad_token_id)
|
labels = accelerator.pad_across_processes(batch["labels"], dim=1, pad_index=tokenizer.pad_token_id)
|
||||||
|
|
||||||
generated_tokens, labels = accelerator.gather((generated_tokens, labels))
|
generated_tokens, labels = accelerator.gather_for_metrics(generated_tokens, labels)
|
||||||
generated_tokens = generated_tokens.cpu().numpy()
|
generated_tokens = generated_tokens.cpu().numpy()
|
||||||
labels = labels.cpu().numpy()
|
labels = labels.cpu().numpy()
|
||||||
|
|
||||||
@@ -699,14 +698,8 @@ def main():
|
|||||||
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||||
|
|
||||||
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
|
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
|
||||||
# If we are in a multiprocess environment, the last batch has duplicates
|
|
||||||
if accelerator.num_processes > 1:
|
|
||||||
if step == len(eval_dataloader) - 1:
|
|
||||||
decoded_preds = decoded_preds[: len(eval_dataloader.dataset) - samples_seen]
|
|
||||||
decoded_labels = decoded_labels[: len(eval_dataloader.dataset) - samples_seen]
|
|
||||||
else:
|
|
||||||
samples_seen += len(decoded_labels)
|
|
||||||
|
|
||||||
|
decoded_preds, decoded_labels = accelerator.gather_for_metrics(decoded_preds, decoded_labels)
|
||||||
metric.add_batch(
|
metric.add_batch(
|
||||||
predictions=decoded_preds,
|
predictions=decoded_preds,
|
||||||
references=decoded_labels,
|
references=decoded_labels,
|
||||||
|
|||||||
Reference in New Issue
Block a user