remove SharedDDP as it is deprecated (#25702)
* remove SharedDDP as it was drepracated * apply review suggestion * make style * Oops,forgot to remove the compute_loss context manager in Seq2SeqTrainer. * remove the unnecessary conditional statement * keep the logic of IPEX * clean code * mix precision setup & make fixup --------- Co-authored-by: statelesshz <jihuazhong1@huawei.com>
This commit is contained in:
@@ -19,7 +19,6 @@ from torch import nn
|
|||||||
from torch.utils.data import DistributedSampler, RandomSampler
|
from torch.utils.data import DistributedSampler, RandomSampler
|
||||||
|
|
||||||
from transformers import PreTrainedModel, Trainer, logging
|
from transformers import PreTrainedModel, Trainer, logging
|
||||||
from transformers.integrations import is_fairscale_available
|
|
||||||
from transformers.models.fsmt.configuration_fsmt import FSMTConfig
|
from transformers.models.fsmt.configuration_fsmt import FSMTConfig
|
||||||
from transformers.optimization import (
|
from transformers.optimization import (
|
||||||
Adafactor,
|
Adafactor,
|
||||||
@@ -36,10 +35,6 @@ from transformers.training_args import ParallelMode
|
|||||||
from transformers.utils import is_torch_tpu_available
|
from transformers.utils import is_torch_tpu_available
|
||||||
|
|
||||||
|
|
||||||
if is_fairscale_available():
|
|
||||||
from fairscale.optim import OSS
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
arg_to_scheduler = {
|
arg_to_scheduler = {
|
||||||
@@ -118,14 +113,7 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
"eps": self.args.adam_epsilon,
|
"eps": self.args.adam_epsilon,
|
||||||
}
|
}
|
||||||
optimizer_kwargs["lr"] = self.args.learning_rate
|
optimizer_kwargs["lr"] = self.args.learning_rate
|
||||||
if self.sharded_ddp:
|
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||||
self.optimizer = OSS(
|
|
||||||
params=optimizer_grouped_parameters,
|
|
||||||
optim=optimizer_cls,
|
|
||||||
**optimizer_kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
|
||||||
|
|
||||||
if self.lr_scheduler is None:
|
if self.lr_scheduler is None:
|
||||||
self.lr_scheduler = self._get_lr_scheduler(num_training_steps)
|
self.lr_scheduler = self._get_lr_scheduler(num_training_steps)
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -109,7 +109,6 @@ _deps = [
|
|||||||
"diffusers",
|
"diffusers",
|
||||||
"dill<0.3.5",
|
"dill<0.3.5",
|
||||||
"evaluate>=0.2.0",
|
"evaluate>=0.2.0",
|
||||||
"fairscale>0.3",
|
|
||||||
"faiss-cpu",
|
"faiss-cpu",
|
||||||
"fastapi",
|
"fastapi",
|
||||||
"filelock",
|
"filelock",
|
||||||
@@ -275,7 +274,6 @@ extras["modelcreation"] = deps_list("cookiecutter")
|
|||||||
|
|
||||||
extras["sagemaker"] = deps_list("sagemaker")
|
extras["sagemaker"] = deps_list("sagemaker")
|
||||||
extras["deepspeed"] = deps_list("deepspeed") + extras["accelerate"]
|
extras["deepspeed"] = deps_list("deepspeed") + extras["accelerate"]
|
||||||
extras["fairscale"] = deps_list("fairscale")
|
|
||||||
extras["optuna"] = deps_list("optuna")
|
extras["optuna"] = deps_list("optuna")
|
||||||
extras["ray"] = deps_list("ray[tune]")
|
extras["ray"] = deps_list("ray[tune]")
|
||||||
extras["sigopt"] = deps_list("sigopt")
|
extras["sigopt"] = deps_list("sigopt")
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ deps = {
|
|||||||
"diffusers": "diffusers",
|
"diffusers": "diffusers",
|
||||||
"dill": "dill<0.3.5",
|
"dill": "dill<0.3.5",
|
||||||
"evaluate": "evaluate>=0.2.0",
|
"evaluate": "evaluate>=0.2.0",
|
||||||
"fairscale": "fairscale>0.3",
|
|
||||||
"faiss-cpu": "faiss-cpu",
|
"faiss-cpu": "faiss-cpu",
|
||||||
"fastapi": "fastapi",
|
"fastapi": "fastapi",
|
||||||
"filelock": "filelock",
|
"filelock": "filelock",
|
||||||
|
|||||||
@@ -57,7 +57,6 @@ _import_structure = {
|
|||||||
"is_codecarbon_available",
|
"is_codecarbon_available",
|
||||||
"is_comet_available",
|
"is_comet_available",
|
||||||
"is_dagshub_available",
|
"is_dagshub_available",
|
||||||
"is_fairscale_available",
|
|
||||||
"is_flyte_deck_standard_available",
|
"is_flyte_deck_standard_available",
|
||||||
"is_flytekit_available",
|
"is_flytekit_available",
|
||||||
"is_mlflow_available",
|
"is_mlflow_available",
|
||||||
@@ -118,7 +117,6 @@ if TYPE_CHECKING:
|
|||||||
is_codecarbon_available,
|
is_codecarbon_available,
|
||||||
is_comet_available,
|
is_comet_available,
|
||||||
is_dagshub_available,
|
is_dagshub_available,
|
||||||
is_fairscale_available,
|
|
||||||
is_flyte_deck_standard_available,
|
is_flyte_deck_standard_available,
|
||||||
is_flytekit_available,
|
is_flytekit_available,
|
||||||
is_mlflow_available,
|
is_mlflow_available,
|
||||||
|
|||||||
@@ -134,10 +134,6 @@ def is_dagshub_available():
|
|||||||
return None not in [importlib.util.find_spec("dagshub"), importlib.util.find_spec("mlflow")]
|
return None not in [importlib.util.find_spec("dagshub"), importlib.util.find_spec("mlflow")]
|
||||||
|
|
||||||
|
|
||||||
def is_fairscale_available():
|
|
||||||
return importlib.util.find_spec("fairscale") is not None
|
|
||||||
|
|
||||||
|
|
||||||
def is_neptune_available():
|
def is_neptune_available():
|
||||||
return _has_neptune
|
return _has_neptune
|
||||||
|
|
||||||
|
|||||||
@@ -42,7 +42,6 @@ from transformers import logging as transformers_logging
|
|||||||
|
|
||||||
from .integrations import (
|
from .integrations import (
|
||||||
is_clearml_available,
|
is_clearml_available,
|
||||||
is_fairscale_available,
|
|
||||||
is_optuna_available,
|
is_optuna_available,
|
||||||
is_ray_available,
|
is_ray_available,
|
||||||
is_sigopt_available,
|
is_sigopt_available,
|
||||||
@@ -871,13 +870,6 @@ def require_deepspeed(test_case):
|
|||||||
return unittest.skipUnless(is_deepspeed_available(), "test requires deepspeed")(test_case)
|
return unittest.skipUnless(is_deepspeed_available(), "test requires deepspeed")(test_case)
|
||||||
|
|
||||||
|
|
||||||
def require_fairscale(test_case):
|
|
||||||
"""
|
|
||||||
Decorator marking a test that requires fairscale
|
|
||||||
"""
|
|
||||||
return unittest.skipUnless(is_fairscale_available(), "test requires fairscale")(test_case)
|
|
||||||
|
|
||||||
|
|
||||||
def require_apex(test_case):
|
def require_apex(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that requires apex
|
Decorator marking a test that requires apex
|
||||||
|
|||||||
@@ -40,7 +40,6 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Un
|
|||||||
from .integrations import (
|
from .integrations import (
|
||||||
get_reporting_integration_callbacks,
|
get_reporting_integration_callbacks,
|
||||||
hp_params,
|
hp_params,
|
||||||
is_fairscale_available,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# isort: on
|
# isort: on
|
||||||
@@ -58,7 +57,6 @@ from . import __version__
|
|||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
|
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
|
||||||
from .debug_utils import DebugOption, DebugUnderflowOverflow
|
from .debug_utils import DebugOption, DebugUnderflowOverflow
|
||||||
from .dependency_versions_check import dep_version_check
|
|
||||||
from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
|
from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
|
||||||
from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
|
from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
|
||||||
from .modelcard import TrainingSummary
|
from .modelcard import TrainingSummary
|
||||||
@@ -107,7 +105,6 @@ from .trainer_utils import (
|
|||||||
IntervalStrategy,
|
IntervalStrategy,
|
||||||
PredictionOutput,
|
PredictionOutput,
|
||||||
RemoveColumnsCollator,
|
RemoveColumnsCollator,
|
||||||
ShardedDDPOption,
|
|
||||||
TrainerMemoryTracker,
|
TrainerMemoryTracker,
|
||||||
TrainOutput,
|
TrainOutput,
|
||||||
default_compute_objective,
|
default_compute_objective,
|
||||||
@@ -171,15 +168,6 @@ if is_torch_tpu_available(check_device=False):
|
|||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
import torch_xla.debug.metrics as met
|
import torch_xla.debug.metrics as met
|
||||||
|
|
||||||
if is_fairscale_available():
|
|
||||||
dep_version_check("fairscale")
|
|
||||||
import fairscale
|
|
||||||
from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP
|
|
||||||
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
|
|
||||||
from fairscale.nn.wrap import auto_wrap
|
|
||||||
from fairscale.optim import OSS
|
|
||||||
from fairscale.optim.grad_scaler import ShardedGradScaler
|
|
||||||
|
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
if is_sagemaker_mp_enabled():
|
||||||
import smdistributed.modelparallel.torch as smp
|
import smdistributed.modelparallel.torch as smp
|
||||||
@@ -420,33 +408,6 @@ class Trainer:
|
|||||||
" model, please make sure that you have installed `bitsandbytes>=0.37.0`. "
|
" model, please make sure that you have installed `bitsandbytes>=0.37.0`. "
|
||||||
)
|
)
|
||||||
|
|
||||||
# Setup Sharded DDP training
|
|
||||||
self.sharded_ddp = None
|
|
||||||
if len(args.sharded_ddp) > 0:
|
|
||||||
if self.is_deepspeed_enabled:
|
|
||||||
raise ValueError(
|
|
||||||
"Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags."
|
|
||||||
)
|
|
||||||
if len(args.fsdp) > 0:
|
|
||||||
raise ValueError(
|
|
||||||
"Using --sharded_ddp xxx together with --fsdp is not possible, deactivate one of those flags."
|
|
||||||
)
|
|
||||||
if args.parallel_mode != ParallelMode.DISTRIBUTED:
|
|
||||||
raise ValueError("Using sharded DDP only works in distributed training.")
|
|
||||||
elif not is_fairscale_available():
|
|
||||||
raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.")
|
|
||||||
elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None:
|
|
||||||
raise ImportError(
|
|
||||||
"Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found "
|
|
||||||
f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`."
|
|
||||||
)
|
|
||||||
elif ShardedDDPOption.SIMPLE in args.sharded_ddp:
|
|
||||||
self.sharded_ddp = ShardedDDPOption.SIMPLE
|
|
||||||
elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp:
|
|
||||||
self.sharded_ddp = ShardedDDPOption.ZERO_DP_2
|
|
||||||
elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp:
|
|
||||||
self.sharded_ddp = ShardedDDPOption.ZERO_DP_3
|
|
||||||
|
|
||||||
self.fsdp = None
|
self.fsdp = None
|
||||||
if len(args.fsdp) > 0:
|
if len(args.fsdp) > 0:
|
||||||
if self.is_deepspeed_enabled:
|
if self.is_deepspeed_enabled:
|
||||||
@@ -488,14 +449,12 @@ class Trainer:
|
|||||||
# 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
|
# 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
|
||||||
# and we only use deepspeed for training at the moment
|
# and we only use deepspeed for training at the moment
|
||||||
# 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first
|
# 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first
|
||||||
# 4. Sharded DDP - same as MP
|
# 4. FSDP - same as MP
|
||||||
# 5. FSDP - same as MP
|
|
||||||
self.place_model_on_device = args.place_model_on_device
|
self.place_model_on_device = args.place_model_on_device
|
||||||
if (
|
if (
|
||||||
self.is_model_parallel
|
self.is_model_parallel
|
||||||
or self.is_deepspeed_enabled
|
or self.is_deepspeed_enabled
|
||||||
or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train)
|
or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train)
|
||||||
or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])
|
|
||||||
or (self.fsdp is not None)
|
or (self.fsdp is not None)
|
||||||
or self.is_fsdp_enabled
|
or self.is_fsdp_enabled
|
||||||
):
|
):
|
||||||
@@ -545,11 +504,11 @@ class Trainer:
|
|||||||
" `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and"
|
" `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and"
|
||||||
" `model.to(xm.xla_device())` is performed before the optimizer creation in your script."
|
" `model.to(xm.xla_device())` is performed before the optimizer creation in your script."
|
||||||
)
|
)
|
||||||
if ((self.sharded_ddp is not None) or self.is_deepspeed_enabled or (self.fsdp is not None)) and (
|
if (self.is_deepspeed_enabled or (self.fsdp is not None)) and (
|
||||||
self.optimizer is not None or self.lr_scheduler is not None
|
self.optimizer is not None or self.lr_scheduler is not None
|
||||||
):
|
):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Passing `optimizers` is not allowed if Fairscale, Deepspeed or PyTorch FSDP is enabled."
|
"Passing `optimizers` is not allowed if Deepspeed or PyTorch FSDP is enabled."
|
||||||
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
|
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
|
||||||
)
|
)
|
||||||
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
|
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
|
||||||
@@ -592,7 +551,6 @@ class Trainer:
|
|||||||
|
|
||||||
# Mixed precision setup
|
# Mixed precision setup
|
||||||
self.use_apex = False
|
self.use_apex = False
|
||||||
self.use_cuda_amp = False
|
|
||||||
self.use_cpu_amp = False
|
self.use_cpu_amp = False
|
||||||
|
|
||||||
# Mixed precision setup for SageMaker Model Parallel
|
# Mixed precision setup for SageMaker Model Parallel
|
||||||
@@ -617,33 +575,19 @@ class Trainer:
|
|||||||
f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, "
|
f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, "
|
||||||
"but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer."
|
"but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer."
|
||||||
)
|
)
|
||||||
|
if (args.fp16 or args.bf16) and args.half_precision_backend == "auto":
|
||||||
if (args.fp16 or args.bf16) and self.sharded_ddp is not None:
|
if args.device == torch.device("cpu"):
|
||||||
if args.half_precision_backend == "auto":
|
if args.fp16:
|
||||||
if args.device == torch.device("cpu"):
|
raise ValueError("Tried to use `fp16` but it is not supported on cpu")
|
||||||
if args.fp16:
|
|
||||||
raise ValueError("Tried to use `fp16` but it is not supported on cpu")
|
|
||||||
else:
|
|
||||||
args.half_precision_backend = "cpu_amp"
|
|
||||||
else:
|
else:
|
||||||
args.half_precision_backend = "cuda_amp"
|
args.half_precision_backend = "cpu_amp"
|
||||||
|
|
||||||
logger.info(f"Using {args.half_precision_backend} half precision backend")
|
logger.info(f"Using {args.half_precision_backend} half precision backend")
|
||||||
|
|
||||||
self.do_grad_scaling = False
|
|
||||||
if (args.fp16 or args.bf16) and not (self.is_deepspeed_enabled or is_sagemaker_mp_enabled()):
|
if (args.fp16 or args.bf16) and not (self.is_deepspeed_enabled or is_sagemaker_mp_enabled()):
|
||||||
# deepspeed and SageMaker Model Parallel manage their own half precision
|
# deepspeed and SageMaker Model Parallel manage their own half precision
|
||||||
if self.sharded_ddp is not None:
|
if args.half_precision_backend == "cpu_amp":
|
||||||
if args.half_precision_backend == "cuda_amp":
|
self.use_cpu_amp = True
|
||||||
self.use_cuda_amp = True
|
self.amp_dtype = torch.bfloat16
|
||||||
self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16
|
|
||||||
# bf16 does not need grad scaling
|
|
||||||
self.do_grad_scaling = self.amp_dtype == torch.float16
|
|
||||||
if self.do_grad_scaling:
|
|
||||||
self.scaler = ShardedGradScaler()
|
|
||||||
elif args.half_precision_backend == "cpu_amp":
|
|
||||||
self.use_cpu_amp = True
|
|
||||||
self.amp_dtype = torch.bfloat16
|
|
||||||
elif args.half_precision_backend == "apex":
|
elif args.half_precision_backend == "apex":
|
||||||
if not is_apex_available():
|
if not is_apex_available():
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
@@ -652,18 +596,6 @@ class Trainer:
|
|||||||
)
|
)
|
||||||
self.use_apex = True
|
self.use_apex = True
|
||||||
|
|
||||||
# FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error.
|
|
||||||
if (
|
|
||||||
is_sagemaker_mp_enabled()
|
|
||||||
and self.use_cuda_amp
|
|
||||||
and args.max_grad_norm is not None
|
|
||||||
and args.max_grad_norm > 0
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass "
|
|
||||||
"along 'max_grad_norm': 0 in your hyperparameters."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Label smoothing
|
# Label smoothing
|
||||||
if self.args.label_smoothing_factor != 0:
|
if self.args.label_smoothing_factor != 0:
|
||||||
self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
|
self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
|
||||||
@@ -994,27 +926,20 @@ class Trainer:
|
|||||||
|
|
||||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
|
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
|
||||||
|
|
||||||
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
|
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||||
self.optimizer = OSS(
|
if optimizer_cls.__name__ == "Adam8bit":
|
||||||
params=optimizer_grouped_parameters,
|
import bitsandbytes
|
||||||
optim=optimizer_cls,
|
|
||||||
**optimizer_kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
|
||||||
if optimizer_cls.__name__ == "Adam8bit":
|
|
||||||
import bitsandbytes
|
|
||||||
|
|
||||||
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
||||||
|
|
||||||
skipped = 0
|
skipped = 0
|
||||||
for module in opt_model.modules():
|
for module in opt_model.modules():
|
||||||
if isinstance(module, nn.Embedding):
|
if isinstance(module, nn.Embedding):
|
||||||
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
|
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
|
||||||
logger.info(f"skipped {module}: {skipped/2**20}M params")
|
logger.info(f"skipped {module}: {skipped/2**20}M params")
|
||||||
manager.register_module_override(module, "weight", {"optim_bits": 32})
|
manager.register_module_override(module, "weight", {"optim_bits": 32})
|
||||||
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
|
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
|
||||||
logger.info(f"skipped: {skipped/2**20}M params")
|
logger.info(f"skipped: {skipped/2**20}M params")
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
if is_sagemaker_mp_enabled():
|
||||||
self.optimizer = smp.DistributedOptimizer(self.optimizer)
|
self.optimizer = smp.DistributedOptimizer(self.optimizer)
|
||||||
@@ -1333,7 +1258,6 @@ class Trainer:
|
|||||||
jit_model(**example_batch)
|
jit_model(**example_batch)
|
||||||
model = jit_model
|
model = jit_model
|
||||||
self.use_cpu_amp = False
|
self.use_cpu_amp = False
|
||||||
self.use_cuda_amp = False
|
|
||||||
except (RuntimeError, TypeError, ValueError, NameError, IndexError) as e:
|
except (RuntimeError, TypeError, ValueError, NameError, IndexError) as e:
|
||||||
logger.warning(f"failed to use PyTorch jit mode due to: {e}.")
|
logger.warning(f"failed to use PyTorch jit mode due to: {e}.")
|
||||||
|
|
||||||
@@ -1396,25 +1320,8 @@ class Trainer:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
# Distributed training (should be after apex fp16 initialization)
|
# Distributed training (should be after apex fp16 initialization)
|
||||||
if self.sharded_ddp is not None:
|
|
||||||
# Sharded DDP!
|
|
||||||
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
|
|
||||||
model = ShardedDDP(model, self.optimizer)
|
|
||||||
else:
|
|
||||||
mixed_precision = self.args.fp16 or self.args.bf16
|
|
||||||
cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp
|
|
||||||
zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3
|
|
||||||
# XXX: Breaking the self.model convention but I see no way around it for now.
|
|
||||||
if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp:
|
|
||||||
model = auto_wrap(model)
|
|
||||||
self.model = model = FullyShardedDDP(
|
|
||||||
model,
|
|
||||||
mixed_precision=mixed_precision,
|
|
||||||
reshard_after_forward=zero_3,
|
|
||||||
cpu_offload=cpu_offload,
|
|
||||||
).to(self.args.device)
|
|
||||||
# Distributed training using PyTorch FSDP
|
# Distributed training using PyTorch FSDP
|
||||||
elif self.fsdp is not None and self.args.fsdp_config["xla"]:
|
if self.fsdp is not None and self.args.fsdp_config["xla"]:
|
||||||
try:
|
try:
|
||||||
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
|
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
|
||||||
from torch_xla.distributed.fsdp import checkpoint_module
|
from torch_xla.distributed.fsdp import checkpoint_module
|
||||||
@@ -1669,13 +1576,7 @@ class Trainer:
|
|||||||
else:
|
else:
|
||||||
debug_overflow = DebugUnderflowOverflow(self.model) # noqa
|
debug_overflow = DebugUnderflowOverflow(self.model) # noqa
|
||||||
|
|
||||||
delay_optimizer_creation = (
|
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.fsdp is not None or self.is_fsdp_enabled
|
||||||
self.sharded_ddp is not None
|
|
||||||
and self.sharded_ddp != ShardedDDPOption.SIMPLE
|
|
||||||
or is_sagemaker_mp_enabled()
|
|
||||||
or self.fsdp is not None
|
|
||||||
or self.is_fsdp_enabled
|
|
||||||
)
|
|
||||||
|
|
||||||
# We need to reset the scheduler, as its parameters may be different on subsequent calls
|
# We need to reset the scheduler, as its parameters may be different on subsequent calls
|
||||||
if self._created_lr_scheduler:
|
if self._created_lr_scheduler:
|
||||||
@@ -1716,7 +1617,7 @@ class Trainer:
|
|||||||
|
|
||||||
# as the model is wrapped, don't use `accelerator.prepare`
|
# as the model is wrapped, don't use `accelerator.prepare`
|
||||||
# this is for unhandled cases such as
|
# this is for unhandled cases such as
|
||||||
# Fairscale Sharded DDP, FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
|
# FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
|
||||||
use_accelerator_prepare = True if model is self.model else False
|
use_accelerator_prepare = True if model is self.model else False
|
||||||
|
|
||||||
if delay_optimizer_creation:
|
if delay_optimizer_creation:
|
||||||
@@ -1932,14 +1833,6 @@ class Trainer:
|
|||||||
if args.max_grad_norm is not None and args.max_grad_norm > 0:
|
if args.max_grad_norm is not None and args.max_grad_norm > 0:
|
||||||
# deepspeed does its own clipping
|
# deepspeed does its own clipping
|
||||||
|
|
||||||
if self.do_grad_scaling:
|
|
||||||
# Reduce gradients first for XLA
|
|
||||||
if is_torch_tpu_available():
|
|
||||||
gradients = xm._fetch_gradients(self.optimizer)
|
|
||||||
xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
|
|
||||||
# AMP: gradients need unscaling
|
|
||||||
self.scaler.unscale_(self.optimizer)
|
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled() and args.fp16:
|
if is_sagemaker_mp_enabled() and args.fp16:
|
||||||
self.optimizer.clip_master_grads(args.max_grad_norm)
|
self.optimizer.clip_master_grads(args.max_grad_norm)
|
||||||
elif hasattr(self.optimizer, "clip_grad_norm"):
|
elif hasattr(self.optimizer, "clip_grad_norm"):
|
||||||
@@ -1961,24 +1854,8 @@ class Trainer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Optimizer step
|
# Optimizer step
|
||||||
optimizer_was_run = True
|
self.optimizer.step()
|
||||||
if is_torch_tpu_available():
|
optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
|
||||||
if self.do_grad_scaling:
|
|
||||||
self.scaler.step(self.optimizer)
|
|
||||||
self.scaler.update()
|
|
||||||
else:
|
|
||||||
# tpu-comment: accelerate wrapped optimizers call xm.optimizer_step
|
|
||||||
self.optimizer.step()
|
|
||||||
elif self.do_grad_scaling:
|
|
||||||
scale_before = self.scaler.get_scale()
|
|
||||||
self.scaler.step(self.optimizer)
|
|
||||||
self.scaler.update()
|
|
||||||
scale_after = self.scaler.get_scale()
|
|
||||||
optimizer_was_run = scale_before <= scale_after
|
|
||||||
else:
|
|
||||||
self.optimizer.step()
|
|
||||||
optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
|
|
||||||
|
|
||||||
if optimizer_was_run:
|
if optimizer_was_run:
|
||||||
# Delay optimizer scheduling until metrics are generated
|
# Delay optimizer scheduling until metrics are generated
|
||||||
if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
||||||
@@ -2408,9 +2285,6 @@ class Trainer:
|
|||||||
self.model_wrapped.save_checkpoint(output_dir)
|
self.model_wrapped.save_checkpoint(output_dir)
|
||||||
|
|
||||||
# Save optimizer and scheduler
|
# Save optimizer and scheduler
|
||||||
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
|
|
||||||
self.optimizer.consolidate_state_dict()
|
|
||||||
|
|
||||||
if self.fsdp or self.is_fsdp_enabled:
|
if self.fsdp or self.is_fsdp_enabled:
|
||||||
if self.is_fsdp_enabled:
|
if self.is_fsdp_enabled:
|
||||||
save_fsdp_optimizer(
|
save_fsdp_optimizer(
|
||||||
@@ -2455,8 +2329,6 @@ class Trainer:
|
|||||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
|
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
|
||||||
reissue_pt_warnings(caught_warnings)
|
reissue_pt_warnings(caught_warnings)
|
||||||
if self.do_grad_scaling:
|
|
||||||
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
|
|
||||||
|
|
||||||
# Determine the new best metric / best model checkpoint
|
# Determine the new best metric / best model checkpoint
|
||||||
if metrics is not None and self.args.metric_for_best_model is not None:
|
if metrics is not None and self.args.metric_for_best_model is not None:
|
||||||
@@ -2600,8 +2472,6 @@ class Trainer:
|
|||||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
|
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
|
||||||
reissue_pt_warnings(caught_warnings)
|
reissue_pt_warnings(caught_warnings)
|
||||||
if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)):
|
|
||||||
self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME)))
|
|
||||||
|
|
||||||
def hyperparameter_search(
|
def hyperparameter_search(
|
||||||
self,
|
self,
|
||||||
@@ -2744,12 +2614,8 @@ class Trainer:
|
|||||||
A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired
|
A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired
|
||||||
arguments, depending on the situation.
|
arguments, depending on the situation.
|
||||||
"""
|
"""
|
||||||
if self.use_cuda_amp or self.use_cpu_amp:
|
if self.use_cpu_amp:
|
||||||
ctx_manager = (
|
ctx_manager = torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
|
||||||
torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
|
|
||||||
if self.use_cpu_amp
|
|
||||||
else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
ctx_manager = contextlib.nullcontext()
|
ctx_manager = contextlib.nullcontext()
|
||||||
|
|
||||||
@@ -2786,9 +2652,7 @@ class Trainer:
|
|||||||
if self.args.n_gpu > 1:
|
if self.args.n_gpu > 1:
|
||||||
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||||
|
|
||||||
if self.do_grad_scaling:
|
if self.use_apex:
|
||||||
self.scaler.scale(loss).backward()
|
|
||||||
elif self.use_apex:
|
|
||||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||||
scaled_loss.backward()
|
scaled_loss.backward()
|
||||||
else:
|
else:
|
||||||
@@ -2872,12 +2736,7 @@ class Trainer:
|
|||||||
if IS_SAGEMAKER_MP_POST_1_10:
|
if IS_SAGEMAKER_MP_POST_1_10:
|
||||||
# 'user_content.pt' indicates model state_dict saved with smp >= 1.10
|
# 'user_content.pt' indicates model state_dict saved with smp >= 1.10
|
||||||
Path(os.path.join(output_dir, "user_content.pt")).touch()
|
Path(os.path.join(output_dir, "user_content.pt")).touch()
|
||||||
elif (
|
elif self.fsdp is not None or self.is_fsdp_enabled:
|
||||||
ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp
|
|
||||||
or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
|
|
||||||
or self.fsdp is not None
|
|
||||||
or self.is_fsdp_enabled
|
|
||||||
):
|
|
||||||
state_dict = self.model.state_dict() if not self.is_fsdp_enabled else {}
|
state_dict = self.model.state_dict() if not self.is_fsdp_enabled else {}
|
||||||
if self.args.should_save:
|
if self.args.should_save:
|
||||||
self._save(output_dir, state_dict=state_dict)
|
self._save(output_dir, state_dict=state_dict)
|
||||||
|
|||||||
@@ -266,7 +266,6 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
has_labels = "labels" in inputs
|
has_labels = "labels" in inputs
|
||||||
inputs = self._prepare_inputs(inputs)
|
inputs = self._prepare_inputs(inputs)
|
||||||
|
|
||||||
# XXX: adapt synced_gpus for fairscale as well
|
|
||||||
# Priority (handled in generate):
|
# Priority (handled in generate):
|
||||||
# non-`None` gen_kwargs > model.generation_config > default GenerationConfig()
|
# non-`None` gen_kwargs > model.generation_config > default GenerationConfig()
|
||||||
if len(gen_kwargs) == 0 and hasattr(self, "_gen_kwargs"):
|
if len(gen_kwargs) == 0 and hasattr(self, "_gen_kwargs"):
|
||||||
|
|||||||
@@ -651,14 +651,6 @@ def number_of_arguments(func):
|
|||||||
return len(inspect.signature(func).parameters)
|
return len(inspect.signature(func).parameters)
|
||||||
|
|
||||||
|
|
||||||
class ShardedDDPOption(ExplicitEnum):
|
|
||||||
SIMPLE = "simple"
|
|
||||||
ZERO_DP_2 = "zero_dp_2"
|
|
||||||
ZERO_DP_3 = "zero_dp_3"
|
|
||||||
OFFLOAD = "offload"
|
|
||||||
AUTO_WRAP = "auto_wrap"
|
|
||||||
|
|
||||||
|
|
||||||
def find_executable_batch_size(
|
def find_executable_batch_size(
|
||||||
function: callable = None, starting_batch_size: int = 128, auto_find_batch_size: bool = False
|
function: callable = None, starting_batch_size: int = 128, auto_find_batch_size: bool = False
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -34,7 +34,6 @@ from .trainer_utils import (
|
|||||||
HubStrategy,
|
HubStrategy,
|
||||||
IntervalStrategy,
|
IntervalStrategy,
|
||||||
SchedulerType,
|
SchedulerType,
|
||||||
ShardedDDPOption,
|
|
||||||
)
|
)
|
||||||
from .utils import (
|
from .utils import (
|
||||||
ExplicitEnum,
|
ExplicitEnum,
|
||||||
@@ -328,9 +327,9 @@ class TrainingArguments:
|
|||||||
fp16_backend (`str`, *optional*, defaults to `"auto"`):
|
fp16_backend (`str`, *optional*, defaults to `"auto"`):
|
||||||
This argument is deprecated. Use `half_precision_backend` instead.
|
This argument is deprecated. Use `half_precision_backend` instead.
|
||||||
half_precision_backend (`str`, *optional*, defaults to `"auto"`):
|
half_precision_backend (`str`, *optional*, defaults to `"auto"`):
|
||||||
The backend to use for mixed precision training. Must be one of `"auto", "cuda_amp", "apex", "cpu_amp"`.
|
The backend to use for mixed precision training. Must be one of `"auto", "apex", "cpu_amp"`. `"auto"` will
|
||||||
`"auto"` will use CPU/CUDA AMP or APEX depending on the PyTorch version detected, while the other choices
|
use CPU/CUDA AMP or APEX depending on the PyTorch version detected, while the other choices will force the
|
||||||
will force the requested backend.
|
requested backend.
|
||||||
bf16_full_eval (`bool`, *optional*, defaults to `False`):
|
bf16_full_eval (`bool`, *optional*, defaults to `False`):
|
||||||
Whether to use full bfloat16 evaluation instead of 32-bit. This will be faster and save memory but can harm
|
Whether to use full bfloat16 evaluation instead of 32-bit. This will be faster and save memory but can harm
|
||||||
metric values. This is an experimental API and it may change.
|
metric values. This is an experimental API and it may change.
|
||||||
@@ -410,21 +409,6 @@ class TrainingArguments:
|
|||||||
When resuming training, whether or not to skip the epochs and batches to get the data loading at the same
|
When resuming training, whether or not to skip the epochs and batches to get the data loading at the same
|
||||||
stage as in the previous training. If set to `True`, the training will begin faster (as that skipping step
|
stage as in the previous training. If set to `True`, the training will begin faster (as that skipping step
|
||||||
can take a long time) but will not yield the same results as the interrupted training would have.
|
can take a long time) but will not yield the same results as the interrupted training would have.
|
||||||
sharded_ddp (`bool`, `str` or list of [`~trainer_utils.ShardedDDPOption`], *optional*, defaults to `''`):
|
|
||||||
Use Sharded DDP training from [FairScale](https://github.com/facebookresearch/fairscale) (in distributed
|
|
||||||
training only). This is an experimental feature.
|
|
||||||
|
|
||||||
A list of options along the following:
|
|
||||||
|
|
||||||
- `"simple"`: to use first instance of sharded DDP released by fairscale (`ShardedDDP`) similar to ZeRO-2.
|
|
||||||
- `"zero_dp_2"`: to use the second instance of sharded DPP released by fairscale (`FullyShardedDDP`) in
|
|
||||||
Zero-2 mode (with `reshard_after_forward=False`).
|
|
||||||
- `"zero_dp_3"`: to use the second instance of sharded DPP released by fairscale (`FullyShardedDDP`) in
|
|
||||||
Zero-3 mode (with `reshard_after_forward=True`).
|
|
||||||
- `"offload"`: to add ZeRO-offload (only compatible with `"zero_dp_2"` and `"zero_dp_3"`).
|
|
||||||
|
|
||||||
If a string is passed, it will be split on space. If a bool is passed, it will be converted to an empty
|
|
||||||
list for `False` and `["simple"]` for `True`.
|
|
||||||
fsdp (`bool`, `str` or list of [`~trainer_utils.FSDPOption`], *optional*, defaults to `''`):
|
fsdp (`bool`, `str` or list of [`~trainer_utils.FSDPOption`], *optional*, defaults to `''`):
|
||||||
Use PyTorch Distributed Parallel Training (in distributed training only).
|
Use PyTorch Distributed Parallel Training (in distributed training only).
|
||||||
|
|
||||||
@@ -877,7 +861,7 @@ class TrainingArguments:
|
|||||||
default="auto",
|
default="auto",
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The backend to be used for half precision.",
|
"help": "The backend to be used for half precision.",
|
||||||
"choices": ["auto", "cuda_amp", "apex", "cpu_amp"],
|
"choices": ["auto", "apex", "cpu_amp"],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
bf16_full_eval: bool = field(
|
bf16_full_eval: bool = field(
|
||||||
@@ -996,17 +980,6 @@ class TrainingArguments:
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
sharded_ddp: Optional[Union[List[ShardedDDPOption], str]] = field(
|
|
||||||
default="",
|
|
||||||
metadata={
|
|
||||||
"help": (
|
|
||||||
"Whether or not to use sharded DDP training (in distributed training only). The base option should be"
|
|
||||||
" `simple`, `zero_dp_2` or `zero_dp_3` and you can add CPU-offload to `zero_dp_2` or `zero_dp_3` like"
|
|
||||||
" this: zero_dp_2 offload` or `zero_dp_3 offload`. You can add auto-wrap to `zero_dp_2` or `zero_dp_3`"
|
|
||||||
" with the same syntax: zero_dp_2 auto_wrap` or `zero_dp_3 auto_wrap`."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
fsdp: Optional[Union[List[FSDPOption], str]] = field(
|
fsdp: Optional[Union[List[FSDPOption], str]] = field(
|
||||||
default="",
|
default="",
|
||||||
metadata={
|
metadata={
|
||||||
@@ -1154,7 +1127,7 @@ class TrainingArguments:
|
|||||||
default="auto",
|
default="auto",
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Deprecated. Use half_precision_backend instead",
|
"help": "Deprecated. Use half_precision_backend instead",
|
||||||
"choices": ["auto", "cuda_amp", "apex", "cpu_amp"],
|
"choices": ["auto", "apex", "cpu_amp"],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
push_to_hub_model_id: Optional[str] = field(
|
push_to_hub_model_id: Optional[str] = field(
|
||||||
@@ -1407,8 +1380,6 @@ class TrainingArguments:
|
|||||||
" `--half_precision_backend apex`: GPU bf16 is not supported by apex. Use"
|
" `--half_precision_backend apex`: GPU bf16 is not supported by apex. Use"
|
||||||
" `--half_precision_backend cuda_amp` instead"
|
" `--half_precision_backend cuda_amp` instead"
|
||||||
)
|
)
|
||||||
if not (self.sharded_ddp == "" or not self.sharded_ddp):
|
|
||||||
raise ValueError("sharded_ddp is not supported with bf16")
|
|
||||||
|
|
||||||
if self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU:
|
if self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU:
|
||||||
if self.evaluation_strategy == IntervalStrategy.NO:
|
if self.evaluation_strategy == IntervalStrategy.NO:
|
||||||
@@ -1508,7 +1479,7 @@ class TrainingArguments:
|
|||||||
# no need to assert on else
|
# no need to assert on else
|
||||||
|
|
||||||
# if training args is specified, it will override the one specified in the accelerate config
|
# if training args is specified, it will override the one specified in the accelerate config
|
||||||
if self.half_precision_backend != "apex" and len(self.sharded_ddp) == 0:
|
if self.half_precision_backend != "apex":
|
||||||
mixed_precision_dtype = os.environ.get("ACCELERATE_MIXED_PRECISION", "no")
|
mixed_precision_dtype = os.environ.get("ACCELERATE_MIXED_PRECISION", "no")
|
||||||
if self.fp16:
|
if self.fp16:
|
||||||
mixed_precision_dtype = "fp16"
|
mixed_precision_dtype = "fp16"
|
||||||
@@ -1541,26 +1512,6 @@ class TrainingArguments:
|
|||||||
" during training"
|
" during training"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not (self.sharded_ddp == "" or not self.sharded_ddp):
|
|
||||||
warnings.warn(
|
|
||||||
"using `sharded_ddp` is deprecated and will be removed in version 4.33"
|
|
||||||
" of 🤗 Transformers. Use `fsdp` instead",
|
|
||||||
FutureWarning,
|
|
||||||
)
|
|
||||||
if isinstance(self.sharded_ddp, bool):
|
|
||||||
self.sharded_ddp = "simple" if self.sharded_ddp else ""
|
|
||||||
if isinstance(self.sharded_ddp, str):
|
|
||||||
self.sharded_ddp = [ShardedDDPOption(s) for s in self.sharded_ddp.split()]
|
|
||||||
if self.sharded_ddp == [ShardedDDPOption.OFFLOAD]:
|
|
||||||
raise ValueError(
|
|
||||||
"`--sharded_ddp offload` can't work on its own. It needs to be added to `--sharded_ddp zero_dp_2` or "
|
|
||||||
'`--sharded_ddp zero_dp_3`. For example, `--sharded_ddp "zero_dp_2 offload"`.'
|
|
||||||
)
|
|
||||||
elif len(self.sharded_ddp) > 1 and ShardedDDPOption.SIMPLE in self.sharded_ddp:
|
|
||||||
raise ValueError("`--sharded_ddp simple` is not compatible with any other option.")
|
|
||||||
elif ShardedDDPOption.ZERO_DP_2 in self.sharded_ddp and ShardedDDPOption.ZERO_DP_3 in self.sharded_ddp:
|
|
||||||
raise ValueError("`--sharded_ddp zero_dp_2` is not compatible with `--sharded_ddp zero_dp_3`.")
|
|
||||||
|
|
||||||
if isinstance(self.fsdp, bool):
|
if isinstance(self.fsdp, bool):
|
||||||
self.fsdp = "full_shard" if self.fsdp else ""
|
self.fsdp = "full_shard" if self.fsdp else ""
|
||||||
if isinstance(self.fsdp, str):
|
if isinstance(self.fsdp, str):
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ import math
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import unittest
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
@@ -32,7 +31,6 @@ from transformers.testing_utils import (
|
|||||||
get_torch_dist_unique_port,
|
get_torch_dist_unique_port,
|
||||||
require_apex,
|
require_apex,
|
||||||
require_bitsandbytes,
|
require_bitsandbytes,
|
||||||
require_fairscale,
|
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
require_torch_multi_gpu,
|
require_torch_multi_gpu,
|
||||||
@@ -105,36 +103,6 @@ class TestTrainerExt(TestCasePlus):
|
|||||||
def test_run_seq2seq_ddp(self):
|
def test_run_seq2seq_ddp(self):
|
||||||
self.run_seq2seq_quick(distributed=True)
|
self.run_seq2seq_quick(distributed=True)
|
||||||
|
|
||||||
# test --sharded_ddp w/o --fp16
|
|
||||||
@unittest.skip("Requires an update of the env running those tests")
|
|
||||||
@require_torch_multi_gpu
|
|
||||||
@require_fairscale
|
|
||||||
def test_run_seq2seq_sharded_ddp(self):
|
|
||||||
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple")
|
|
||||||
|
|
||||||
# test --sharded_ddp w/ --fp16
|
|
||||||
@unittest.skip("Requires an update of the env running those tests")
|
|
||||||
@require_torch_multi_gpu
|
|
||||||
@require_fairscale
|
|
||||||
def test_run_seq2seq_sharded_ddp_fp16(self):
|
|
||||||
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple --fp16")
|
|
||||||
|
|
||||||
# test --sharded_ddp zero_dp_2 w/o --fp16
|
|
||||||
@unittest.skip("Requires an update of the env running those tests")
|
|
||||||
@require_torch_multi_gpu
|
|
||||||
@require_fairscale
|
|
||||||
def test_run_seq2seq_fully_sharded_ddp(self):
|
|
||||||
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp zero_dp_2", predict_with_generate=False)
|
|
||||||
|
|
||||||
# test --sharded_ddp zero_dp_2 w/ --fp16
|
|
||||||
@unittest.skip("Requires an update of the env running those tests")
|
|
||||||
@require_torch_multi_gpu
|
|
||||||
@require_fairscale
|
|
||||||
def test_run_seq2seq_fully_sharded_ddp_fp16(self):
|
|
||||||
self.run_seq2seq_quick(
|
|
||||||
distributed=True, extra_args_str="--sharded_ddp zero_dp_2 --fp16", predict_with_generate=False
|
|
||||||
)
|
|
||||||
|
|
||||||
@require_apex
|
@require_apex
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
def test_run_seq2seq_apex(self):
|
def test_run_seq2seq_apex(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user