Replace accelerator.use_fp16 in examples (#33513)
* Replace `accelerator.use_fp16` in examples * pad_to_multiple_of=16 for fp8
This commit is contained in:
@@ -534,11 +534,17 @@ def main():
|
||||
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
|
||||
|
||||
label_pad_token_id = -100 if args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||
if accelerator.mixed_precision == "fp8":
|
||||
pad_to_multiple_of = 16
|
||||
elif accelerator.mixed_precision != "no":
|
||||
pad_to_multiple_of = 8
|
||||
else:
|
||||
pad_to_multiple_of = None
|
||||
data_collator = DataCollatorForSeq2Seq(
|
||||
tokenizer,
|
||||
model=model,
|
||||
label_pad_token_id=label_pad_token_id,
|
||||
pad_to_multiple_of=8 if accelerator.use_fp16 else None,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
)
|
||||
|
||||
def postprocess_text(preds, labels):
|
||||
|
||||
Reference in New Issue
Block a user