Replace accelerator.use_fp16 in examples (#33513)

* Replace `accelerator.use_fp16` in examples

* pad_to_multiple_of=16 for fp8
This commit is contained in:
hlky
2024-09-17 03:13:06 +01:00
committed by GitHub
parent ba1f1dc132
commit 9f196ef2e0
10 changed files with 79 additions and 16 deletions

View File

@@ -534,11 +534,17 @@ def main():
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
label_pad_token_id = -100 if args.ignore_pad_token_for_loss else tokenizer.pad_token_id
if accelerator.mixed_precision == "fp8":
pad_to_multiple_of = 16
elif accelerator.mixed_precision != "no":
pad_to_multiple_of = 8
else:
pad_to_multiple_of = None
data_collator = DataCollatorForSeq2Seq(
tokenizer,
model=model,
label_pad_token_id=label_pad_token_id,
pad_to_multiple_of=8 if accelerator.use_fp16 else None,
pad_to_multiple_of=pad_to_multiple_of,
)
def postprocess_text(preds, labels):