Introduce AcceleratorConfig dataclass (#28664)
* Introduce acceleratorconfig dataclass * Extra second warn * Move import * Try moving import under is_accelerate_available * Quality * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Clean * Remove to_kwargs * Change version * Improve tests by including dispatch and split batches * Improve reliability * Update tests/trainer/test_trainer.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Fixup tests and review nits * Make tests pass * protect import * Protect import * Empty-Commit * Make training_args.to_dict handle the AcceleratorConfig --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -76,6 +76,7 @@ from .trainer_callback import (
|
|||||||
TrainerState,
|
TrainerState,
|
||||||
)
|
)
|
||||||
from .trainer_pt_utils import (
|
from .trainer_pt_utils import (
|
||||||
|
AcceleratorConfig,
|
||||||
DistributedTensorGatherer,
|
DistributedTensorGatherer,
|
||||||
IterableDatasetShard,
|
IterableDatasetShard,
|
||||||
LabelSmoother,
|
LabelSmoother,
|
||||||
@@ -4029,11 +4030,21 @@ class Trainer:
|
|||||||
gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)
|
gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)
|
||||||
|
|
||||||
# create accelerator object
|
# create accelerator object
|
||||||
|
accelerator_kwargs = {}
|
||||||
|
if self.args.accelerator_config is not None:
|
||||||
|
accelerator_kwargs = self.args.accelerator_config
|
||||||
|
# dict and AcceleratorConfigs are parseable, json files are not
|
||||||
|
if isinstance(accelerator_kwargs, AcceleratorConfig):
|
||||||
|
accelerator_kwargs = accelerator_kwargs.to_dict()
|
||||||
|
elif isinstance(accelerator_kwargs, dict):
|
||||||
|
# Some values may need to go through non-accelerate aligned defaults
|
||||||
|
# and we need to run the `__post_init__` to set them
|
||||||
|
accelerator_kwargs = AcceleratorConfig(**accelerator_kwargs).to_dict()
|
||||||
|
|
||||||
self.accelerator = Accelerator(
|
self.accelerator = Accelerator(
|
||||||
dispatch_batches=self.args.dispatch_batches,
|
|
||||||
split_batches=self.args.split_batches,
|
|
||||||
deepspeed_plugin=self.args.deepspeed_plugin,
|
deepspeed_plugin=self.args.deepspeed_plugin,
|
||||||
gradient_accumulation_plugin=gradient_accumulation_plugin,
|
gradient_accumulation_plugin=gradient_accumulation_plugin,
|
||||||
|
**accelerator_kwargs,
|
||||||
)
|
)
|
||||||
# some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
|
# some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
|
||||||
self.gather_function = self.accelerator.gather_for_metrics
|
self.gather_function = self.accelerator.gather_for_metrics
|
||||||
|
|||||||
@@ -16,7 +16,9 @@
|
|||||||
Torch utilities for the Trainer class.
|
Torch utilities for the Trainer class.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
import datetime
|
import datetime
|
||||||
|
import io
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@@ -24,7 +26,7 @@ import sys
|
|||||||
import warnings
|
import warnings
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from logging import StreamHandler
|
from logging import StreamHandler
|
||||||
from typing import Any, Dict, Iterator, List, Optional, Union
|
from typing import Any, Dict, Iterator, List, Optional, Union
|
||||||
|
|
||||||
@@ -1140,3 +1142,87 @@ if is_sagemaker_mp_enabled():
|
|||||||
# It doesn't seem possible to check here if `tensor` is a StepOutput because StepOutput lives in `smp.step`
|
# It doesn't seem possible to check here if `tensor` is a StepOutput because StepOutput lives in `smp.step`
|
||||||
# which is also the name of the decorator so Python is confused.
|
# which is also the name of the decorator so Python is confused.
|
||||||
return tensor.concat().detach().cpu()
|
return tensor.concat().detach().cpu()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AcceleratorConfig:
|
||||||
|
"""
|
||||||
|
A subset of arguments relating to the underlying [`accelerate.Accelerator`]
|
||||||
|
implementation utilized in the `Trainer` that can be customized.
|
||||||
|
Mostly relating to data.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
split_batches (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If
|
||||||
|
`True` the actual batch size used will be the same on any kind of distributed processes, but it must be a
|
||||||
|
round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set
|
||||||
|
in your script multiplied by the number of processes.
|
||||||
|
dispatch_batches (`bool`, *optional*):
|
||||||
|
If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process
|
||||||
|
and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose
|
||||||
|
underlying dataset is an `IterableDataset`, `False` otherwise.
|
||||||
|
even_batches (`bool`, *optional*, defaults to `True`):
|
||||||
|
If set to `True`, in cases where the total batch size across all processes does not exactly divide the
|
||||||
|
dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among
|
||||||
|
all workers.
|
||||||
|
use_seedable_sampler (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not use a fully seedable random sampler ([`accelerate.data_loader.SeedableRandomSampler`]). Ensures
|
||||||
|
training results are fully reproducable using a different sampling technique. While seed-to-seed results
|
||||||
|
may differ, on average the differences are neglible when using multiple different seeds to compare. Should
|
||||||
|
also be ran with [`~utils.set_seed`] for the best results.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Data related arguments
|
||||||
|
split_batches: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If"
|
||||||
|
" `True` the actual batch size used will be the same on any kind of distributed processes, but it must be a"
|
||||||
|
" round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set"
|
||||||
|
" in your script multiplied by the number of processes."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
dispatch_batches: bool = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process"
|
||||||
|
" and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose"
|
||||||
|
" underlying dataset is an `IterableDataslet`, `False` otherwise."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
even_batches: bool = field(
|
||||||
|
default=True,
|
||||||
|
metadata={
|
||||||
|
"help": "If set to `True`, in cases where the total batch size across all processes does not exactly divide the"
|
||||||
|
" dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among"
|
||||||
|
" all workers."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
use_seedable_sampler: bool = field(
|
||||||
|
default=True,
|
||||||
|
metadata={
|
||||||
|
"help": "Whether or not use a fully seedable random sampler ([`accelerate.data_loader.SeedableRandomSampler`])."
|
||||||
|
"Ensures training results are fully reproducable using a different sampling technique. "
|
||||||
|
"While seed-to-seed results may differ, on average the differences are neglible when using"
|
||||||
|
"multiple different seeds to compare. Should also be ran with [`~utils.set_seed`] for the best results."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json_file(cls, json_file):
|
||||||
|
# Check if exists
|
||||||
|
open_file = io.open if os.path.exists(json_file) else open
|
||||||
|
with open_file(json_file, "r", encoding="utf-8") as f:
|
||||||
|
config_dict = json.load(f)
|
||||||
|
# Check for keys and load sensible defaults
|
||||||
|
extra_keys = sorted(key for key in config_dict.keys() if key not in cls.__dataclass_fields__.keys())
|
||||||
|
if len(extra_keys) > 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"The config file at {json_file} had unknown keys ({extra_keys}), please try upgrading your `transformers`"
|
||||||
|
" version or fix (and potentially remove these keys) from your config file."
|
||||||
|
)
|
||||||
|
return cls(**config_dict)
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return copy.deepcopy(self.__dict__)
|
||||||
|
|||||||
@@ -70,6 +70,8 @@ if is_accelerate_available():
|
|||||||
from accelerate.state import AcceleratorState, PartialState
|
from accelerate.state import AcceleratorState, PartialState
|
||||||
from accelerate.utils import DistributedType
|
from accelerate.utils import DistributedType
|
||||||
|
|
||||||
|
from .trainer_pt_utils import AcceleratorConfig
|
||||||
|
|
||||||
if is_torch_tpu_available(check_device=False):
|
if is_torch_tpu_available(check_device=False):
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
@@ -487,6 +489,32 @@ class TrainingArguments:
|
|||||||
Use [Deepspeed](https://github.com/microsoft/deepspeed). This is an experimental feature and its API may
|
Use [Deepspeed](https://github.com/microsoft/deepspeed). This is an experimental feature and its API may
|
||||||
evolve in the future. The value is either the location of DeepSpeed json config file (e.g.,
|
evolve in the future. The value is either the location of DeepSpeed json config file (e.g.,
|
||||||
`ds_config.json`) or an already loaded json file as a `dict`"
|
`ds_config.json`) or an already loaded json file as a `dict`"
|
||||||
|
|
||||||
|
accelerator_config (`str`, `dict`, or `AcceleratorConfig`, *optional*):
|
||||||
|
Config to be used with the internal `Accelerator` implementation. The value is either a location of
|
||||||
|
accelerator json config file (e.g., `accelerator_config.json`), an already loaded json file as `dict`,
|
||||||
|
or an instance of [`~trainer_pt_utils.AcceleratorConfig`].
|
||||||
|
|
||||||
|
A list of config and its options:
|
||||||
|
- split_batches (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If
|
||||||
|
`True` the actual batch size used will be the same on any kind of distributed processes, but it must be a
|
||||||
|
round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set
|
||||||
|
in your script multiplied by the number of processes.
|
||||||
|
- dispatch_batches (`bool`, *optional*):
|
||||||
|
If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process
|
||||||
|
and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose
|
||||||
|
underlying dataset is an `IterableDataset`, `False` otherwise.
|
||||||
|
- even_batches (`bool`, *optional*, defaults to `True`):
|
||||||
|
If set to `True`, in cases where the total batch size across all processes does not exactly divide the
|
||||||
|
dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among
|
||||||
|
all workers.
|
||||||
|
- use_seedable_sampler (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not use a fully seedable random sampler ([`accelerate.data_loader.SeedableRandomSampler`]). Ensures
|
||||||
|
training results are fully reproducable using a different sampling technique. While seed-to-seed results
|
||||||
|
may differ, on average the differences are neglible when using multiple different seeds to compare. Should
|
||||||
|
also be ran with [`~utils.set_seed`] for the best results.
|
||||||
|
|
||||||
label_smoothing_factor (`float`, *optional*, defaults to 0.0):
|
label_smoothing_factor (`float`, *optional*, defaults to 0.0):
|
||||||
The label smoothing factor to use. Zero means no label smoothing, otherwise the underlying onehot-encoded
|
The label smoothing factor to use. Zero means no label smoothing, otherwise the underlying onehot-encoded
|
||||||
labels are changed from 0s and 1s to `label_smoothing_factor/num_labels` and `1 - label_smoothing_factor +
|
labels are changed from 0s and 1s to `label_smoothing_factor/num_labels` and `1 - label_smoothing_factor +
|
||||||
@@ -1085,6 +1113,16 @@ class TrainingArguments:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
# Do not touch this type annotation or it will stop working in CLI
|
# Do not touch this type annotation or it will stop working in CLI
|
||||||
|
accelerator_config: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"Config to be used with the internal Accelerator object initializtion. The value is either a "
|
||||||
|
"accelerator json config file (e.g., `accelerator_config.json`) or an already loaded json file as `dict`."
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Do not touch this type annotation or it will stop working in CLI
|
||||||
deepspeed: Optional[str] = field(
|
deepspeed: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
@@ -1282,20 +1320,12 @@ class TrainingArguments:
|
|||||||
|
|
||||||
dispatch_batches: Optional[bool] = field(
|
dispatch_batches: Optional[bool] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={"help": "Deprecated. Pass {'dispatch_batches':VALUE} to `accelerator_config`."},
|
||||||
"help": "Whether to dispatch batches across devices in distributed training. If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process "
|
|
||||||
"and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose"
|
|
||||||
"underlying dataset is an `IterableDataset`, `False` otherwise."
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
split_batches: Optional[bool] = field(
|
split_batches: Optional[bool] = field(
|
||||||
default=False,
|
default=None,
|
||||||
metadata={
|
metadata={"help": "Deprecated. Pass {'split_batches':True} to `accelerator_config`."},
|
||||||
"help": "Whether or not the accelerator should split the batches yielded by the dataloaders across the devices during distributed training. If"
|
|
||||||
"set to `True`, the actual batch size used will be the same on any kind of distributed processes, but it must be a"
|
|
||||||
"round multiple of the number of processes you are using (such as GPUs)."
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
include_tokens_per_second: Optional[bool] = field(
|
include_tokens_per_second: Optional[bool] = field(
|
||||||
@@ -1702,6 +1732,28 @@ class TrainingArguments:
|
|||||||
os.environ[f"{prefix}SYNC_MODULE_STATES"] = self.fsdp_config.get("sync_module_states", "true")
|
os.environ[f"{prefix}SYNC_MODULE_STATES"] = self.fsdp_config.get("sync_module_states", "true")
|
||||||
os.environ[f"{prefix}USE_ORIG_PARAMS"] = self.fsdp_config.get("use_orig_params", "true")
|
os.environ[f"{prefix}USE_ORIG_PARAMS"] = self.fsdp_config.get("use_orig_params", "true")
|
||||||
|
|
||||||
|
if is_accelerate_available():
|
||||||
|
if not isinstance(self.accelerator_config, (AcceleratorConfig, dict)):
|
||||||
|
if self.accelerator_config is None:
|
||||||
|
self.accelerator_config = AcceleratorConfig()
|
||||||
|
else:
|
||||||
|
self.accelerator_config = AcceleratorConfig.from_json_file(self.accelerator_config)
|
||||||
|
if self.dispatch_batches is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"Using `--dispatch_batches` is deprecated and will be removed in version 4.41 of 🤗 Transformers. Use"
|
||||||
|
" `--accelerator_config {'dispatch_batches':VALUE} instead",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
self.accelerator_config["dispatch_batches"] = self.dispatch_batches
|
||||||
|
|
||||||
|
if self.split_batches is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"Using `--split_batches` is deprecated and will be removed in version 4.41 of 🤗 Transformers. Use"
|
||||||
|
" `--accelerator_config {'split_batches':VALUE} instead",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
self.accelerator_config["split_batches"] = self.split_batches
|
||||||
|
|
||||||
if self.tpu_metrics_debug:
|
if self.tpu_metrics_debug:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"
|
"using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"
|
||||||
@@ -2156,6 +2208,9 @@ class TrainingArguments:
|
|||||||
d[k] = [x.value for x in v]
|
d[k] = [x.value for x in v]
|
||||||
if k.endswith("_token"):
|
if k.endswith("_token"):
|
||||||
d[k] = f"<{k.upper()}>"
|
d[k] = f"<{k.upper()}>"
|
||||||
|
# Handle the accelerator_config if passed
|
||||||
|
if is_accelerate_available() and isinstance(v, AcceleratorConfig):
|
||||||
|
d[k] = v.to_dict()
|
||||||
return d
|
return d
|
||||||
|
|
||||||
def to_json_string(self):
|
def to_json_string(self):
|
||||||
|
|||||||
@@ -118,6 +118,7 @@ if is_torch_available():
|
|||||||
TrainerState,
|
TrainerState,
|
||||||
)
|
)
|
||||||
from transformers.modeling_utils import unwrap_model
|
from transformers.modeling_utils import unwrap_model
|
||||||
|
from transformers.trainer_pt_utils import AcceleratorConfig
|
||||||
|
|
||||||
if is_safetensors_available():
|
if is_safetensors_available():
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
@@ -2412,6 +2413,146 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
execute_subprocess_async(command)
|
execute_subprocess_async(command)
|
||||||
# successful return here == success - any errors would have caused an error or a timeout in the sub-call
|
# successful return here == success - any errors would have caused an error or a timeout in the sub-call
|
||||||
|
|
||||||
|
def test_accelerator_config_empty(self):
|
||||||
|
# Checks that a config can be made with the defaults if not passed
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||||
|
model = RegressionPreTrainedModel(config)
|
||||||
|
eval_dataset = SampleIterableDataset()
|
||||||
|
|
||||||
|
# Leaves one option as something *not* basic
|
||||||
|
args = RegressionTrainingArguments(
|
||||||
|
output_dir=tmp_dir,
|
||||||
|
)
|
||||||
|
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
||||||
|
self.assertEqual(trainer.accelerator.split_batches, False)
|
||||||
|
self.assertEqual(trainer.accelerator.dispatch_batches, None)
|
||||||
|
self.assertEqual(trainer.accelerator.even_batches, True)
|
||||||
|
self.assertEqual(trainer.accelerator.use_seedable_sampler, True)
|
||||||
|
|
||||||
|
def test_accelerator_config_from_dict(self):
|
||||||
|
# Checks that accelerator kwargs can be passed through
|
||||||
|
# and the accelerator is initialized respectively
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||||
|
model = RegressionPreTrainedModel(config)
|
||||||
|
eval_dataset = SampleIterableDataset()
|
||||||
|
|
||||||
|
# Leaves all options as something *not* basic
|
||||||
|
args = RegressionTrainingArguments(
|
||||||
|
output_dir=tmp_dir,
|
||||||
|
accelerator_config={
|
||||||
|
"split_batches": True,
|
||||||
|
"dispatch_batches": True,
|
||||||
|
"even_batches": False,
|
||||||
|
"use_seedable_sampler": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
||||||
|
self.assertEqual(trainer.accelerator.split_batches, True)
|
||||||
|
self.assertEqual(trainer.accelerator.dispatch_batches, True)
|
||||||
|
self.assertEqual(trainer.accelerator.even_batches, False)
|
||||||
|
self.assertEqual(trainer.accelerator.use_seedable_sampler, True)
|
||||||
|
|
||||||
|
def test_accelerator_config_from_yaml(self):
|
||||||
|
# Checks that accelerator kwargs can be passed through
|
||||||
|
# and the accelerator is initialized respectively
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
path_file = Path(tmp_dir) / "accelerator_config.json"
|
||||||
|
with open(path_file, "w") as f:
|
||||||
|
accelerator_config = {
|
||||||
|
"split_batches": True,
|
||||||
|
"dispatch_batches": True,
|
||||||
|
"even_batches": False,
|
||||||
|
"use_seedable_sampler": False,
|
||||||
|
}
|
||||||
|
json.dump(accelerator_config, f)
|
||||||
|
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||||
|
model = RegressionPreTrainedModel(config)
|
||||||
|
eval_dataset = SampleIterableDataset()
|
||||||
|
|
||||||
|
# Leaves all options as something *not* basic
|
||||||
|
args = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config=path_file)
|
||||||
|
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
||||||
|
self.assertEqual(trainer.accelerator.split_batches, True)
|
||||||
|
self.assertEqual(trainer.accelerator.dispatch_batches, True)
|
||||||
|
self.assertEqual(trainer.accelerator.even_batches, False)
|
||||||
|
self.assertEqual(trainer.accelerator.use_seedable_sampler, False)
|
||||||
|
|
||||||
|
def test_accelerator_config_from_dataclass(self):
|
||||||
|
# Checks that accelerator kwargs can be passed through
|
||||||
|
# and the accelerator is initialized respectively
|
||||||
|
accelerator_config = AcceleratorConfig(
|
||||||
|
split_batches=True, dispatch_batches=True, even_batches=False, use_seedable_sampler=False
|
||||||
|
)
|
||||||
|
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||||
|
model = RegressionPreTrainedModel(config)
|
||||||
|
eval_dataset = SampleIterableDataset()
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
args = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config=accelerator_config)
|
||||||
|
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
||||||
|
self.assertEqual(trainer.accelerator.split_batches, True)
|
||||||
|
self.assertEqual(trainer.accelerator.dispatch_batches, True)
|
||||||
|
self.assertEqual(trainer.accelerator.even_batches, False)
|
||||||
|
self.assertEqual(trainer.accelerator.use_seedable_sampler, False)
|
||||||
|
|
||||||
|
def test_accelerator_config_from_partial(self):
|
||||||
|
# Checks that accelerator kwargs can be passed through
|
||||||
|
# and the accelerator is initialized respectively
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||||
|
model = RegressionPreTrainedModel(config)
|
||||||
|
eval_dataset = SampleIterableDataset()
|
||||||
|
|
||||||
|
# Leaves one option as something *not* basic
|
||||||
|
args = RegressionTrainingArguments(
|
||||||
|
output_dir=tmp_dir,
|
||||||
|
accelerator_config={
|
||||||
|
"split_batches": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
||||||
|
self.assertEqual(trainer.accelerator.split_batches, True)
|
||||||
|
self.assertEqual(trainer.accelerator.dispatch_batches, None)
|
||||||
|
self.assertEqual(trainer.accelerator.even_batches, True)
|
||||||
|
self.assertEqual(trainer.accelerator.use_seedable_sampler, True)
|
||||||
|
|
||||||
|
def test_accelerator_config_from_dict_with_deprecated_args(self):
|
||||||
|
# Checks that accelerator kwargs can be passed through
|
||||||
|
# and the accelerator is initialized respectively
|
||||||
|
# and maintains the deprecated args if passed in
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||||
|
model = RegressionPreTrainedModel(config)
|
||||||
|
eval_dataset = SampleIterableDataset()
|
||||||
|
|
||||||
|
# Leaves all options as something *not* basic
|
||||||
|
with self.assertWarns(FutureWarning) as cm:
|
||||||
|
args = RegressionTrainingArguments(
|
||||||
|
output_dir=tmp_dir,
|
||||||
|
accelerator_config={
|
||||||
|
"split_batches": True,
|
||||||
|
},
|
||||||
|
dispatch_batches=False,
|
||||||
|
)
|
||||||
|
self.assertIn("dispatch_batches", str(cm.warnings[0].message))
|
||||||
|
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
||||||
|
self.assertEqual(trainer.accelerator.dispatch_batches, False)
|
||||||
|
self.assertEqual(trainer.accelerator.split_batches, True)
|
||||||
|
with self.assertWarns(FutureWarning) as cm:
|
||||||
|
args = RegressionTrainingArguments(
|
||||||
|
output_dir=tmp_dir,
|
||||||
|
accelerator_config={
|
||||||
|
"even_batches": False,
|
||||||
|
},
|
||||||
|
split_batches=True,
|
||||||
|
)
|
||||||
|
self.assertIn("split_batches", str(cm.warnings[0].message))
|
||||||
|
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
|
||||||
|
self.assertEqual(trainer.accelerator.split_batches, True)
|
||||||
|
self.assertEqual(trainer.accelerator.even_batches, False)
|
||||||
|
self.assertEqual(trainer.accelerator.dispatch_batches, None)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@is_staging_test
|
@is_staging_test
|
||||||
|
|||||||
Reference in New Issue
Block a user