Replace accelerator.use_fp16 in examples (#33513)
* Replace `accelerator.use_fp16` in examples * pad_to_multiple_of=16 for fp8
This commit is contained in:
@@ -542,9 +542,14 @@ def main():
|
||||
# Otherwise, `DataCollatorForTokenClassification` will apply dynamic padding for us (by padding to the maximum length of
|
||||
# the samples passed). When using mixed precision, we add `pad_to_multiple_of=8` to pad all tensors to multiple
|
||||
# of 8s, which will enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).
|
||||
data_collator = DataCollatorForLukeTokenClassification(
|
||||
tokenizer, pad_to_multiple_of=(8 if accelerator.use_fp16 else None)
|
||||
)
|
||||
# For fp8, we pad to multiple of 16.
|
||||
if accelerator.mixed_precision == "fp8":
|
||||
pad_to_multiple_of = 16
|
||||
elif accelerator.mixed_precision != "no":
|
||||
pad_to_multiple_of = 8
|
||||
else:
|
||||
pad_to_multiple_of = None
|
||||
data_collator = DataCollatorForLukeTokenClassification(tokenizer, pad_to_multiple_of=pad_to_multiple_of)
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size
|
||||
|
||||
@@ -704,7 +704,14 @@ def finetune(accelerator, model_name_or_path, train_file, output_dir, **kwargs):
|
||||
# precision, we add `pad_to_multiple_of=8` to pad all tensors to multiple of
|
||||
# 8s, which will enable the use of Tensor Cores on NVIDIA hardware with
|
||||
# compute capability >= 7.5 (Volta).
|
||||
data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=(8 if accelerator.use_fp16 else None))
|
||||
# For fp8, we pad to multiple of 16.
|
||||
if accelerator.mixed_precision == "fp8":
|
||||
pad_to_multiple_of = 16
|
||||
elif accelerator.mixed_precision != "no":
|
||||
pad_to_multiple_of = 8
|
||||
else:
|
||||
pad_to_multiple_of = None
|
||||
data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=pad_to_multiple_of)
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
|
||||
Reference in New Issue
Block a user