From 70b07d97cf2c5f61fff55700b65528a1b6845cd2 Mon Sep 17 00:00:00 2001 From: Matthew Hoffman Date: Thu, 10 Oct 2024 11:09:04 -0700 Subject: [PATCH] Default `synced_gpus` to `True` when using `FullyShardedDataParallel` (#33483) * Default synced_gpus to True when using FullyShardedDataParallel Fixes #30228 Related: * https://github.com/pytorch/pytorch/issues/100069 * https://github.com/pytorch/pytorch/issues/123962 Similar to DeepSpeed ZeRO Stage 3, when using FSDP with multiple GPUs and differently sized data per rank, the ranks reach different synchronization points at the same time, leading to deadlock To avoid this, we can automatically set synced_gpus to True if we detect that a PreTrainedModel is being managed by FSDP using _is_fsdp_managed_module, which was added in 2.0.0 for torch.compile: https://github.com/pytorch/pytorch/blob/v2.0.0/torch/distributed/fsdp/_dynamo_utils.py * Remove test file * ruff formatting * ruff format * Update copyright year Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Add test for FSDP-wrapped model generation Before #33483, these tests would have hung for 10 minutes before crashing due to a timeout error * Ruff format * Move argparse import * Remove barrier I think this might cause more problems if one of the workers was killed * Move import into function to decrease load time https://github.com/huggingface/transformers/pull/33483#discussion_r1787972735 * Add test for accelerate and Trainer https://github.com/huggingface/transformers/pull/33483#discussion_r1790309675 * Refactor imports * Ruff format * Use nullcontext --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/generation/utils.py | 33 ++-- src/transformers/integrations/__init__.py | 2 + src/transformers/integrations/fsdp.py | 33 ++++ .../data2vec/modeling_data2vec_audio.py | 7 +- .../models/deprecated/mctct/modeling_mctct.py | 7 +- .../models/hubert/modeling_hubert.py | 13 +- .../models/m2m_100/modeling_m2m_100.py | 13 +- .../models/musicgen/modeling_musicgen.py | 6 +- .../modeling_musicgen_melody.py | 6 +- .../models/nllb_moe/modeling_nllb_moe.py | 7 +- .../seamless_m4t/modeling_seamless_m4t.py | 13 +- .../modeling_seamless_m4t_v2.py | 13 +- src/transformers/models/sew/modeling_sew.py | 7 +- .../models/speecht5/modeling_speecht5.py | 11 +- .../models/unispeech/modeling_unispeech.py | 13 +- .../unispeech_sat/modeling_unispeech_sat.py | 13 +- src/transformers/models/vits/modeling_vits.py | 7 +- .../models/wav2vec2/modeling_wav2vec2.py | 13 +- .../wav2vec2_bert/modeling_wav2vec2_bert.py | 7 +- .../modeling_wav2vec2_conformer.py | 7 +- .../models/wavlm/modeling_wavlm.py | 13 +- .../models/whisper/generation_whisper.py | 3 +- src/transformers/trainer.py | 4 +- src/transformers/trainer_seq2seq.py | 19 ++- tests/generation/test_fsdp.py | 148 ++++++++++++++++++ tests/trainer/test_trainer_fsdp.py | 114 ++++++++++++++ 26 files changed, 434 insertions(+), 98 deletions(-) create mode 100644 src/transformers/integrations/fsdp.py create mode 100644 tests/generation/test_fsdp.py create mode 100644 tests/trainer/test_trainer_fsdp.py diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index b355bbeaa9..5da4878513 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -35,6 +35,7 @@ from ..cache_utils import ( ) from ..configuration_utils import PretrainedConfig from ..integrations.deepspeed import is_deepspeed_zero3_enabled +from ..integrations.fsdp import is_fsdp_managed_module from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput from ..pytorch_utils import isin_mps_friendly from ..tokenization_utils import ExtensionsTrie @@ -1913,9 +1914,9 @@ class GenerationMixin: for constrained generation conditioned on the prefix, as described in [Autoregressive Entity Retrieval](https://arxiv.org/abs/2010.00904). synced_gpus (`bool`, *optional*): - Whether to continue running the while loop until max_length. Unless overridden this flag will be set to - `True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished - generating before other GPUs. Otherwise it'll be set to `False`. + Whether to continue running the while loop until max_length. Unless overridden, this flag will be set + to `True` if using `FullyShardedDataParallel` or DeepSpeed ZeRO Stage 3 with multiple GPUs to avoid + deadlocking if one GPU finishes generating before other GPUs. Otherwise, defaults to `False`. assistant_model (`PreTrainedModel`, *optional*): An assistant model that can be used to accelerate generation. The assistant model must have the exact same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model @@ -1962,10 +1963,7 @@ class GenerationMixin: # 2. Set generation parameters if not already defined if synced_gpus is None: - if is_deepspeed_zero3_enabled() and dist.get_world_size() > 1: - synced_gpus = True - else: - synced_gpus = False + synced_gpus = (is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)) and dist.get_world_size() > 1 logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() @@ -2499,7 +2497,8 @@ class GenerationMixin: generation_config ([`~generation.GenerationConfig`]): The generation configuration to be used as parametrization of the decoding method. synced_gpus (`bool`): - Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). streamer (`BaseStreamer`, *optional*): Streamer object that will be used to stream the generated sequences. Generated tokens are passed through `streamer.put(token_ids)` and the streamer is responsible for any further processing. @@ -2702,7 +2701,8 @@ class GenerationMixin: generation_config ([`~generation.GenerationConfig`]): The generation configuration to be used as parametrization of the decoding method. synced_gpus (`bool`): - Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). streamer (`BaseStreamer`, *optional*): Streamer object that will be used to stream the generated sequences. Generated tokens are passed through `streamer.put(token_ids)` and the streamer is responsible for any further processing. @@ -3105,7 +3105,8 @@ class GenerationMixin: generation_config ([`~generation.GenerationConfig`]): The generation configuration to be used as parametrization of the decoding method. synced_gpus (`bool`): - Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). streamer (`BaseStreamer`, *optional*): Streamer object that will be used to stream the generated sequences. Generated tokens are passed through `streamer.put(token_ids)` and the streamer is responsible for any further processing. @@ -3307,7 +3308,8 @@ class GenerationMixin: generation_config ([`~generation.GenerationConfig`]): The generation configuration to be used as parametrization of the decoding method. synced_gpus (`bool`): - Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). model_kwargs: Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. @@ -3585,7 +3587,8 @@ class GenerationMixin: generation_config ([`~generation.GenerationConfig`]): The generation configuration to be used as parametrization of the decoding method. synced_gpus (`bool`): - Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). model_kwargs: Additional model specific kwargs that will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. @@ -3874,7 +3877,8 @@ class GenerationMixin: generation_config ([`~generation.GenerationConfig`]): The generation configuration to be used as parametrization of the decoding method. synced_gpus (`bool`): - Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). model_kwargs: Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. @@ -4112,7 +4116,8 @@ class GenerationMixin: generation_config ([`~generation.GenerationConfig`]): The generation configuration to be used as parametrization of the decoding method. synced_gpus (`bool`): - Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). streamer (`BaseStreamer`, *optional*): Streamer object that will be used to stream the generated sequences. Generated tokens are passed through `streamer.put(token_ids)` and the streamer is responsible for any further processing. diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index 03e5b4802e..093e0af298 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -54,6 +54,7 @@ _import_structure = { ], "eetq": ["replace_with_eetq_linear"], "fbgemm_fp8": ["FbgemmFp8Linear", "replace_with_fbgemm_fp8_linear"], + "fsdp": ["is_fsdp_managed_module"], "ggml": [ "GGUF_CONFIG_MAPPING", "GGUF_TENSOR_MAPPING", @@ -155,6 +156,7 @@ if TYPE_CHECKING: ) from .eetq import replace_with_eetq_linear from .fbgemm_fp8 import FbgemmFp8Linear, replace_with_fbgemm_fp8_linear + from .fsdp import is_fsdp_managed_module from .ggml import ( GGUF_CONFIG_MAPPING, GGUF_TENSOR_MAPPING, diff --git a/src/transformers/integrations/fsdp.py b/src/transformers/integrations/fsdp.py new file mode 100644 index 0000000000..7bcb11fe74 --- /dev/null +++ b/src/transformers/integrations/fsdp.py @@ -0,0 +1,33 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +from ..utils import is_torch_available + + +if TYPE_CHECKING: + from torch import nn + + +def is_fsdp_managed_module(module: nn.Module) -> bool: + if not is_torch_available(): + return False + + import torch.distributed.fsdp + + return isinstance(module, torch.distributed.fsdp.FullyShardedDataParallel) or getattr( + module, "_is_fsdp_managed_module", False + ) diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index b6ad74e8c8..590509eaf9 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -26,6 +26,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_outputs import ( BaseModelOutput, CausalLMOutput, @@ -826,7 +827,7 @@ class Data2VecAudioEncoder(nn.Module): hidden_states = self.layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) - deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) for layer in self.layers: if output_hidden_states: @@ -836,8 +837,8 @@ class Data2VecAudioEncoder(nn.Module): dropout_probability = torch.rand([]) skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False - if not skip_the_layer or deepspeed_zero3_is_enabled: - # under deepspeed zero3 all gpus must run in sync + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( layer.__call__, diff --git a/src/transformers/models/deprecated/mctct/modeling_mctct.py b/src/transformers/models/deprecated/mctct/modeling_mctct.py index becba11c16..3cbf2cc0bf 100755 --- a/src/transformers/models/deprecated/mctct/modeling_mctct.py +++ b/src/transformers/models/deprecated/mctct/modeling_mctct.py @@ -24,6 +24,7 @@ from torch import nn from ....activations import ACT2FN from ....file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward from ....integrations.deepspeed import is_deepspeed_zero3_enabled +from ....integrations.fsdp import is_fsdp_managed_module from ....modeling_attn_mask_utils import _prepare_4d_attention_mask from ....modeling_outputs import BaseModelOutput, CausalLMOutput from ....modeling_utils import ( @@ -579,7 +580,7 @@ class MCTCTEncoder(MCTCTPreTrainedModel): f"but it is for {head_mask.size()[0]}." ) - deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) @@ -588,8 +589,8 @@ class MCTCTEncoder(MCTCTPreTrainedModel): dropout_probability = torch.rand([]) skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False - if not skip_the_layer or deepspeed_zero3_is_enabled: - # under deepspeed zero3 all gpus must run in sync + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 08760a3f4a..ad21d768e3 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -25,6 +25,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -968,7 +969,7 @@ class HubertEncoder(nn.Module): hidden_states = self.layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) - deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) for layer in self.layers: if output_hidden_states: @@ -978,8 +979,8 @@ class HubertEncoder(nn.Module): dropout_probability = torch.rand([]) skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False - if not skip_the_layer or deepspeed_zero3_is_enabled: - # under deepspeed zero3 all gpus must run in sync + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( layer.__call__, @@ -1055,7 +1056,7 @@ class HubertEncoderStableLayerNorm(nn.Module): hidden_states = hidden_states + position_embeddings hidden_states = self.dropout(hidden_states) - deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) for layer in self.layers: if output_hidden_states: @@ -1065,8 +1066,8 @@ class HubertEncoderStableLayerNorm(nn.Module): dropout_probability = torch.rand([]) skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False - if not skip_the_layer or deepspeed_zero3_is_enabled: - # under deepspeed zero3 all gpus must run in sync + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 9856eec0c2..1588aa28aa 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...generation import GenerationMixin from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, @@ -1055,7 +1056,7 @@ class M2M100Encoder(M2M100PreTrainedModel): f"The head_mask should be specified for {len(self.layers)} layers, but it is for" f" {head_mask.size()[0]}." ) - deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: @@ -1065,8 +1066,8 @@ class M2M100Encoder(M2M100PreTrainedModel): dropout_probability = torch.rand([]) skip_the_layer = True if self.training and (dropout_probability < self.layerdrop) else False - if not skip_the_layer or deepspeed_zero3_is_enabled: - # under deepspeed zero3 all gpus must run in sync + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( @@ -1312,7 +1313,7 @@ class M2M100Decoder(M2M100PreTrainedModel): f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" f" {head_mask.size()[0]}." ) - deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: @@ -1322,8 +1323,8 @@ class M2M100Decoder(M2M100PreTrainedModel): dropout_probability = torch.rand([]) skip_the_layer = True if self.training and (dropout_probability < self.layerdrop) else False - if not skip_the_layer or deepspeed_zero3_is_enabled: - # under deepspeed zero3 all gpus must run in sync + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync past_key_value = past_key_values[idx] if past_key_values is not None else None diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index c9f3b88c68..8d7f6ad3c7 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -1506,7 +1506,8 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel, GenerationMixin): generation config. If a stopping criteria is passed that is already created with the arguments or a generation config an error is thrown. This feature is intended for advanced users. synced_gpus (`bool`, *optional*, defaults to `False`): - Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). streamer (`BaseStreamer`, *optional*): Streamer object that will be used to stream the generated sequences. Generated tokens are passed through `streamer.put(token_ids)` and the streamer is responsible for any further processing. @@ -2513,7 +2514,8 @@ class MusicgenForConditionalGeneration(PreTrainedModel, GenerationMixin): generation config. If a stopping criteria is passed that is already created with the arguments or a generation config an error is thrown. This feature is intended for advanced users. synced_gpus (`bool`, *optional*, defaults to `False`): - Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). streamer (`BaseStreamer`, *optional*): Streamer object that will be used to stream the generated sequences. Generated tokens are passed through `streamer.put(token_ids)` and the streamer is responsible for any further processing. diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index 15cad4072d..96b8d29db8 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -1428,7 +1428,8 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel, GenerationMixin): generation config. If a stopping criteria is passed that is already created with the arguments or a generation config an error is thrown. This feature is intended for advanced users. synced_gpus (`bool`, *optional*, defaults to `False`): - Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). streamer (`BaseStreamer`, *optional*): Streamer object that will be used to stream the generated sequences. Generated tokens are passed through `streamer.put(token_ids)` and the streamer is responsible for any further processing. @@ -2364,7 +2365,8 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel, GenerationMixin): generation config. If a stopping criteria is passed that is already created with the arguments or a generation config an error is thrown. This feature is intended for advanced users. synced_gpus (`bool`, *optional*, defaults to `False`): - Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). streamer (`BaseStreamer`, *optional*): Streamer object that will be used to stream the generated sequences. Generated tokens are passed through `streamer.put(token_ids)` and the streamer is responsible for any further processing. diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index c33844da0f..cedefc4f46 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...generation import GenerationMixin from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( MoEModelOutput, @@ -1370,7 +1371,7 @@ class NllbMoeDecoder(NllbMoePreTrainedModel): f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" f" {head_mask.size()[0]}." ) - deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: @@ -1380,13 +1381,13 @@ class NllbMoeDecoder(NllbMoePreTrainedModel): dropout_probability = torch.rand([]) skip_the_layer = True if self.training and (dropout_probability < self.layerdrop) else False - if not skip_the_layer or deepspeed_zero3_is_enabled: + if not skip_the_layer or synced_gpus: layer_head_mask = head_mask[idx] if head_mask is not None else None cross_attn_layer_head_mask = cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None - # under deepspeed zero3 all gpus must run in sync + # under fsdp or deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index eb606208bf..adc01ec40f 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -27,6 +27,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...generation import GenerationMixin from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -823,7 +824,7 @@ class SeamlessM4TConformerEncoder(nn.Module): else: relative_position_embeddings = None - deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) for i, layer in enumerate(self.layers): if output_hidden_states: @@ -835,8 +836,8 @@ class SeamlessM4TConformerEncoder(nn.Module): skip_the_layer = ( True if self.training and (dropout_probability < self.config.speech_encoder_layerdrop) else False ) - if not skip_the_layer or deepspeed_zero3_is_enabled: - # under deepspeed zero3 all gpus must run in sync + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( layer.__call__, @@ -2863,7 +2864,8 @@ class SeamlessM4TForTextToText(SeamlessM4TPreTrainedModel, GenerationMixin): for constrained generation conditioned on the prefix, as described in [Autoregressive Entity Retrieval](https://arxiv.org/abs/2010.00904). synced_gpus (`bool`, *optional*, defaults to `False`): - Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). kwargs (`Dict[str, Any]`, *optional*): Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be forwarded to the `forward` function of the model. @@ -3149,7 +3151,8 @@ class SeamlessM4TForSpeechToText(SeamlessM4TPreTrainedModel): for constrained generation conditioned on the prefix, as described in [Autoregressive Entity Retrieval](https://arxiv.org/abs/2010.00904). synced_gpus (`bool`, *optional*, defaults to `False`): - Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). kwargs (`Dict[str, Any]`, *optional*): Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be forwarded to the `forward` function of the model. diff --git a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py index da44913e74..21265faa22 100644 --- a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +++ b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py @@ -27,6 +27,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...generation import GenerationMixin from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -770,7 +771,7 @@ class SeamlessM4Tv2ConformerEncoder(nn.Module): hidden_states = self.dropout(hidden_states) - deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) for i, layer in enumerate(self.layers): if output_hidden_states: @@ -782,8 +783,8 @@ class SeamlessM4Tv2ConformerEncoder(nn.Module): skip_the_layer = ( True if self.training and (dropout_probability < self.config.speech_encoder_layerdrop) else False ) - if not skip_the_layer or deepspeed_zero3_is_enabled: - # under deepspeed zero3 all gpus must run in sync + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( layer.__call__, @@ -3121,7 +3122,8 @@ class SeamlessM4Tv2ForTextToText(SeamlessM4Tv2PreTrainedModel, GenerationMixin): for constrained generation conditioned on the prefix, as described in [Autoregressive Entity Retrieval](https://arxiv.org/abs/2010.00904). synced_gpus (`bool`, *optional*, defaults to `False`): - Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). kwargs (`Dict[str, Any]`, *optional*): Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be forwarded to the `forward` function of the model. @@ -3417,7 +3419,8 @@ class SeamlessM4Tv2ForSpeechToText(SeamlessM4Tv2PreTrainedModel): for constrained generation conditioned on the prefix, as described in [Autoregressive Entity Retrieval](https://arxiv.org/abs/2010.00904). synced_gpus (`bool`, *optional*, defaults to `False`): - Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). kwargs (`Dict[str, Any]`, *optional*): Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be forwarded to the `forward` function of the model. diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 5dfe54e24a..8638d93385 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -26,6 +26,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -921,7 +922,7 @@ class SEWEncoder(nn.Module): hidden_states = self.layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) - deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) for layer in self.layers: if output_hidden_states: @@ -931,8 +932,8 @@ class SEWEncoder(nn.Module): dropout_probability = torch.rand([]) skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False - if not skip_the_layer or deepspeed_zero3_is_enabled: - # under deepspeed zero3 all gpus must run in sync + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( layer.__call__, diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index 790e6a74a4..dbe57c01d9 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, L1Loss from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -1318,7 +1319,7 @@ class SpeechT5Encoder(SpeechT5PreTrainedModel): position_bias = self.embed_positions(hidden_states) - deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -1341,8 +1342,8 @@ class SpeechT5Encoder(SpeechT5PreTrainedModel): dropout_probability = torch.rand([]) skip_the_layer = dropout_probability < self.layerdrop - if not skip_the_layer or deepspeed_zero3_is_enabled: - # under deepspeed zero3 all gpus must run in sync + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, @@ -1603,7 +1604,7 @@ class SpeechT5Decoder(SpeechT5PreTrainedModel): encoder_attention_mask, hidden_states.dtype, tgt_len=input_shape[-1] ) - deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) if self.gradient_checkpointing and self.training: if use_cache: @@ -1636,7 +1637,7 @@ class SpeechT5Decoder(SpeechT5PreTrainedModel): if self.training: dropout_probability = torch.rand([]) skip_the_layer = dropout_probability < self.layerdrop - if skip_the_layer and not deepspeed_zero3_is_enabled: + if skip_the_layer and not synced_gpus: continue past_key_value = past_key_values[idx] if past_key_values is not None else None diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index d2779fc200..eab2303247 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -27,6 +27,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput, Wav2Vec2BaseModelOutput from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -1004,7 +1005,7 @@ class UniSpeechEncoder(nn.Module): hidden_states = self.layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) - deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) for layer in self.layers: if output_hidden_states: @@ -1014,8 +1015,8 @@ class UniSpeechEncoder(nn.Module): dropout_probability = torch.rand([]) skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False - if not skip_the_layer or deepspeed_zero3_is_enabled: - # under deepspeed zero3 all gpus must run in sync + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( layer.__call__, @@ -1091,7 +1092,7 @@ class UniSpeechEncoderStableLayerNorm(nn.Module): hidden_states = hidden_states + position_embeddings hidden_states = self.dropout(hidden_states) - deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) for layer in self.layers: if output_hidden_states: @@ -1101,8 +1102,8 @@ class UniSpeechEncoderStableLayerNorm(nn.Module): dropout_probability = torch.rand([]) skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False - if not skip_the_layer or deepspeed_zero3_is_enabled: - # under deepspeed zero3 all gpus must run in sync + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 7bb9843448..31d5071dbe 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -27,6 +27,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_outputs import ( BaseModelOutput, CausalLMOutput, @@ -1021,7 +1022,7 @@ class UniSpeechSatEncoder(nn.Module): hidden_states = self.layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) - deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) for layer in self.layers: if output_hidden_states: @@ -1031,8 +1032,8 @@ class UniSpeechSatEncoder(nn.Module): dropout_probability = torch.rand([]) skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False - if not skip_the_layer or deepspeed_zero3_is_enabled: - # under deepspeed zero3 all gpus must run in sync + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( layer.__call__, @@ -1108,7 +1109,7 @@ class UniSpeechSatEncoderStableLayerNorm(nn.Module): hidden_states = hidden_states + position_embeddings hidden_states = self.dropout(hidden_states) - deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) for layer in self.layers: if output_hidden_states: @@ -1118,8 +1119,8 @@ class UniSpeechSatEncoderStableLayerNorm(nn.Module): dropout_probability = torch.rand([]) skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False - if not skip_the_layer or deepspeed_zero3_is_enabled: - # under deepspeed zero3 all gpus must run in sync + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( diff --git a/src/transformers/models/vits/modeling_vits.py b/src/transformers/models/vits/modeling_vits.py index 23bc8a72f8..66834167d1 100644 --- a/src/transformers/models/vits/modeling_vits.py +++ b/src/transformers/models/vits/modeling_vits.py @@ -25,6 +25,7 @@ from torch import nn from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -1139,7 +1140,7 @@ class VitsEncoder(nn.Module): hidden_states = hidden_states * padding_mask - deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) for encoder_layer in self.layers: if output_hidden_states: @@ -1149,8 +1150,8 @@ class VitsEncoder(nn.Module): dropout_probability = np.random.uniform(0, 1) skip_the_layer = self.training and (dropout_probability < self.layerdrop) - if not skip_the_layer or deepspeed_zero3_is_enabled: - # under deepspeed zero3 all gpus must run in sync + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index d79936ab2b..2648722111 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -27,6 +27,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_outputs import ( BaseModelOutput, CausalLMOutput, @@ -1038,7 +1039,7 @@ class Wav2Vec2Encoder(nn.Module): hidden_states = self.layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) - deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) for layer in self.layers: if output_hidden_states: @@ -1048,8 +1049,8 @@ class Wav2Vec2Encoder(nn.Module): dropout_probability = torch.rand([]) skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False - if not skip_the_layer or deepspeed_zero3_is_enabled: - # under deepspeed zero3 all gpus must run in sync + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( layer.__call__, @@ -1124,7 +1125,7 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module): hidden_states = hidden_states + position_embeddings hidden_states = self.dropout(hidden_states) - deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) for layer in self.layers: if output_hidden_states: @@ -1134,8 +1135,8 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module): dropout_probability = torch.rand([]) skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False - if not skip_the_layer or deepspeed_zero3_is_enabled: - # under deepspeed zero3 all gpus must run in sync + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( diff --git a/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py b/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py index ebbf700bf9..6f1d5576df 100644 --- a/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +++ b/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py @@ -26,6 +26,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -730,7 +731,7 @@ class Wav2Vec2BertEncoder(nn.Module): else: relative_position_embeddings = None - deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) for i, layer in enumerate(self.layers): if output_hidden_states: @@ -740,8 +741,8 @@ class Wav2Vec2BertEncoder(nn.Module): dropout_probability = torch.rand([]) skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False - if not skip_the_layer or deepspeed_zero3_is_enabled: - # under deepspeed zero3 all gpus must run in sync + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( layer.__call__, diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py index c37dd980d4..933bf8f6dc 100644 --- a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py @@ -27,6 +27,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_outputs import ( BaseModelOutput, CausalLMOutput, @@ -893,7 +894,7 @@ class Wav2Vec2ConformerEncoder(nn.Module): else: relative_position_embeddings = None - deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) for i, layer in enumerate(self.layers): if output_hidden_states: @@ -903,8 +904,8 @@ class Wav2Vec2ConformerEncoder(nn.Module): dropout_probability = torch.rand([]) skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False - if not skip_the_layer or deepspeed_zero3_is_enabled: - # under deepspeed zero3 all gpus must run in sync + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( layer.__call__, diff --git a/src/transformers/models/wavlm/modeling_wavlm.py b/src/transformers/models/wavlm/modeling_wavlm.py index fa5fd390f5..4df192fda5 100755 --- a/src/transformers/models/wavlm/modeling_wavlm.py +++ b/src/transformers/models/wavlm/modeling_wavlm.py @@ -27,6 +27,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_outputs import ( BaseModelOutput, CausalLMOutput, @@ -697,7 +698,7 @@ class WavLMEncoder(nn.Module): hidden_states = self.layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) - deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) position_bias = None for i, layer in enumerate(self.layers): @@ -708,8 +709,8 @@ class WavLMEncoder(nn.Module): dropout_probability = torch.rand([]) skip_the_layer = self.training and i > 0 and (dropout_probability < self.config.layerdrop) - if not skip_the_layer or deepspeed_zero3_is_enabled: - # under deepspeed zero3 all gpus must run in sync + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( layer.__call__, @@ -781,7 +782,7 @@ class WavLMEncoderStableLayerNorm(nn.Module): hidden_states = hidden_states + position_embeddings hidden_states = self.dropout(hidden_states) - deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) position_bias = None for i, layer in enumerate(self.layers): @@ -792,8 +793,8 @@ class WavLMEncoderStableLayerNorm(nn.Module): dropout_probability = torch.rand([]) skip_the_layer = self.training and i > 0 and (dropout_probability < self.config.layerdrop) - if not skip_the_layer or deepspeed_zero3_is_enabled: - # under deepspeed zero3 all gpus must run in sync + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index a3de765137..0ecdcb4dbd 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -357,7 +357,8 @@ class WhisperGenerationMixin(GenerationMixin): for constrained generation conditioned on the prefix, as described in [Autoregressive Entity Retrieval](https://arxiv.org/abs/2010.00904). synced_gpus (`bool`, *optional*, defaults to `False`): - Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). return_timestamps (`bool`, *optional*): Whether to return the timestamps with the text. This enables the `WhisperTimestampsLogitsProcessor`. task (`str`, *optional*): diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 9c4ddd268a..20b9f6dad2 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -4031,7 +4031,7 @@ class Trainer: start_time = time.time() model = ( self.accelerator.prepare(model) - if self.is_deepspeed_enabled + if self.is_deepspeed_enabled or self.is_fsdp_enabled else self.accelerator.prepare_model(model, evaluation_mode=True) ) self.model_preparation_time = round(time.time() - start_time, 4) @@ -4634,7 +4634,7 @@ class Trainer: if len(self.accelerator._models) == 0 and model is self.model: model = ( self.accelerator.prepare(model) - if self.is_deepspeed_enabled + if self.is_deepspeed_enabled or self.is_fsdp_enabled else self.accelerator.prepare_model(model, evaluation_mode=True) ) diff --git a/src/transformers/trainer_seq2seq.py b/src/transformers/trainer_seq2seq.py index 03fc484fd8..07d0571e44 100644 --- a/src/transformers/trainer_seq2seq.py +++ b/src/transformers/trainer_seq2seq.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import warnings from copy import deepcopy from pathlib import Path @@ -19,10 +20,12 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Un import torch from torch import nn +from torch.distributed.fsdp import FullyShardedDataParallel from torch.utils.data import Dataset from .generation.configuration_utils import GenerationConfig from .integrations.deepspeed import is_deepspeed_zero3_enabled +from .integrations.fsdp import is_fsdp_managed_module from .trainer import Trainer from .utils import is_datasets_available, logging from .utils.deprecation import deprecate_kwarg @@ -303,10 +306,8 @@ class Seq2SeqTrainer(Trainer): if "max_length" in gen_kwargs and gen_kwargs["max_length"] is None: gen_kwargs.pop("max_length") - default_synced_gpus = True if is_deepspeed_zero3_enabled() else False - gen_kwargs["synced_gpus"] = ( - gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus - ) + default_synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self.model) + gen_kwargs["synced_gpus"] = gen_kwargs.get("synced_gpus", default_synced_gpus) generation_inputs = inputs.copy() # If the `decoder_input_ids` was created from `labels`, evict the former, so that the model can freely generate @@ -319,7 +320,15 @@ class Seq2SeqTrainer(Trainer): generation_inputs = { k: v for k, v in inputs.items() if k not in ("decoder_input_ids", "decoder_attention_mask") } - generated_tokens = self.model.generate(**generation_inputs, **gen_kwargs) + + summon_full_params_context = ( + FullyShardedDataParallel.summon_full_params(self.model) + if isinstance(self.model, FullyShardedDataParallel) + else contextlib.nullcontext() + ) + + with summon_full_params_context: + generated_tokens = self.model.generate(**generation_inputs, **gen_kwargs) # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop # TODO: remove this hack when the legacy code that initializes generation_config from a model config is diff --git a/tests/generation/test_fsdp.py b/tests/generation/test_fsdp.py new file mode 100644 index 0000000000..904ccdea63 --- /dev/null +++ b/tests/generation/test_fsdp.py @@ -0,0 +1,148 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from typing import Any, Callable + +from transformers import is_torch_available +from transformers.testing_utils import ( + TestCasePlus, + execute_subprocess_async, + get_torch_dist_unique_port, + require_torch_multi_gpu, +) + + +if is_torch_available(): + import functools + + import torch + import torch.distributed + from torch.distributed._composable.fsdp import fully_shard, register_fsdp_forward_method + from torch.distributed.device_mesh import init_device_mesh + from torch.distributed.fsdp import FullyShardedDataParallel + from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy + + from transformers import AutoModelForCausalLM, AutoTokenizer + from transformers.models.gpt2.modeling_gpt2 import GPT2Block + + data = 4 * [ + "Hello world!", + "The quick brown fox jumps over the lazy dog.", + ] + + def manage_process_group(func: Callable[..., Any]) -> Callable[..., Any]: + """Manage the creation and destruction of the distributed process group for the wrapped function.""" + + def wrapped(*args: Any, **kwargs: Any) -> Any: + torch.distributed.init_process_group(world_size=torch.cuda.device_count()) + try: + return func(*args, **kwargs) + finally: + torch.distributed.destroy_process_group() + + return wrapped + + @manage_process_group + def fsdp_generate(): + torch.cuda.set_device(device := torch.device(rank := torch.distributed.get_rank())) + + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(device) + + fsdp_model = FullyShardedDataParallel( + model, + auto_wrap_policy=functools.partial(transformer_auto_wrap_policy, transformer_layer_cls={GPT2Block}), + limit_all_gathers=True, + use_orig_params=True, + ) + + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + batch = tokenizer(data[rank], return_tensors="pt", return_attention_mask=True).to(device) + + with FullyShardedDataParallel.summon_full_params(fsdp_model): + _ = fsdp_model.module.generate( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + max_length=30, + ) + + @manage_process_group + def fsdp2_generate(): + torch.cuda.set_device(device := torch.device(rank := torch.distributed.get_rank())) + + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(device) + + mesh = init_device_mesh("cuda", (torch.distributed.get_world_size(),)) + for submodule in model.modules(): + if isinstance(submodule, GPT2Block): + fully_shard(submodule, mesh=mesh) + fully_shard(model, mesh=mesh) + + register_fsdp_forward_method(model, "generate") + + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + batch = tokenizer(data[rank], return_tensors="pt", return_attention_mask=True).to(device) + + _ = model.generate( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + max_length=30, + ) + + +class TestFSDPGeneration(TestCasePlus): + @require_torch_multi_gpu + def test_fsdp_generate(self): + distributed_args = f"""--nproc_per_node={torch.cuda.device_count()} + --master_port={get_torch_dist_unique_port()} + {self.test_file_dir}/test_fsdp.py + """.split() + args = "--fsdp".split() + cmd = ["torchrun"] + distributed_args + args + execute_subprocess_async(cmd, env=self.get_env()) + # successful return here == success - any errors would have caused an error in the sub-call + + @require_torch_multi_gpu + def test_fsdp2_generate(self): + distributed_args = f"""--nproc_per_node={torch.cuda.device_count()} + --master_port={get_torch_dist_unique_port()} + {self.test_file_dir}/test_fsdp.py + """.split() + args = "--fsdp2".split() + cmd = ["torchrun"] + distributed_args + args + execute_subprocess_async(cmd, env=self.get_env()) + # successful return here == success - any errors would have caused an error in the sub-call + + +if __name__ == "__main__": + # The script below is meant to be run under torch.distributed, on a machine with multiple GPUs: + # + # PYTHONPATH="src" python -m torch.distributed.run --nproc_per_node 2 --output_dir output_dir ./tests/generation/test_fsdp.py --fsdp + + class CLIArgs(argparse.Namespace): + fsdp: bool + fsdp2: bool + + parser = argparse.ArgumentParser() + group = parser.add_mutually_exclusive_group() + group.add_argument("--fsdp", action="store_true") + group.add_argument("--fsdp2", action="store_true") + args = parser.parse_args(namespace=CLIArgs()) + + if args.fsdp: + fsdp_generate() + elif args.fsdp2: + fsdp2_generate() + else: + raise ValueError("Missing test selection") diff --git a/tests/trainer/test_trainer_fsdp.py b/tests/trainer/test_trainer_fsdp.py new file mode 100644 index 0000000000..994a82a8db --- /dev/null +++ b/tests/trainer/test_trainer_fsdp.py @@ -0,0 +1,114 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict + +from transformers import is_torch_available +from transformers.testing_utils import ( + TestCasePlus, + execute_subprocess_async, + get_torch_dist_unique_port, + require_accelerate, + require_torch_multi_gpu, +) + + +if is_torch_available(): + import torch + import torch.distributed + import torch.utils.data + + from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + DataCollatorForSeq2Seq, + EvalPrediction, + GenerationConfig, + HfArgumentParser, + PreTrainedTokenizerBase, + Seq2SeqTrainer, + Seq2SeqTrainingArguments, + ) + + class DummyTextDataset(torch.utils.data.Dataset[str]): + def __init__(self, tokenizer: PreTrainedTokenizerBase) -> None: + data = 4 * [ + "Hello world!", + "The quick brown fox jumps over the lazy dog.", + ] + self.data = [ + {k: v.squeeze(0) for k, v in tokenizer(item, return_tensors="pt", return_attention_mask=True).items()} + for item in data + ] + for item in self.data: + item["labels"] = item["input_ids"] + + def __len__(self) -> int: + return len(self.data) + + def __getitem__(self, i: int) -> str: + return self.data[i] + + +class TestFSDPTrainer(TestCasePlus): + @require_accelerate + @require_torch_multi_gpu + def test_trainer(self): + output_dir = self.get_auto_remove_tmp_dir() + cmd = [ + "accelerate", + "launch", + "--use_fsdp", + "--main_process_port", + f"{get_torch_dist_unique_port()}", + "--num_processes", + f"{torch.cuda.device_count()}", + "--fsdp_transformer_layer_cls_to_wrap", + "GPT2Block", + f"{self.test_file_dir}/test_trainer_fsdp.py", + "--output_dir", + f"{output_dir}", + "--report_to", + "none", + ] + execute_subprocess_async(cmd, env=self.get_env()) + # successful return here == success - any errors would have caused an error in the sub-call + + +if __name__ == "__main__": + parser = HfArgumentParser((Seq2SeqTrainingArguments,)) + training_args = parser.parse_args_into_dataclasses()[0] + training_args.per_device_eval_batch_size = 1 + training_args.use_legacy_prediction_loop = False + training_args.predict_with_generate = True + training_args.generation_config = GenerationConfig(max_length=30) + + pretrained_model_name = "hf-internal-testing/tiny-random-gpt2" + tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name) + tokenizer.pad_token = tokenizer.eos_token + device = torch.device(torch.distributed.get_rank()) + model = AutoModelForCausalLM.from_pretrained(pretrained_model_name).to(device) + + def compute_metrics(p: EvalPrediction) -> Dict[str, bool]: + return {"accuracy": (p.predictions == p.label_ids).mean()} + + trainer = Seq2SeqTrainer( + model=model, + args=training_args, + data_collator=DataCollatorForSeq2Seq(tokenizer, model), + eval_dataset=DummyTextDataset(tokenizer), + compute_metrics=compute_metrics, + ) + + metrics = trainer.evaluate()