Default synced_gpus to True when using FullyShardedDataParallel (#33483)

* Default synced_gpus to True when using FullyShardedDataParallel

Fixes #30228

Related:

* https://github.com/pytorch/pytorch/issues/100069
* https://github.com/pytorch/pytorch/issues/123962

Similar to DeepSpeed ZeRO Stage 3, when using FSDP with multiple GPUs and differently sized data per rank, the ranks reach different synchronization points at the same time, leading to deadlock

To avoid this, we can automatically set synced_gpus to True if we detect that a PreTrainedModel is being managed by FSDP using _is_fsdp_managed_module, which was added in 2.0.0 for torch.compile: https://github.com/pytorch/pytorch/blob/v2.0.0/torch/distributed/fsdp/_dynamo_utils.py

* Remove test file

* ruff formatting

* ruff format

* Update copyright year

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Add test for FSDP-wrapped model generation

Before #33483, these tests would have hung for 10 minutes before crashing due to a timeout error

* Ruff format

* Move argparse import

* Remove barrier

I think this might cause more problems if one of the workers was killed

* Move import into function to decrease load time

https://github.com/huggingface/transformers/pull/33483#discussion_r1787972735

* Add test for accelerate and Trainer

https://github.com/huggingface/transformers/pull/33483#discussion_r1790309675

* Refactor imports

* Ruff format

* Use nullcontext

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Matthew Hoffman
2024-10-10 11:09:04 -07:00
committed by GitHub
parent 24b82f3cd5
commit 70b07d97cf
26 changed files with 434 additions and 98 deletions

View File

@@ -35,6 +35,7 @@ from ..cache_utils import (
)
from ..configuration_utils import PretrainedConfig
from ..integrations.deepspeed import is_deepspeed_zero3_enabled
from ..integrations.fsdp import is_fsdp_managed_module
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
from ..pytorch_utils import isin_mps_friendly
from ..tokenization_utils import ExtensionsTrie
@@ -1913,9 +1914,9 @@ class GenerationMixin:
for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
Retrieval](https://arxiv.org/abs/2010.00904).
synced_gpus (`bool`, *optional*):
Whether to continue running the while loop until max_length. Unless overridden this flag will be set to
`True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished
generating before other GPUs. Otherwise it'll be set to `False`.
Whether to continue running the while loop until max_length. Unless overridden, this flag will be set
to `True` if using `FullyShardedDataParallel` or DeepSpeed ZeRO Stage 3 with multiple GPUs to avoid
deadlocking if one GPU finishes generating before other GPUs. Otherwise, defaults to `False`.
assistant_model (`PreTrainedModel`, *optional*):
An assistant model that can be used to accelerate generation. The assistant model must have the exact
same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model
@@ -1962,10 +1963,7 @@ class GenerationMixin:
# 2. Set generation parameters if not already defined
if synced_gpus is None:
if is_deepspeed_zero3_enabled() and dist.get_world_size() > 1:
synced_gpus = True
else:
synced_gpus = False
synced_gpus = (is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)) and dist.get_world_size() > 1
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
@@ -2499,7 +2497,8 @@ class GenerationMixin:
generation_config ([`~generation.GenerationConfig`]):
The generation configuration to be used as parametrization of the decoding method.
synced_gpus (`bool`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
@@ -2702,7 +2701,8 @@ class GenerationMixin:
generation_config ([`~generation.GenerationConfig`]):
The generation configuration to be used as parametrization of the decoding method.
synced_gpus (`bool`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
@@ -3105,7 +3105,8 @@ class GenerationMixin:
generation_config ([`~generation.GenerationConfig`]):
The generation configuration to be used as parametrization of the decoding method.
synced_gpus (`bool`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
@@ -3307,7 +3308,8 @@ class GenerationMixin:
generation_config ([`~generation.GenerationConfig`]):
The generation configuration to be used as parametrization of the decoding method.
synced_gpus (`bool`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
an encoder-decoder model the kwargs should include `encoder_outputs`.
@@ -3585,7 +3587,8 @@ class GenerationMixin:
generation_config ([`~generation.GenerationConfig`]):
The generation configuration to be used as parametrization of the decoding method.
synced_gpus (`bool`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
model_kwargs:
Additional model specific kwargs that will be forwarded to the `forward` function of the model. If
model is an encoder-decoder model the kwargs should include `encoder_outputs`.
@@ -3874,7 +3877,8 @@ class GenerationMixin:
generation_config ([`~generation.GenerationConfig`]):
The generation configuration to be used as parametrization of the decoding method.
synced_gpus (`bool`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
an encoder-decoder model the kwargs should include `encoder_outputs`.
@@ -4112,7 +4116,8 @@ class GenerationMixin:
generation_config ([`~generation.GenerationConfig`]):
The generation configuration to be used as parametrization of the decoding method.
synced_gpus (`bool`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.

View File

@@ -54,6 +54,7 @@ _import_structure = {
],
"eetq": ["replace_with_eetq_linear"],
"fbgemm_fp8": ["FbgemmFp8Linear", "replace_with_fbgemm_fp8_linear"],
"fsdp": ["is_fsdp_managed_module"],
"ggml": [
"GGUF_CONFIG_MAPPING",
"GGUF_TENSOR_MAPPING",
@@ -155,6 +156,7 @@ if TYPE_CHECKING:
)
from .eetq import replace_with_eetq_linear
from .fbgemm_fp8 import FbgemmFp8Linear, replace_with_fbgemm_fp8_linear
from .fsdp import is_fsdp_managed_module
from .ggml import (
GGUF_CONFIG_MAPPING,
GGUF_TENSOR_MAPPING,

View File

@@ -0,0 +1,33 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING
from ..utils import is_torch_available
if TYPE_CHECKING:
from torch import nn
def is_fsdp_managed_module(module: nn.Module) -> bool:
if not is_torch_available():
return False
import torch.distributed.fsdp
return isinstance(module, torch.distributed.fsdp.FullyShardedDataParallel) or getattr(
module, "_is_fsdp_managed_module", False
)

View File

@@ -26,6 +26,7 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...integrations.fsdp import is_fsdp_managed_module
from ...modeling_outputs import (
BaseModelOutput,
CausalLMOutput,
@@ -826,7 +827,7 @@ class Data2VecAudioEncoder(nn.Module):
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states)
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
for layer in self.layers:
if output_hidden_states:
@@ -836,8 +837,8 @@ class Data2VecAudioEncoder(nn.Module):
dropout_probability = torch.rand([])
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer.__call__,

View File

@@ -24,6 +24,7 @@ from torch import nn
from ....activations import ACT2FN
from ....file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ....integrations.deepspeed import is_deepspeed_zero3_enabled
from ....integrations.fsdp import is_fsdp_managed_module
from ....modeling_attn_mask_utils import _prepare_4d_attention_mask
from ....modeling_outputs import BaseModelOutput, CausalLMOutput
from ....modeling_utils import (
@@ -579,7 +580,7 @@ class MCTCTEncoder(MCTCTPreTrainedModel):
f"but it is for {head_mask.size()[0]}."
)
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
@@ -588,8 +589,8 @@ class MCTCTEncoder(MCTCTPreTrainedModel):
dropout_probability = torch.rand([])
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,

View File

@@ -25,6 +25,7 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...integrations.fsdp import is_fsdp_managed_module
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...utils import (
@@ -968,7 +969,7 @@ class HubertEncoder(nn.Module):
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states)
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
for layer in self.layers:
if output_hidden_states:
@@ -978,8 +979,8 @@ class HubertEncoder(nn.Module):
dropout_probability = torch.rand([])
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer.__call__,
@@ -1055,7 +1056,7 @@ class HubertEncoderStableLayerNorm(nn.Module):
hidden_states = hidden_states + position_embeddings
hidden_states = self.dropout(hidden_states)
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
for layer in self.layers:
if output_hidden_states:
@@ -1065,8 +1066,8 @@ class HubertEncoderStableLayerNorm(nn.Module):
dropout_probability = torch.rand([])
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync
# XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(

View File

@@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...integrations.fsdp import is_fsdp_managed_module
from ...modeling_attn_mask_utils import (
_prepare_4d_attention_mask,
_prepare_4d_attention_mask_for_sdpa,
@@ -1055,7 +1056,7 @@ class M2M100Encoder(M2M100PreTrainedModel):
f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
f" {head_mask.size()[0]}."
)
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
@@ -1065,8 +1066,8 @@ class M2M100Encoder(M2M100PreTrainedModel):
dropout_probability = torch.rand([])
skip_the_layer = True if self.training and (dropout_probability < self.layerdrop) else False
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
@@ -1312,7 +1313,7 @@ class M2M100Decoder(M2M100PreTrainedModel):
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
f" {head_mask.size()[0]}."
)
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
@@ -1322,8 +1323,8 @@ class M2M100Decoder(M2M100PreTrainedModel):
dropout_probability = torch.rand([])
skip_the_layer = True if self.training and (dropout_probability < self.layerdrop) else False
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync
past_key_value = past_key_values[idx] if past_key_values is not None else None

View File

@@ -1506,7 +1506,8 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel, GenerationMixin):
generation config. If a stopping criteria is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
@@ -2513,7 +2514,8 @@ class MusicgenForConditionalGeneration(PreTrainedModel, GenerationMixin):
generation config. If a stopping criteria is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.

View File

@@ -1428,7 +1428,8 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel, GenerationMixin):
generation config. If a stopping criteria is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
@@ -2364,7 +2365,8 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel, GenerationMixin):
generation config. If a stopping criteria is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.

View File

@@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...integrations.fsdp import is_fsdp_managed_module
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_outputs import (
MoEModelOutput,
@@ -1370,7 +1371,7 @@ class NllbMoeDecoder(NllbMoePreTrainedModel):
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
f" {head_mask.size()[0]}."
)
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
@@ -1380,13 +1381,13 @@ class NllbMoeDecoder(NllbMoePreTrainedModel):
dropout_probability = torch.rand([])
skip_the_layer = True if self.training and (dropout_probability < self.layerdrop) else False
if not skip_the_layer or deepspeed_zero3_is_enabled:
if not skip_the_layer or synced_gpus:
layer_head_mask = head_mask[idx] if head_mask is not None else None
cross_attn_layer_head_mask = cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
past_key_value = past_key_values[idx] if past_key_values is not None else None
# under deepspeed zero3 all gpus must run in sync
# under fsdp or deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(

View File

@@ -27,6 +27,7 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...integrations.fsdp import is_fsdp_managed_module
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_outputs import (
BaseModelOutput,
@@ -823,7 +824,7 @@ class SeamlessM4TConformerEncoder(nn.Module):
else:
relative_position_embeddings = None
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
for i, layer in enumerate(self.layers):
if output_hidden_states:
@@ -835,8 +836,8 @@ class SeamlessM4TConformerEncoder(nn.Module):
skip_the_layer = (
True if self.training and (dropout_probability < self.config.speech_encoder_layerdrop) else False
)
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer.__call__,
@@ -2863,7 +2864,8 @@ class SeamlessM4TForTextToText(SeamlessM4TPreTrainedModel, GenerationMixin):
for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
Retrieval](https://arxiv.org/abs/2010.00904).
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
kwargs (`Dict[str, Any]`, *optional*):
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model.
@@ -3149,7 +3151,8 @@ class SeamlessM4TForSpeechToText(SeamlessM4TPreTrainedModel):
for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
Retrieval](https://arxiv.org/abs/2010.00904).
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
kwargs (`Dict[str, Any]`, *optional*):
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model.

View File

@@ -27,6 +27,7 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...integrations.fsdp import is_fsdp_managed_module
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_outputs import (
BaseModelOutput,
@@ -770,7 +771,7 @@ class SeamlessM4Tv2ConformerEncoder(nn.Module):
hidden_states = self.dropout(hidden_states)
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
for i, layer in enumerate(self.layers):
if output_hidden_states:
@@ -782,8 +783,8 @@ class SeamlessM4Tv2ConformerEncoder(nn.Module):
skip_the_layer = (
True if self.training and (dropout_probability < self.config.speech_encoder_layerdrop) else False
)
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer.__call__,
@@ -3121,7 +3122,8 @@ class SeamlessM4Tv2ForTextToText(SeamlessM4Tv2PreTrainedModel, GenerationMixin):
for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
Retrieval](https://arxiv.org/abs/2010.00904).
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
kwargs (`Dict[str, Any]`, *optional*):
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model.
@@ -3417,7 +3419,8 @@ class SeamlessM4Tv2ForSpeechToText(SeamlessM4Tv2PreTrainedModel):
for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
Retrieval](https://arxiv.org/abs/2010.00904).
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
kwargs (`Dict[str, Any]`, *optional*):
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model.

View File

@@ -26,6 +26,7 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...integrations.fsdp import is_fsdp_managed_module
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...utils import (
@@ -921,7 +922,7 @@ class SEWEncoder(nn.Module):
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states)
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
for layer in self.layers:
if output_hidden_states:
@@ -931,8 +932,8 @@ class SEWEncoder(nn.Module):
dropout_probability = torch.rand([])
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer.__call__,

View File

@@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, L1Loss
from ...activations import ACT2FN
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...integrations.fsdp import is_fsdp_managed_module
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_outputs import (
BaseModelOutput,
@@ -1318,7 +1319,7 @@ class SpeechT5Encoder(SpeechT5PreTrainedModel):
position_bias = self.embed_positions(hidden_states)
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
@@ -1341,8 +1342,8 @@ class SpeechT5Encoder(SpeechT5PreTrainedModel):
dropout_probability = torch.rand([])
skip_the_layer = dropout_probability < self.layerdrop
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
@@ -1603,7 +1604,7 @@ class SpeechT5Decoder(SpeechT5PreTrainedModel):
encoder_attention_mask, hidden_states.dtype, tgt_len=input_shape[-1]
)
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
if self.gradient_checkpointing and self.training:
if use_cache:
@@ -1636,7 +1637,7 @@ class SpeechT5Decoder(SpeechT5PreTrainedModel):
if self.training:
dropout_probability = torch.rand([])
skip_the_layer = dropout_probability < self.layerdrop
if skip_the_layer and not deepspeed_zero3_is_enabled:
if skip_the_layer and not synced_gpus:
continue
past_key_value = past_key_values[idx] if past_key_values is not None else None

View File

@@ -27,6 +27,7 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...integrations.fsdp import is_fsdp_managed_module
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput, Wav2Vec2BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import (
@@ -1004,7 +1005,7 @@ class UniSpeechEncoder(nn.Module):
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states)
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
for layer in self.layers:
if output_hidden_states:
@@ -1014,8 +1015,8 @@ class UniSpeechEncoder(nn.Module):
dropout_probability = torch.rand([])
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer.__call__,
@@ -1091,7 +1092,7 @@ class UniSpeechEncoderStableLayerNorm(nn.Module):
hidden_states = hidden_states + position_embeddings
hidden_states = self.dropout(hidden_states)
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
for layer in self.layers:
if output_hidden_states:
@@ -1101,8 +1102,8 @@ class UniSpeechEncoderStableLayerNorm(nn.Module):
dropout_probability = torch.rand([])
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync
# XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(

View File

@@ -27,6 +27,7 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...integrations.fsdp import is_fsdp_managed_module
from ...modeling_outputs import (
BaseModelOutput,
CausalLMOutput,
@@ -1021,7 +1022,7 @@ class UniSpeechSatEncoder(nn.Module):
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states)
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
for layer in self.layers:
if output_hidden_states:
@@ -1031,8 +1032,8 @@ class UniSpeechSatEncoder(nn.Module):
dropout_probability = torch.rand([])
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer.__call__,
@@ -1108,7 +1109,7 @@ class UniSpeechSatEncoderStableLayerNorm(nn.Module):
hidden_states = hidden_states + position_embeddings
hidden_states = self.dropout(hidden_states)
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
for layer in self.layers:
if output_hidden_states:
@@ -1118,8 +1119,8 @@ class UniSpeechSatEncoderStableLayerNorm(nn.Module):
dropout_probability = torch.rand([])
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync
# XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(

View File

@@ -25,6 +25,7 @@ from torch import nn
from ...activations import ACT2FN
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...integrations.fsdp import is_fsdp_managed_module
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_outputs import (
BaseModelOutput,
@@ -1139,7 +1140,7 @@ class VitsEncoder(nn.Module):
hidden_states = hidden_states * padding_mask
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
for encoder_layer in self.layers:
if output_hidden_states:
@@ -1149,8 +1150,8 @@ class VitsEncoder(nn.Module):
dropout_probability = np.random.uniform(0, 1)
skip_the_layer = self.training and (dropout_probability < self.layerdrop)
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,

View File

@@ -27,6 +27,7 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...integrations.fsdp import is_fsdp_managed_module
from ...modeling_outputs import (
BaseModelOutput,
CausalLMOutput,
@@ -1038,7 +1039,7 @@ class Wav2Vec2Encoder(nn.Module):
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states)
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
for layer in self.layers:
if output_hidden_states:
@@ -1048,8 +1049,8 @@ class Wav2Vec2Encoder(nn.Module):
dropout_probability = torch.rand([])
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer.__call__,
@@ -1124,7 +1125,7 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module):
hidden_states = hidden_states + position_embeddings
hidden_states = self.dropout(hidden_states)
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
for layer in self.layers:
if output_hidden_states:
@@ -1134,8 +1135,8 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module):
dropout_probability = torch.rand([])
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync
# XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(

View File

@@ -26,6 +26,7 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...integrations.fsdp import is_fsdp_managed_module
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_outputs import (
BaseModelOutput,
@@ -730,7 +731,7 @@ class Wav2Vec2BertEncoder(nn.Module):
else:
relative_position_embeddings = None
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
for i, layer in enumerate(self.layers):
if output_hidden_states:
@@ -740,8 +741,8 @@ class Wav2Vec2BertEncoder(nn.Module):
dropout_probability = torch.rand([])
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer.__call__,

View File

@@ -27,6 +27,7 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...integrations.fsdp import is_fsdp_managed_module
from ...modeling_outputs import (
BaseModelOutput,
CausalLMOutput,
@@ -893,7 +894,7 @@ class Wav2Vec2ConformerEncoder(nn.Module):
else:
relative_position_embeddings = None
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
for i, layer in enumerate(self.layers):
if output_hidden_states:
@@ -903,8 +904,8 @@ class Wav2Vec2ConformerEncoder(nn.Module):
dropout_probability = torch.rand([])
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer.__call__,

View File

@@ -27,6 +27,7 @@ from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...integrations.fsdp import is_fsdp_managed_module
from ...modeling_outputs import (
BaseModelOutput,
CausalLMOutput,
@@ -697,7 +698,7 @@ class WavLMEncoder(nn.Module):
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states)
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
position_bias = None
for i, layer in enumerate(self.layers):
@@ -708,8 +709,8 @@ class WavLMEncoder(nn.Module):
dropout_probability = torch.rand([])
skip_the_layer = self.training and i > 0 and (dropout_probability < self.config.layerdrop)
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer.__call__,
@@ -781,7 +782,7 @@ class WavLMEncoderStableLayerNorm(nn.Module):
hidden_states = hidden_states + position_embeddings
hidden_states = self.dropout(hidden_states)
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
position_bias = None
for i, layer in enumerate(self.layers):
@@ -792,8 +793,8 @@ class WavLMEncoderStableLayerNorm(nn.Module):
dropout_probability = torch.rand([])
skip_the_layer = self.training and i > 0 and (dropout_probability < self.config.layerdrop)
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if not skip_the_layer or synced_gpus:
# under fsdp or deepspeed zero3 all gpus must run in sync
# XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(

View File

@@ -357,7 +357,8 @@ class WhisperGenerationMixin(GenerationMixin):
for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
Retrieval](https://arxiv.org/abs/2010.00904).
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
return_timestamps (`bool`, *optional*):
Whether to return the timestamps with the text. This enables the `WhisperTimestampsLogitsProcessor`.
task (`str`, *optional*):

View File

@@ -4031,7 +4031,7 @@ class Trainer:
start_time = time.time()
model = (
self.accelerator.prepare(model)
if self.is_deepspeed_enabled
if self.is_deepspeed_enabled or self.is_fsdp_enabled
else self.accelerator.prepare_model(model, evaluation_mode=True)
)
self.model_preparation_time = round(time.time() - start_time, 4)
@@ -4634,7 +4634,7 @@ class Trainer:
if len(self.accelerator._models) == 0 and model is self.model:
model = (
self.accelerator.prepare(model)
if self.is_deepspeed_enabled
if self.is_deepspeed_enabled or self.is_fsdp_enabled
else self.accelerator.prepare_model(model, evaluation_mode=True)
)

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import warnings
from copy import deepcopy
from pathlib import Path
@@ -19,10 +20,12 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Un
import torch
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel
from torch.utils.data import Dataset
from .generation.configuration_utils import GenerationConfig
from .integrations.deepspeed import is_deepspeed_zero3_enabled
from .integrations.fsdp import is_fsdp_managed_module
from .trainer import Trainer
from .utils import is_datasets_available, logging
from .utils.deprecation import deprecate_kwarg
@@ -303,10 +306,8 @@ class Seq2SeqTrainer(Trainer):
if "max_length" in gen_kwargs and gen_kwargs["max_length"] is None:
gen_kwargs.pop("max_length")
default_synced_gpus = True if is_deepspeed_zero3_enabled() else False
gen_kwargs["synced_gpus"] = (
gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus
)
default_synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self.model)
gen_kwargs["synced_gpus"] = gen_kwargs.get("synced_gpus", default_synced_gpus)
generation_inputs = inputs.copy()
# If the `decoder_input_ids` was created from `labels`, evict the former, so that the model can freely generate
@@ -319,6 +320,14 @@ class Seq2SeqTrainer(Trainer):
generation_inputs = {
k: v for k, v in inputs.items() if k not in ("decoder_input_ids", "decoder_attention_mask")
}
summon_full_params_context = (
FullyShardedDataParallel.summon_full_params(self.model)
if isinstance(self.model, FullyShardedDataParallel)
else contextlib.nullcontext()
)
with summon_full_params_context:
generated_tokens = self.model.generate(**generation_inputs, **gen_kwargs)
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop

View File

@@ -0,0 +1,148 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from typing import Any, Callable
from transformers import is_torch_available
from transformers.testing_utils import (
TestCasePlus,
execute_subprocess_async,
get_torch_dist_unique_port,
require_torch_multi_gpu,
)
if is_torch_available():
import functools
import torch
import torch.distributed
from torch.distributed._composable.fsdp import fully_shard, register_fsdp_forward_method
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.gpt2.modeling_gpt2 import GPT2Block
data = 4 * [
"Hello world!",
"The quick brown fox jumps over the lazy dog.",
]
def manage_process_group(func: Callable[..., Any]) -> Callable[..., Any]:
"""Manage the creation and destruction of the distributed process group for the wrapped function."""
def wrapped(*args: Any, **kwargs: Any) -> Any:
torch.distributed.init_process_group(world_size=torch.cuda.device_count())
try:
return func(*args, **kwargs)
finally:
torch.distributed.destroy_process_group()
return wrapped
@manage_process_group
def fsdp_generate():
torch.cuda.set_device(device := torch.device(rank := torch.distributed.get_rank()))
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(device)
fsdp_model = FullyShardedDataParallel(
model,
auto_wrap_policy=functools.partial(transformer_auto_wrap_policy, transformer_layer_cls={GPT2Block}),
limit_all_gathers=True,
use_orig_params=True,
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
batch = tokenizer(data[rank], return_tensors="pt", return_attention_mask=True).to(device)
with FullyShardedDataParallel.summon_full_params(fsdp_model):
_ = fsdp_model.module.generate(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
max_length=30,
)
@manage_process_group
def fsdp2_generate():
torch.cuda.set_device(device := torch.device(rank := torch.distributed.get_rank()))
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(device)
mesh = init_device_mesh("cuda", (torch.distributed.get_world_size(),))
for submodule in model.modules():
if isinstance(submodule, GPT2Block):
fully_shard(submodule, mesh=mesh)
fully_shard(model, mesh=mesh)
register_fsdp_forward_method(model, "generate")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
batch = tokenizer(data[rank], return_tensors="pt", return_attention_mask=True).to(device)
_ = model.generate(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
max_length=30,
)
class TestFSDPGeneration(TestCasePlus):
@require_torch_multi_gpu
def test_fsdp_generate(self):
distributed_args = f"""--nproc_per_node={torch.cuda.device_count()}
--master_port={get_torch_dist_unique_port()}
{self.test_file_dir}/test_fsdp.py
""".split()
args = "--fsdp".split()
cmd = ["torchrun"] + distributed_args + args
execute_subprocess_async(cmd, env=self.get_env())
# successful return here == success - any errors would have caused an error in the sub-call
@require_torch_multi_gpu
def test_fsdp2_generate(self):
distributed_args = f"""--nproc_per_node={torch.cuda.device_count()}
--master_port={get_torch_dist_unique_port()}
{self.test_file_dir}/test_fsdp.py
""".split()
args = "--fsdp2".split()
cmd = ["torchrun"] + distributed_args + args
execute_subprocess_async(cmd, env=self.get_env())
# successful return here == success - any errors would have caused an error in the sub-call
if __name__ == "__main__":
# The script below is meant to be run under torch.distributed, on a machine with multiple GPUs:
#
# PYTHONPATH="src" python -m torch.distributed.run --nproc_per_node 2 --output_dir output_dir ./tests/generation/test_fsdp.py --fsdp
class CLIArgs(argparse.Namespace):
fsdp: bool
fsdp2: bool
parser = argparse.ArgumentParser()
group = parser.add_mutually_exclusive_group()
group.add_argument("--fsdp", action="store_true")
group.add_argument("--fsdp2", action="store_true")
args = parser.parse_args(namespace=CLIArgs())
if args.fsdp:
fsdp_generate()
elif args.fsdp2:
fsdp2_generate()
else:
raise ValueError("Missing test selection")

View File

@@ -0,0 +1,114 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
from transformers import is_torch_available
from transformers.testing_utils import (
TestCasePlus,
execute_subprocess_async,
get_torch_dist_unique_port,
require_accelerate,
require_torch_multi_gpu,
)
if is_torch_available():
import torch
import torch.distributed
import torch.utils.data
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
DataCollatorForSeq2Seq,
EvalPrediction,
GenerationConfig,
HfArgumentParser,
PreTrainedTokenizerBase,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
)
class DummyTextDataset(torch.utils.data.Dataset[str]):
def __init__(self, tokenizer: PreTrainedTokenizerBase) -> None:
data = 4 * [
"Hello world!",
"The quick brown fox jumps over the lazy dog.",
]
self.data = [
{k: v.squeeze(0) for k, v in tokenizer(item, return_tensors="pt", return_attention_mask=True).items()}
for item in data
]
for item in self.data:
item["labels"] = item["input_ids"]
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, i: int) -> str:
return self.data[i]
class TestFSDPTrainer(TestCasePlus):
@require_accelerate
@require_torch_multi_gpu
def test_trainer(self):
output_dir = self.get_auto_remove_tmp_dir()
cmd = [
"accelerate",
"launch",
"--use_fsdp",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"--num_processes",
f"{torch.cuda.device_count()}",
"--fsdp_transformer_layer_cls_to_wrap",
"GPT2Block",
f"{self.test_file_dir}/test_trainer_fsdp.py",
"--output_dir",
f"{output_dir}",
"--report_to",
"none",
]
execute_subprocess_async(cmd, env=self.get_env())
# successful return here == success - any errors would have caused an error in the sub-call
if __name__ == "__main__":
parser = HfArgumentParser((Seq2SeqTrainingArguments,))
training_args = parser.parse_args_into_dataclasses()[0]
training_args.per_device_eval_batch_size = 1
training_args.use_legacy_prediction_loop = False
training_args.predict_with_generate = True
training_args.generation_config = GenerationConfig(max_length=30)
pretrained_model_name = "hf-internal-testing/tiny-random-gpt2"
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
tokenizer.pad_token = tokenizer.eos_token
device = torch.device(torch.distributed.get_rank())
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name).to(device)
def compute_metrics(p: EvalPrediction) -> Dict[str, bool]:
return {"accuracy": (p.predictions == p.label_ids).mean()}
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
data_collator=DataCollatorForSeq2Seq(tokenizer, model),
eval_dataset=DummyTextDataset(tokenizer),
compute_metrics=compute_metrics,
)
metrics = trainer.evaluate()