fsdp fixes and enhancements (#24980)
* fix fsdp prepare to remove the warnings and fix excess memory usage * Update training_args.py * parity for FSDP+XLA * Update trainer.py
This commit is contained in:
committed by
GitHub
parent
ec3dfe5e24
commit
f4eb459ef2
@@ -441,7 +441,7 @@ as the model saving with FSDP activated is only available with recent fixes.
|
|||||||
- Remaining FSDP config is passed via `--fsdp_config <path_to_fsdp_config.json>`. It is either a location of
|
- Remaining FSDP config is passed via `--fsdp_config <path_to_fsdp_config.json>`. It is either a location of
|
||||||
FSDP json config file (e.g., `fsdp_config.json`) or an already loaded json file as `dict`.
|
FSDP json config file (e.g., `fsdp_config.json`) or an already loaded json file as `dict`.
|
||||||
- If auto wrapping is enabled, you can either use transformer based auto wrap policy or size based auto wrap policy.
|
- If auto wrapping is enabled, you can either use transformer based auto wrap policy or size based auto wrap policy.
|
||||||
- For transformer based auto wrap policy, please specify `fsdp_transformer_layer_cls_to_wrap` in the config file.
|
- For transformer based auto wrap policy, it is recommended to specify `fsdp_transformer_layer_cls_to_wrap` in the config file. If not specified, the default value is `model._no_split_modules` when available.
|
||||||
This specifies the list of transformer layer class name (case-sensitive) to wrap ,e.g, [`BertLayer`], [`GPTJBlock`], [`T5Block`] ....
|
This specifies the list of transformer layer class name (case-sensitive) to wrap ,e.g, [`BertLayer`], [`GPTJBlock`], [`T5Block`] ....
|
||||||
This is important because submodules that share weights (e.g., embedding layer) should not end up in different FSDP wrapped units.
|
This is important because submodules that share weights (e.g., embedding layer) should not end up in different FSDP wrapped units.
|
||||||
Using this policy, wrapping happens for each block containing Multi-Head Attention followed by couple of MLP layers.
|
Using this policy, wrapping happens for each block containing Multi-Head Attention followed by couple of MLP layers.
|
||||||
@@ -482,7 +482,7 @@ Pass `--fsdp "full shard"` along with following changes to be made in `--fsdp_co
|
|||||||
This setting can only be used when the xla flag is set to true, and an auto wrapping policy is specified through
|
This setting can only be used when the xla flag is set to true, and an auto wrapping policy is specified through
|
||||||
`fsdp_min_num_params` or `fsdp_transformer_layer_cls_to_wrap`.
|
`fsdp_min_num_params` or `fsdp_transformer_layer_cls_to_wrap`.
|
||||||
- You can either use transformer based auto wrap policy or size based auto wrap policy.
|
- You can either use transformer based auto wrap policy or size based auto wrap policy.
|
||||||
- For transformer based auto wrap policy, please specify `fsdp_transformer_layer_cls_to_wrap` in the config file.
|
- For transformer based auto wrap policy, it is recommended to specify `fsdp_transformer_layer_cls_to_wrap` in the config file. If not specified, the default value is `model._no_split_modules` when available.
|
||||||
This specifies the list of transformer layer class name (case-sensitive) to wrap ,e.g, [`BertLayer`], [`GPTJBlock`], [`T5Block`] ....
|
This specifies the list of transformer layer class name (case-sensitive) to wrap ,e.g, [`BertLayer`], [`GPTJBlock`], [`T5Block`] ....
|
||||||
This is important because submodules that share weights (e.g., embedding layer) should not end up in different FSDP wrapped units.
|
This is important because submodules that share weights (e.g., embedding layer) should not end up in different FSDP wrapped units.
|
||||||
Using this policy, wrapping happens for each block containing Multi-Head Attention followed by couple of MLP layers.
|
Using this policy, wrapping happens for each block containing Multi-Head Attention followed by couple of MLP layers.
|
||||||
|
|||||||
@@ -1377,18 +1377,24 @@ class Trainer:
|
|||||||
raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.")
|
raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.")
|
||||||
auto_wrap_policy = None
|
auto_wrap_policy = None
|
||||||
auto_wrapper_callable = None
|
auto_wrapper_callable = None
|
||||||
|
default_transformer_cls_names_to_wrap = getattr(model, "_no_split_modules", None)
|
||||||
|
fsdp_transformer_layer_cls_to_wrap = self.args.fsdp_config.get(
|
||||||
|
"fsdp_transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap
|
||||||
|
)
|
||||||
|
|
||||||
if self.args.fsdp_config["fsdp_min_num_params"] > 0:
|
if self.args.fsdp_config["fsdp_min_num_params"] > 0:
|
||||||
auto_wrap_policy = functools.partial(
|
auto_wrap_policy = functools.partial(
|
||||||
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
|
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
|
||||||
)
|
)
|
||||||
elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
|
elif fsdp_transformer_layer_cls_to_wrap is not None:
|
||||||
transformer_cls_to_wrap = set()
|
transformer_cls_to_wrap = set()
|
||||||
for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
|
for layer_class in fsdp_transformer_layer_cls_to_wrap:
|
||||||
transformer_cls = get_module_class_from_name(model, layer_class)
|
transformer_cls = get_module_class_from_name(model, layer_class)
|
||||||
if transformer_cls is None:
|
if transformer_cls is None:
|
||||||
raise Exception("Could not find the transformer layer class to wrap in the model.")
|
raise Exception("Could not find the transformer layer class to wrap in the model.")
|
||||||
else:
|
else:
|
||||||
transformer_cls_to_wrap.add(transformer_cls)
|
transformer_cls_to_wrap.add(transformer_cls)
|
||||||
|
|
||||||
auto_wrap_policy = functools.partial(
|
auto_wrap_policy = functools.partial(
|
||||||
transformer_auto_wrap_policy,
|
transformer_auto_wrap_policy,
|
||||||
# Transformer layer class to wrap
|
# Transformer layer class to wrap
|
||||||
@@ -1600,6 +1606,7 @@ class Trainer:
|
|||||||
and self.sharded_ddp != ShardedDDPOption.SIMPLE
|
and self.sharded_ddp != ShardedDDPOption.SIMPLE
|
||||||
or is_sagemaker_mp_enabled()
|
or is_sagemaker_mp_enabled()
|
||||||
or self.fsdp is not None
|
or self.fsdp is not None
|
||||||
|
or self.is_fsdp_enabled
|
||||||
)
|
)
|
||||||
|
|
||||||
# We need to reset the scheduler, as its parameters may be different on subsequent calls
|
# We need to reset the scheduler, as its parameters may be different on subsequent calls
|
||||||
@@ -1631,6 +1638,8 @@ class Trainer:
|
|||||||
use_accelerator_prepare = True if model is self.model else False
|
use_accelerator_prepare = True if model is self.model else False
|
||||||
|
|
||||||
if delay_optimizer_creation:
|
if delay_optimizer_creation:
|
||||||
|
if use_accelerator_prepare:
|
||||||
|
self.model = self.accelerator.prepare(self.model)
|
||||||
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
|
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
|
||||||
|
|
||||||
# prepare using `accelerator` prepare
|
# prepare using `accelerator` prepare
|
||||||
|
|||||||
@@ -1567,6 +1567,7 @@ class TrainingArguments:
|
|||||||
elif fsdp_option == FSDPOption.OFFLOAD:
|
elif fsdp_option == FSDPOption.OFFLOAD:
|
||||||
os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
|
os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
|
||||||
elif fsdp_option == FSDPOption.AUTO_WRAP:
|
elif fsdp_option == FSDPOption.AUTO_WRAP:
|
||||||
|
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0]
|
||||||
if self.fsdp_config["fsdp_min_num_params"] > 0:
|
if self.fsdp_config["fsdp_min_num_params"] > 0:
|
||||||
os.environ["FSDP_MIN_NUM_PARAMS"] = str(self.fsdp_config["fsdp_min_num_params"])
|
os.environ["FSDP_MIN_NUM_PARAMS"] = str(self.fsdp_config["fsdp_min_num_params"])
|
||||||
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[1]
|
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[1]
|
||||||
@@ -1574,7 +1575,6 @@ class TrainingArguments:
|
|||||||
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = ",".join(
|
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = ",".join(
|
||||||
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]
|
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]
|
||||||
)
|
)
|
||||||
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0]
|
|
||||||
prefetch_policy = self.fsdp_config.get("fsdp_backward_prefetch", "NO_PREFETCH")
|
prefetch_policy = self.fsdp_config.get("fsdp_backward_prefetch", "NO_PREFETCH")
|
||||||
os.environ["FSDP_BACKWARD_PREFETCH"] = prefetch_policy.upper()
|
os.environ["FSDP_BACKWARD_PREFETCH"] = prefetch_policy.upper()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user