Default synced_gpus to True when using FullyShardedDataParallel (#33483)
* Default synced_gpus to True when using FullyShardedDataParallel Fixes #30228 Related: * https://github.com/pytorch/pytorch/issues/100069 * https://github.com/pytorch/pytorch/issues/123962 Similar to DeepSpeed ZeRO Stage 3, when using FSDP with multiple GPUs and differently sized data per rank, the ranks reach different synchronization points at the same time, leading to deadlock To avoid this, we can automatically set synced_gpus to True if we detect that a PreTrainedModel is being managed by FSDP using _is_fsdp_managed_module, which was added in 2.0.0 for torch.compile: https://github.com/pytorch/pytorch/blob/v2.0.0/torch/distributed/fsdp/_dynamo_utils.py * Remove test file * ruff formatting * ruff format * Update copyright year Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Add test for FSDP-wrapped model generation Before #33483, these tests would have hung for 10 minutes before crashing due to a timeout error * Ruff format * Move argparse import * Remove barrier I think this might cause more problems if one of the workers was killed * Move import into function to decrease load time https://github.com/huggingface/transformers/pull/33483#discussion_r1787972735 * Add test for accelerate and Trainer https://github.com/huggingface/transformers/pull/33483#discussion_r1790309675 * Refactor imports * Ruff format * Use nullcontext --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -35,6 +35,7 @@ from ..cache_utils import (
|
||||
)
|
||||
from ..configuration_utils import PretrainedConfig
|
||||
from ..integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from ..integrations.fsdp import is_fsdp_managed_module
|
||||
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
|
||||
from ..pytorch_utils import isin_mps_friendly
|
||||
from ..tokenization_utils import ExtensionsTrie
|
||||
@@ -1913,9 +1914,9 @@ class GenerationMixin:
|
||||
for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
|
||||
Retrieval](https://arxiv.org/abs/2010.00904).
|
||||
synced_gpus (`bool`, *optional*):
|
||||
Whether to continue running the while loop until max_length. Unless overridden this flag will be set to
|
||||
`True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished
|
||||
generating before other GPUs. Otherwise it'll be set to `False`.
|
||||
Whether to continue running the while loop until max_length. Unless overridden, this flag will be set
|
||||
to `True` if using `FullyShardedDataParallel` or DeepSpeed ZeRO Stage 3 with multiple GPUs to avoid
|
||||
deadlocking if one GPU finishes generating before other GPUs. Otherwise, defaults to `False`.
|
||||
assistant_model (`PreTrainedModel`, *optional*):
|
||||
An assistant model that can be used to accelerate generation. The assistant model must have the exact
|
||||
same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model
|
||||
@@ -1962,10 +1963,7 @@ class GenerationMixin:
|
||||
|
||||
# 2. Set generation parameters if not already defined
|
||||
if synced_gpus is None:
|
||||
if is_deepspeed_zero3_enabled() and dist.get_world_size() > 1:
|
||||
synced_gpus = True
|
||||
else:
|
||||
synced_gpus = False
|
||||
synced_gpus = (is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)) and dist.get_world_size() > 1
|
||||
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
@@ -2499,7 +2497,8 @@ class GenerationMixin:
|
||||
generation_config ([`~generation.GenerationConfig`]):
|
||||
The generation configuration to be used as parametrization of the decoding method.
|
||||
synced_gpus (`bool`):
|
||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
|
||||
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
|
||||
streamer (`BaseStreamer`, *optional*):
|
||||
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
||||
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
||||
@@ -2702,7 +2701,8 @@ class GenerationMixin:
|
||||
generation_config ([`~generation.GenerationConfig`]):
|
||||
The generation configuration to be used as parametrization of the decoding method.
|
||||
synced_gpus (`bool`):
|
||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
|
||||
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
|
||||
streamer (`BaseStreamer`, *optional*):
|
||||
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
||||
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
||||
@@ -3105,7 +3105,8 @@ class GenerationMixin:
|
||||
generation_config ([`~generation.GenerationConfig`]):
|
||||
The generation configuration to be used as parametrization of the decoding method.
|
||||
synced_gpus (`bool`):
|
||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
|
||||
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
|
||||
streamer (`BaseStreamer`, *optional*):
|
||||
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
||||
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
||||
@@ -3307,7 +3308,8 @@ class GenerationMixin:
|
||||
generation_config ([`~generation.GenerationConfig`]):
|
||||
The generation configuration to be used as parametrization of the decoding method.
|
||||
synced_gpus (`bool`):
|
||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
|
||||
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
|
||||
model_kwargs:
|
||||
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
|
||||
an encoder-decoder model the kwargs should include `encoder_outputs`.
|
||||
@@ -3585,7 +3587,8 @@ class GenerationMixin:
|
||||
generation_config ([`~generation.GenerationConfig`]):
|
||||
The generation configuration to be used as parametrization of the decoding method.
|
||||
synced_gpus (`bool`):
|
||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
|
||||
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
|
||||
model_kwargs:
|
||||
Additional model specific kwargs that will be forwarded to the `forward` function of the model. If
|
||||
model is an encoder-decoder model the kwargs should include `encoder_outputs`.
|
||||
@@ -3874,7 +3877,8 @@ class GenerationMixin:
|
||||
generation_config ([`~generation.GenerationConfig`]):
|
||||
The generation configuration to be used as parametrization of the decoding method.
|
||||
synced_gpus (`bool`):
|
||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
|
||||
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
|
||||
model_kwargs:
|
||||
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
|
||||
an encoder-decoder model the kwargs should include `encoder_outputs`.
|
||||
@@ -4112,7 +4116,8 @@ class GenerationMixin:
|
||||
generation_config ([`~generation.GenerationConfig`]):
|
||||
The generation configuration to be used as parametrization of the decoding method.
|
||||
synced_gpus (`bool`):
|
||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
|
||||
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
|
||||
streamer (`BaseStreamer`, *optional*):
|
||||
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
||||
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
||||
|
||||
@@ -54,6 +54,7 @@ _import_structure = {
|
||||
],
|
||||
"eetq": ["replace_with_eetq_linear"],
|
||||
"fbgemm_fp8": ["FbgemmFp8Linear", "replace_with_fbgemm_fp8_linear"],
|
||||
"fsdp": ["is_fsdp_managed_module"],
|
||||
"ggml": [
|
||||
"GGUF_CONFIG_MAPPING",
|
||||
"GGUF_TENSOR_MAPPING",
|
||||
@@ -155,6 +156,7 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from .eetq import replace_with_eetq_linear
|
||||
from .fbgemm_fp8 import FbgemmFp8Linear, replace_with_fbgemm_fp8_linear
|
||||
from .fsdp import is_fsdp_managed_module
|
||||
from .ggml import (
|
||||
GGUF_CONFIG_MAPPING,
|
||||
GGUF_TENSOR_MAPPING,
|
||||
|
||||
33
src/transformers/integrations/fsdp.py
Normal file
33
src/transformers/integrations/fsdp.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..utils import is_torch_available
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch import nn
|
||||
|
||||
|
||||
def is_fsdp_managed_module(module: nn.Module) -> bool:
|
||||
if not is_torch_available():
|
||||
return False
|
||||
|
||||
import torch.distributed.fsdp
|
||||
|
||||
return isinstance(module, torch.distributed.fsdp.FullyShardedDataParallel) or getattr(
|
||||
module, "_is_fsdp_managed_module", False
|
||||
)
|
||||
@@ -26,6 +26,7 @@ from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from ...integrations.fsdp import is_fsdp_managed_module
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
CausalLMOutput,
|
||||
@@ -826,7 +827,7 @@ class Data2VecAudioEncoder(nn.Module):
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
|
||||
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
|
||||
|
||||
for layer in self.layers:
|
||||
if output_hidden_states:
|
||||
@@ -836,8 +837,8 @@ class Data2VecAudioEncoder(nn.Module):
|
||||
dropout_probability = torch.rand([])
|
||||
|
||||
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
|
||||
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
||||
# under deepspeed zero3 all gpus must run in sync
|
||||
if not skip_the_layer or synced_gpus:
|
||||
# under fsdp or deepspeed zero3 all gpus must run in sync
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer.__call__,
|
||||
|
||||
@@ -24,6 +24,7 @@ from torch import nn
|
||||
from ....activations import ACT2FN
|
||||
from ....file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
from ....integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from ....integrations.fsdp import is_fsdp_managed_module
|
||||
from ....modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||
from ....modeling_outputs import BaseModelOutput, CausalLMOutput
|
||||
from ....modeling_utils import (
|
||||
@@ -579,7 +580,7 @@ class MCTCTEncoder(MCTCTPreTrainedModel):
|
||||
f"but it is for {head_mask.size()[0]}."
|
||||
)
|
||||
|
||||
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
|
||||
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
@@ -588,8 +589,8 @@ class MCTCTEncoder(MCTCTPreTrainedModel):
|
||||
dropout_probability = torch.rand([])
|
||||
|
||||
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
|
||||
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
||||
# under deepspeed zero3 all gpus must run in sync
|
||||
if not skip_the_layer or synced_gpus:
|
||||
# under fsdp or deepspeed zero3 all gpus must run in sync
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
|
||||
@@ -25,6 +25,7 @@ from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from ...integrations.fsdp import is_fsdp_managed_module
|
||||
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
@@ -968,7 +969,7 @@ class HubertEncoder(nn.Module):
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
|
||||
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
|
||||
|
||||
for layer in self.layers:
|
||||
if output_hidden_states:
|
||||
@@ -978,8 +979,8 @@ class HubertEncoder(nn.Module):
|
||||
dropout_probability = torch.rand([])
|
||||
|
||||
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
|
||||
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
||||
# under deepspeed zero3 all gpus must run in sync
|
||||
if not skip_the_layer or synced_gpus:
|
||||
# under fsdp or deepspeed zero3 all gpus must run in sync
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer.__call__,
|
||||
@@ -1055,7 +1056,7 @@ class HubertEncoderStableLayerNorm(nn.Module):
|
||||
hidden_states = hidden_states + position_embeddings
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
|
||||
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
|
||||
|
||||
for layer in self.layers:
|
||||
if output_hidden_states:
|
||||
@@ -1065,8 +1066,8 @@ class HubertEncoderStableLayerNorm(nn.Module):
|
||||
dropout_probability = torch.rand([])
|
||||
|
||||
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
|
||||
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
||||
# under deepspeed zero3 all gpus must run in sync
|
||||
if not skip_the_layer or synced_gpus:
|
||||
# under fsdp or deepspeed zero3 all gpus must run in sync
|
||||
# XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
|
||||
@@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss
|
||||
from ...activations import ACT2FN
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from ...integrations.fsdp import is_fsdp_managed_module
|
||||
from ...modeling_attn_mask_utils import (
|
||||
_prepare_4d_attention_mask,
|
||||
_prepare_4d_attention_mask_for_sdpa,
|
||||
@@ -1055,7 +1056,7 @@ class M2M100Encoder(M2M100PreTrainedModel):
|
||||
f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
|
||||
f" {head_mask.size()[0]}."
|
||||
)
|
||||
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
|
||||
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
|
||||
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
@@ -1065,8 +1066,8 @@ class M2M100Encoder(M2M100PreTrainedModel):
|
||||
dropout_probability = torch.rand([])
|
||||
|
||||
skip_the_layer = True if self.training and (dropout_probability < self.layerdrop) else False
|
||||
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
||||
# under deepspeed zero3 all gpus must run in sync
|
||||
if not skip_the_layer or synced_gpus:
|
||||
# under fsdp or deepspeed zero3 all gpus must run in sync
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
@@ -1312,7 +1313,7 @@ class M2M100Decoder(M2M100PreTrainedModel):
|
||||
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
|
||||
f" {head_mask.size()[0]}."
|
||||
)
|
||||
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
|
||||
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
@@ -1322,8 +1323,8 @@ class M2M100Decoder(M2M100PreTrainedModel):
|
||||
dropout_probability = torch.rand([])
|
||||
|
||||
skip_the_layer = True if self.training and (dropout_probability < self.layerdrop) else False
|
||||
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
||||
# under deepspeed zero3 all gpus must run in sync
|
||||
if not skip_the_layer or synced_gpus:
|
||||
# under fsdp or deepspeed zero3 all gpus must run in sync
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
|
||||
@@ -1506,7 +1506,8 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel, GenerationMixin):
|
||||
generation config. If a stopping criteria is passed that is already created with the arguments or a
|
||||
generation config an error is thrown. This feature is intended for advanced users.
|
||||
synced_gpus (`bool`, *optional*, defaults to `False`):
|
||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
|
||||
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
|
||||
streamer (`BaseStreamer`, *optional*):
|
||||
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
||||
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
||||
@@ -2513,7 +2514,8 @@ class MusicgenForConditionalGeneration(PreTrainedModel, GenerationMixin):
|
||||
generation config. If a stopping criteria is passed that is already created with the arguments or a
|
||||
generation config an error is thrown. This feature is intended for advanced users.
|
||||
synced_gpus (`bool`, *optional*, defaults to `False`):
|
||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
|
||||
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
|
||||
streamer (`BaseStreamer`, *optional*):
|
||||
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
||||
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
||||
|
||||
@@ -1428,7 +1428,8 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel, GenerationMixin):
|
||||
generation config. If a stopping criteria is passed that is already created with the arguments or a
|
||||
generation config an error is thrown. This feature is intended for advanced users.
|
||||
synced_gpus (`bool`, *optional*, defaults to `False`):
|
||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
|
||||
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
|
||||
streamer (`BaseStreamer`, *optional*):
|
||||
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
||||
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
||||
@@ -2364,7 +2365,8 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel, GenerationMixin):
|
||||
generation config. If a stopping criteria is passed that is already created with the arguments or a
|
||||
generation config an error is thrown. This feature is intended for advanced users.
|
||||
synced_gpus (`bool`, *optional*, defaults to `False`):
|
||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
|
||||
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
|
||||
streamer (`BaseStreamer`, *optional*):
|
||||
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
||||
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
||||
|
||||
@@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss
|
||||
from ...activations import ACT2FN
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from ...integrations.fsdp import is_fsdp_managed_module
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
|
||||
from ...modeling_outputs import (
|
||||
MoEModelOutput,
|
||||
@@ -1370,7 +1371,7 @@ class NllbMoeDecoder(NllbMoePreTrainedModel):
|
||||
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
|
||||
f" {head_mask.size()[0]}."
|
||||
)
|
||||
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
|
||||
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
@@ -1380,13 +1381,13 @@ class NllbMoeDecoder(NllbMoePreTrainedModel):
|
||||
dropout_probability = torch.rand([])
|
||||
|
||||
skip_the_layer = True if self.training and (dropout_probability < self.layerdrop) else False
|
||||
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
||||
if not skip_the_layer or synced_gpus:
|
||||
layer_head_mask = head_mask[idx] if head_mask is not None else None
|
||||
cross_attn_layer_head_mask = cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
# under deepspeed zero3 all gpus must run in sync
|
||||
# under fsdp or deepspeed zero3 all gpus must run in sync
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
|
||||
@@ -27,6 +27,7 @@ from torch.nn import CrossEntropyLoss
|
||||
from ...activations import ACT2FN
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from ...integrations.fsdp import is_fsdp_managed_module
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
@@ -823,7 +824,7 @@ class SeamlessM4TConformerEncoder(nn.Module):
|
||||
else:
|
||||
relative_position_embeddings = None
|
||||
|
||||
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
|
||||
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
@@ -835,8 +836,8 @@ class SeamlessM4TConformerEncoder(nn.Module):
|
||||
skip_the_layer = (
|
||||
True if self.training and (dropout_probability < self.config.speech_encoder_layerdrop) else False
|
||||
)
|
||||
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
||||
# under deepspeed zero3 all gpus must run in sync
|
||||
if not skip_the_layer or synced_gpus:
|
||||
# under fsdp or deepspeed zero3 all gpus must run in sync
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer.__call__,
|
||||
@@ -2863,7 +2864,8 @@ class SeamlessM4TForTextToText(SeamlessM4TPreTrainedModel, GenerationMixin):
|
||||
for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
|
||||
Retrieval](https://arxiv.org/abs/2010.00904).
|
||||
synced_gpus (`bool`, *optional*, defaults to `False`):
|
||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
|
||||
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
|
||||
forwarded to the `forward` function of the model.
|
||||
@@ -3149,7 +3151,8 @@ class SeamlessM4TForSpeechToText(SeamlessM4TPreTrainedModel):
|
||||
for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
|
||||
Retrieval](https://arxiv.org/abs/2010.00904).
|
||||
synced_gpus (`bool`, *optional*, defaults to `False`):
|
||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
|
||||
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
|
||||
forwarded to the `forward` function of the model.
|
||||
|
||||
@@ -27,6 +27,7 @@ from torch.nn import CrossEntropyLoss
|
||||
from ...activations import ACT2FN
|
||||
from ...generation import GenerationMixin
|
||||
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from ...integrations.fsdp import is_fsdp_managed_module
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
@@ -770,7 +771,7 @@ class SeamlessM4Tv2ConformerEncoder(nn.Module):
|
||||
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
|
||||
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
@@ -782,8 +783,8 @@ class SeamlessM4Tv2ConformerEncoder(nn.Module):
|
||||
skip_the_layer = (
|
||||
True if self.training and (dropout_probability < self.config.speech_encoder_layerdrop) else False
|
||||
)
|
||||
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
||||
# under deepspeed zero3 all gpus must run in sync
|
||||
if not skip_the_layer or synced_gpus:
|
||||
# under fsdp or deepspeed zero3 all gpus must run in sync
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer.__call__,
|
||||
@@ -3121,7 +3122,8 @@ class SeamlessM4Tv2ForTextToText(SeamlessM4Tv2PreTrainedModel, GenerationMixin):
|
||||
for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
|
||||
Retrieval](https://arxiv.org/abs/2010.00904).
|
||||
synced_gpus (`bool`, *optional*, defaults to `False`):
|
||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
|
||||
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
|
||||
forwarded to the `forward` function of the model.
|
||||
@@ -3417,7 +3419,8 @@ class SeamlessM4Tv2ForSpeechToText(SeamlessM4Tv2PreTrainedModel):
|
||||
for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
|
||||
Retrieval](https://arxiv.org/abs/2010.00904).
|
||||
synced_gpus (`bool`, *optional*, defaults to `False`):
|
||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
|
||||
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
|
||||
forwarded to the `forward` function of the model.
|
||||
|
||||
@@ -26,6 +26,7 @@ from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from ...integrations.fsdp import is_fsdp_managed_module
|
||||
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
@@ -921,7 +922,7 @@ class SEWEncoder(nn.Module):
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
|
||||
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
|
||||
|
||||
for layer in self.layers:
|
||||
if output_hidden_states:
|
||||
@@ -931,8 +932,8 @@ class SEWEncoder(nn.Module):
|
||||
dropout_probability = torch.rand([])
|
||||
|
||||
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
|
||||
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
||||
# under deepspeed zero3 all gpus must run in sync
|
||||
if not skip_the_layer or synced_gpus:
|
||||
# under fsdp or deepspeed zero3 all gpus must run in sync
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer.__call__,
|
||||
|
||||
@@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, L1Loss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from ...integrations.fsdp import is_fsdp_managed_module
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
@@ -1318,7 +1319,7 @@ class SpeechT5Encoder(SpeechT5PreTrainedModel):
|
||||
|
||||
position_bias = self.embed_positions(hidden_states)
|
||||
|
||||
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
|
||||
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
|
||||
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
@@ -1341,8 +1342,8 @@ class SpeechT5Encoder(SpeechT5PreTrainedModel):
|
||||
dropout_probability = torch.rand([])
|
||||
skip_the_layer = dropout_probability < self.layerdrop
|
||||
|
||||
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
||||
# under deepspeed zero3 all gpus must run in sync
|
||||
if not skip_the_layer or synced_gpus:
|
||||
# under fsdp or deepspeed zero3 all gpus must run in sync
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
@@ -1603,7 +1604,7 @@ class SpeechT5Decoder(SpeechT5PreTrainedModel):
|
||||
encoder_attention_mask, hidden_states.dtype, tgt_len=input_shape[-1]
|
||||
)
|
||||
|
||||
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
|
||||
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
@@ -1636,7 +1637,7 @@ class SpeechT5Decoder(SpeechT5PreTrainedModel):
|
||||
if self.training:
|
||||
dropout_probability = torch.rand([])
|
||||
skip_the_layer = dropout_probability < self.layerdrop
|
||||
if skip_the_layer and not deepspeed_zero3_is_enabled:
|
||||
if skip_the_layer and not synced_gpus:
|
||||
continue
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
@@ -27,6 +27,7 @@ from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from ...integrations.fsdp import is_fsdp_managed_module
|
||||
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput, Wav2Vec2BaseModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
@@ -1004,7 +1005,7 @@ class UniSpeechEncoder(nn.Module):
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
|
||||
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
|
||||
|
||||
for layer in self.layers:
|
||||
if output_hidden_states:
|
||||
@@ -1014,8 +1015,8 @@ class UniSpeechEncoder(nn.Module):
|
||||
dropout_probability = torch.rand([])
|
||||
|
||||
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
|
||||
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
||||
# under deepspeed zero3 all gpus must run in sync
|
||||
if not skip_the_layer or synced_gpus:
|
||||
# under fsdp or deepspeed zero3 all gpus must run in sync
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer.__call__,
|
||||
@@ -1091,7 +1092,7 @@ class UniSpeechEncoderStableLayerNorm(nn.Module):
|
||||
hidden_states = hidden_states + position_embeddings
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
|
||||
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
|
||||
|
||||
for layer in self.layers:
|
||||
if output_hidden_states:
|
||||
@@ -1101,8 +1102,8 @@ class UniSpeechEncoderStableLayerNorm(nn.Module):
|
||||
dropout_probability = torch.rand([])
|
||||
|
||||
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
|
||||
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
||||
# under deepspeed zero3 all gpus must run in sync
|
||||
if not skip_the_layer or synced_gpus:
|
||||
# under fsdp or deepspeed zero3 all gpus must run in sync
|
||||
# XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
|
||||
@@ -27,6 +27,7 @@ from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from ...integrations.fsdp import is_fsdp_managed_module
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
CausalLMOutput,
|
||||
@@ -1021,7 +1022,7 @@ class UniSpeechSatEncoder(nn.Module):
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
|
||||
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
|
||||
|
||||
for layer in self.layers:
|
||||
if output_hidden_states:
|
||||
@@ -1031,8 +1032,8 @@ class UniSpeechSatEncoder(nn.Module):
|
||||
dropout_probability = torch.rand([])
|
||||
|
||||
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
|
||||
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
||||
# under deepspeed zero3 all gpus must run in sync
|
||||
if not skip_the_layer or synced_gpus:
|
||||
# under fsdp or deepspeed zero3 all gpus must run in sync
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer.__call__,
|
||||
@@ -1108,7 +1109,7 @@ class UniSpeechSatEncoderStableLayerNorm(nn.Module):
|
||||
hidden_states = hidden_states + position_embeddings
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
|
||||
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
|
||||
|
||||
for layer in self.layers:
|
||||
if output_hidden_states:
|
||||
@@ -1118,8 +1119,8 @@ class UniSpeechSatEncoderStableLayerNorm(nn.Module):
|
||||
dropout_probability = torch.rand([])
|
||||
|
||||
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
|
||||
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
||||
# under deepspeed zero3 all gpus must run in sync
|
||||
if not skip_the_layer or synced_gpus:
|
||||
# under fsdp or deepspeed zero3 all gpus must run in sync
|
||||
# XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
|
||||
@@ -25,6 +25,7 @@ from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from ...integrations.fsdp import is_fsdp_managed_module
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
@@ -1139,7 +1140,7 @@ class VitsEncoder(nn.Module):
|
||||
|
||||
hidden_states = hidden_states * padding_mask
|
||||
|
||||
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
|
||||
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
|
||||
|
||||
for encoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
@@ -1149,8 +1150,8 @@ class VitsEncoder(nn.Module):
|
||||
dropout_probability = np.random.uniform(0, 1)
|
||||
|
||||
skip_the_layer = self.training and (dropout_probability < self.layerdrop)
|
||||
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
||||
# under deepspeed zero3 all gpus must run in sync
|
||||
if not skip_the_layer or synced_gpus:
|
||||
# under fsdp or deepspeed zero3 all gpus must run in sync
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
|
||||
@@ -27,6 +27,7 @@ from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from ...integrations.fsdp import is_fsdp_managed_module
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
CausalLMOutput,
|
||||
@@ -1038,7 +1039,7 @@ class Wav2Vec2Encoder(nn.Module):
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
|
||||
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
|
||||
|
||||
for layer in self.layers:
|
||||
if output_hidden_states:
|
||||
@@ -1048,8 +1049,8 @@ class Wav2Vec2Encoder(nn.Module):
|
||||
dropout_probability = torch.rand([])
|
||||
|
||||
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
|
||||
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
||||
# under deepspeed zero3 all gpus must run in sync
|
||||
if not skip_the_layer or synced_gpus:
|
||||
# under fsdp or deepspeed zero3 all gpus must run in sync
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer.__call__,
|
||||
@@ -1124,7 +1125,7 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module):
|
||||
hidden_states = hidden_states + position_embeddings
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
|
||||
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
|
||||
|
||||
for layer in self.layers:
|
||||
if output_hidden_states:
|
||||
@@ -1134,8 +1135,8 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module):
|
||||
dropout_probability = torch.rand([])
|
||||
|
||||
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
|
||||
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
||||
# under deepspeed zero3 all gpus must run in sync
|
||||
if not skip_the_layer or synced_gpus:
|
||||
# under fsdp or deepspeed zero3 all gpus must run in sync
|
||||
# XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
|
||||
@@ -26,6 +26,7 @@ from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from ...integrations.fsdp import is_fsdp_managed_module
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
@@ -730,7 +731,7 @@ class Wav2Vec2BertEncoder(nn.Module):
|
||||
else:
|
||||
relative_position_embeddings = None
|
||||
|
||||
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
|
||||
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
@@ -740,8 +741,8 @@ class Wav2Vec2BertEncoder(nn.Module):
|
||||
dropout_probability = torch.rand([])
|
||||
|
||||
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
|
||||
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
||||
# under deepspeed zero3 all gpus must run in sync
|
||||
if not skip_the_layer or synced_gpus:
|
||||
# under fsdp or deepspeed zero3 all gpus must run in sync
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer.__call__,
|
||||
|
||||
@@ -27,6 +27,7 @@ from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from ...integrations.fsdp import is_fsdp_managed_module
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
CausalLMOutput,
|
||||
@@ -893,7 +894,7 @@ class Wav2Vec2ConformerEncoder(nn.Module):
|
||||
else:
|
||||
relative_position_embeddings = None
|
||||
|
||||
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
|
||||
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
@@ -903,8 +904,8 @@ class Wav2Vec2ConformerEncoder(nn.Module):
|
||||
dropout_probability = torch.rand([])
|
||||
|
||||
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
|
||||
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
||||
# under deepspeed zero3 all gpus must run in sync
|
||||
if not skip_the_layer or synced_gpus:
|
||||
# under fsdp or deepspeed zero3 all gpus must run in sync
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer.__call__,
|
||||
|
||||
@@ -27,6 +27,7 @@ from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from ...integrations.fsdp import is_fsdp_managed_module
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
CausalLMOutput,
|
||||
@@ -697,7 +698,7 @@ class WavLMEncoder(nn.Module):
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
|
||||
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
|
||||
position_bias = None
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
@@ -708,8 +709,8 @@ class WavLMEncoder(nn.Module):
|
||||
dropout_probability = torch.rand([])
|
||||
|
||||
skip_the_layer = self.training and i > 0 and (dropout_probability < self.config.layerdrop)
|
||||
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
||||
# under deepspeed zero3 all gpus must run in sync
|
||||
if not skip_the_layer or synced_gpus:
|
||||
# under fsdp or deepspeed zero3 all gpus must run in sync
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer.__call__,
|
||||
@@ -781,7 +782,7 @@ class WavLMEncoderStableLayerNorm(nn.Module):
|
||||
hidden_states = hidden_states + position_embeddings
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
|
||||
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
|
||||
position_bias = None
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
@@ -792,8 +793,8 @@ class WavLMEncoderStableLayerNorm(nn.Module):
|
||||
dropout_probability = torch.rand([])
|
||||
|
||||
skip_the_layer = self.training and i > 0 and (dropout_probability < self.config.layerdrop)
|
||||
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
||||
# under deepspeed zero3 all gpus must run in sync
|
||||
if not skip_the_layer or synced_gpus:
|
||||
# under fsdp or deepspeed zero3 all gpus must run in sync
|
||||
# XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
|
||||
@@ -357,7 +357,8 @@ class WhisperGenerationMixin(GenerationMixin):
|
||||
for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
|
||||
Retrieval](https://arxiv.org/abs/2010.00904).
|
||||
synced_gpus (`bool`, *optional*, defaults to `False`):
|
||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
|
||||
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
|
||||
return_timestamps (`bool`, *optional*):
|
||||
Whether to return the timestamps with the text. This enables the `WhisperTimestampsLogitsProcessor`.
|
||||
task (`str`, *optional*):
|
||||
|
||||
@@ -4031,7 +4031,7 @@ class Trainer:
|
||||
start_time = time.time()
|
||||
model = (
|
||||
self.accelerator.prepare(model)
|
||||
if self.is_deepspeed_enabled
|
||||
if self.is_deepspeed_enabled or self.is_fsdp_enabled
|
||||
else self.accelerator.prepare_model(model, evaluation_mode=True)
|
||||
)
|
||||
self.model_preparation_time = round(time.time() - start_time, 4)
|
||||
@@ -4634,7 +4634,7 @@ class Trainer:
|
||||
if len(self.accelerator._models) == 0 and model is self.model:
|
||||
model = (
|
||||
self.accelerator.prepare(model)
|
||||
if self.is_deepspeed_enabled
|
||||
if self.is_deepspeed_enabled or self.is_fsdp_enabled
|
||||
else self.accelerator.prepare_model(model, evaluation_mode=True)
|
||||
)
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
@@ -19,10 +20,12 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Un
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from .generation.configuration_utils import GenerationConfig
|
||||
from .integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from .integrations.fsdp import is_fsdp_managed_module
|
||||
from .trainer import Trainer
|
||||
from .utils import is_datasets_available, logging
|
||||
from .utils.deprecation import deprecate_kwarg
|
||||
@@ -303,10 +306,8 @@ class Seq2SeqTrainer(Trainer):
|
||||
if "max_length" in gen_kwargs and gen_kwargs["max_length"] is None:
|
||||
gen_kwargs.pop("max_length")
|
||||
|
||||
default_synced_gpus = True if is_deepspeed_zero3_enabled() else False
|
||||
gen_kwargs["synced_gpus"] = (
|
||||
gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus
|
||||
)
|
||||
default_synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self.model)
|
||||
gen_kwargs["synced_gpus"] = gen_kwargs.get("synced_gpus", default_synced_gpus)
|
||||
|
||||
generation_inputs = inputs.copy()
|
||||
# If the `decoder_input_ids` was created from `labels`, evict the former, so that the model can freely generate
|
||||
@@ -319,6 +320,14 @@ class Seq2SeqTrainer(Trainer):
|
||||
generation_inputs = {
|
||||
k: v for k, v in inputs.items() if k not in ("decoder_input_ids", "decoder_attention_mask")
|
||||
}
|
||||
|
||||
summon_full_params_context = (
|
||||
FullyShardedDataParallel.summon_full_params(self.model)
|
||||
if isinstance(self.model, FullyShardedDataParallel)
|
||||
else contextlib.nullcontext()
|
||||
)
|
||||
|
||||
with summon_full_params_context:
|
||||
generated_tokens = self.model.generate(**generation_inputs, **gen_kwargs)
|
||||
|
||||
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
|
||||
|
||||
148
tests/generation/test_fsdp.py
Normal file
148
tests/generation/test_fsdp.py
Normal file
@@ -0,0 +1,148 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
from typing import Any, Callable
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
TestCasePlus,
|
||||
execute_subprocess_async,
|
||||
get_torch_dist_unique_port,
|
||||
require_torch_multi_gpu,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import functools
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
from torch.distributed._composable.fsdp import fully_shard, register_fsdp_forward_method
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel
|
||||
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Block
|
||||
|
||||
data = 4 * [
|
||||
"Hello world!",
|
||||
"The quick brown fox jumps over the lazy dog.",
|
||||
]
|
||||
|
||||
def manage_process_group(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
"""Manage the creation and destruction of the distributed process group for the wrapped function."""
|
||||
|
||||
def wrapped(*args: Any, **kwargs: Any) -> Any:
|
||||
torch.distributed.init_process_group(world_size=torch.cuda.device_count())
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
finally:
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
return wrapped
|
||||
|
||||
@manage_process_group
|
||||
def fsdp_generate():
|
||||
torch.cuda.set_device(device := torch.device(rank := torch.distributed.get_rank()))
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(device)
|
||||
|
||||
fsdp_model = FullyShardedDataParallel(
|
||||
model,
|
||||
auto_wrap_policy=functools.partial(transformer_auto_wrap_policy, transformer_layer_cls={GPT2Block}),
|
||||
limit_all_gathers=True,
|
||||
use_orig_params=True,
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
batch = tokenizer(data[rank], return_tensors="pt", return_attention_mask=True).to(device)
|
||||
|
||||
with FullyShardedDataParallel.summon_full_params(fsdp_model):
|
||||
_ = fsdp_model.module.generate(
|
||||
input_ids=batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
max_length=30,
|
||||
)
|
||||
|
||||
@manage_process_group
|
||||
def fsdp2_generate():
|
||||
torch.cuda.set_device(device := torch.device(rank := torch.distributed.get_rank()))
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(device)
|
||||
|
||||
mesh = init_device_mesh("cuda", (torch.distributed.get_world_size(),))
|
||||
for submodule in model.modules():
|
||||
if isinstance(submodule, GPT2Block):
|
||||
fully_shard(submodule, mesh=mesh)
|
||||
fully_shard(model, mesh=mesh)
|
||||
|
||||
register_fsdp_forward_method(model, "generate")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
batch = tokenizer(data[rank], return_tensors="pt", return_attention_mask=True).to(device)
|
||||
|
||||
_ = model.generate(
|
||||
input_ids=batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
max_length=30,
|
||||
)
|
||||
|
||||
|
||||
class TestFSDPGeneration(TestCasePlus):
|
||||
@require_torch_multi_gpu
|
||||
def test_fsdp_generate(self):
|
||||
distributed_args = f"""--nproc_per_node={torch.cuda.device_count()}
|
||||
--master_port={get_torch_dist_unique_port()}
|
||||
{self.test_file_dir}/test_fsdp.py
|
||||
""".split()
|
||||
args = "--fsdp".split()
|
||||
cmd = ["torchrun"] + distributed_args + args
|
||||
execute_subprocess_async(cmd, env=self.get_env())
|
||||
# successful return here == success - any errors would have caused an error in the sub-call
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_fsdp2_generate(self):
|
||||
distributed_args = f"""--nproc_per_node={torch.cuda.device_count()}
|
||||
--master_port={get_torch_dist_unique_port()}
|
||||
{self.test_file_dir}/test_fsdp.py
|
||||
""".split()
|
||||
args = "--fsdp2".split()
|
||||
cmd = ["torchrun"] + distributed_args + args
|
||||
execute_subprocess_async(cmd, env=self.get_env())
|
||||
# successful return here == success - any errors would have caused an error in the sub-call
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# The script below is meant to be run under torch.distributed, on a machine with multiple GPUs:
|
||||
#
|
||||
# PYTHONPATH="src" python -m torch.distributed.run --nproc_per_node 2 --output_dir output_dir ./tests/generation/test_fsdp.py --fsdp
|
||||
|
||||
class CLIArgs(argparse.Namespace):
|
||||
fsdp: bool
|
||||
fsdp2: bool
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
group = parser.add_mutually_exclusive_group()
|
||||
group.add_argument("--fsdp", action="store_true")
|
||||
group.add_argument("--fsdp2", action="store_true")
|
||||
args = parser.parse_args(namespace=CLIArgs())
|
||||
|
||||
if args.fsdp:
|
||||
fsdp_generate()
|
||||
elif args.fsdp2:
|
||||
fsdp2_generate()
|
||||
else:
|
||||
raise ValueError("Missing test selection")
|
||||
114
tests/trainer/test_trainer_fsdp.py
Normal file
114
tests/trainer/test_trainer_fsdp.py
Normal file
@@ -0,0 +1,114 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
TestCasePlus,
|
||||
execute_subprocess_async,
|
||||
get_torch_dist_unique_port,
|
||||
require_accelerate,
|
||||
require_torch_multi_gpu,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.utils.data
|
||||
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
DataCollatorForSeq2Seq,
|
||||
EvalPrediction,
|
||||
GenerationConfig,
|
||||
HfArgumentParser,
|
||||
PreTrainedTokenizerBase,
|
||||
Seq2SeqTrainer,
|
||||
Seq2SeqTrainingArguments,
|
||||
)
|
||||
|
||||
class DummyTextDataset(torch.utils.data.Dataset[str]):
|
||||
def __init__(self, tokenizer: PreTrainedTokenizerBase) -> None:
|
||||
data = 4 * [
|
||||
"Hello world!",
|
||||
"The quick brown fox jumps over the lazy dog.",
|
||||
]
|
||||
self.data = [
|
||||
{k: v.squeeze(0) for k, v in tokenizer(item, return_tensors="pt", return_attention_mask=True).items()}
|
||||
for item in data
|
||||
]
|
||||
for item in self.data:
|
||||
item["labels"] = item["input_ids"]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, i: int) -> str:
|
||||
return self.data[i]
|
||||
|
||||
|
||||
class TestFSDPTrainer(TestCasePlus):
|
||||
@require_accelerate
|
||||
@require_torch_multi_gpu
|
||||
def test_trainer(self):
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
cmd = [
|
||||
"accelerate",
|
||||
"launch",
|
||||
"--use_fsdp",
|
||||
"--main_process_port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
"--num_processes",
|
||||
f"{torch.cuda.device_count()}",
|
||||
"--fsdp_transformer_layer_cls_to_wrap",
|
||||
"GPT2Block",
|
||||
f"{self.test_file_dir}/test_trainer_fsdp.py",
|
||||
"--output_dir",
|
||||
f"{output_dir}",
|
||||
"--report_to",
|
||||
"none",
|
||||
]
|
||||
execute_subprocess_async(cmd, env=self.get_env())
|
||||
# successful return here == success - any errors would have caused an error in the sub-call
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = HfArgumentParser((Seq2SeqTrainingArguments,))
|
||||
training_args = parser.parse_args_into_dataclasses()[0]
|
||||
training_args.per_device_eval_batch_size = 1
|
||||
training_args.use_legacy_prediction_loop = False
|
||||
training_args.predict_with_generate = True
|
||||
training_args.generation_config = GenerationConfig(max_length=30)
|
||||
|
||||
pretrained_model_name = "hf-internal-testing/tiny-random-gpt2"
|
||||
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
device = torch.device(torch.distributed.get_rank())
|
||||
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name).to(device)
|
||||
|
||||
def compute_metrics(p: EvalPrediction) -> Dict[str, bool]:
|
||||
return {"accuracy": (p.predictions == p.label_ids).mean()}
|
||||
|
||||
trainer = Seq2SeqTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
data_collator=DataCollatorForSeq2Seq(tokenizer, model),
|
||||
eval_dataset=DummyTextDataset(tokenizer),
|
||||
compute_metrics=compute_metrics,
|
||||
)
|
||||
|
||||
metrics = trainer.evaluate()
|
||||
Reference in New Issue
Block a user