Llama et al. / FSDP : Fix breaking change in 4.40 for FSDP (#31161)
* fix llama fsdp * fixup * adding FSDP tests for CPU offloading * fixes * fix tests * fix tests * add it for mixtral * propagate the changes on other models * Update src/transformers/models/phi/modeling_phi.py * Delete utils/testing_scripts/fsdp_cpu_offloading.py Remove script - FSDP + CPU offloading it tested in the test suite * Delete utils/testing_scripts/dummy_fsdp_config.yml * Update + add cache_positions docstring --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -782,6 +782,7 @@ class FalconDecoderLayer(nn.Module):
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
residual = hidden_states
|
||||
|
||||
|
||||
@@ -628,6 +628,7 @@ class GemmaDecoderLayer(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
@@ -642,6 +643,11 @@ class GemmaDecoderLayer(nn.Module):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence
|
||||
kwargs (`dict`, *optional*):
|
||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
||||
into the model
|
||||
"""
|
||||
residual = hidden_states
|
||||
|
||||
|
||||
@@ -706,6 +706,7 @@ class GPTBigCodeBlock(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> Union[
|
||||
Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
||||
]:
|
||||
|
||||
@@ -301,6 +301,7 @@ class LlamaAttention(nn.Module):
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
@@ -590,6 +591,7 @@ class LlamaSdpaAttention(LlamaAttention):
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
@@ -687,6 +689,7 @@ class LlamaDecoderLayer(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
@@ -701,6 +704,11 @@ class LlamaDecoderLayer(nn.Module):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence
|
||||
kwargs (`dict`, *optional*):
|
||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
||||
into the model
|
||||
"""
|
||||
residual = hidden_states
|
||||
|
||||
|
||||
@@ -591,6 +591,7 @@ class MistralSdpaAttention(MistralAttention):
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if output_attentions:
|
||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||
@@ -689,6 +690,7 @@ class MistralDecoderLayer(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
@@ -703,8 +705,12 @@ class MistralDecoderLayer(nn.Module):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence
|
||||
kwargs (`dict`, *optional*):
|
||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
||||
into the model
|
||||
"""
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
@@ -888,6 +888,7 @@ class MixtralDecoderLayer(nn.Module):
|
||||
output_router_logits: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
@@ -906,6 +907,9 @@ class MixtralDecoderLayer(nn.Module):
|
||||
(see `past_key_values`).
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
kwargs (`dict`, *optional*):
|
||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
||||
into the model
|
||||
"""
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
@@ -666,6 +666,7 @@ class OlmoDecoderLayer(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
@@ -680,6 +681,11 @@ class OlmoDecoderLayer(nn.Module):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence
|
||||
kwargs (`dict`, *optional*):
|
||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
||||
into the model
|
||||
"""
|
||||
residual = hidden_states
|
||||
|
||||
|
||||
@@ -771,6 +771,7 @@ class PhiDecoderLayer(nn.Module):
|
||||
use_cache: Optional[bool] = False,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
@@ -790,6 +791,9 @@ class PhiDecoderLayer(nn.Module):
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence
|
||||
kwargs (`dict`, *optional*):
|
||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
||||
into the model
|
||||
"""
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
@@ -830,6 +830,7 @@ class Phi3DecoderLayer(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
@@ -847,6 +848,11 @@ class Phi3DecoderLayer(nn.Module):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence
|
||||
kwargs (`dict`, *optional*):
|
||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
||||
into the model
|
||||
"""
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
@@ -734,6 +734,7 @@ class Qwen2DecoderLayer(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
@@ -749,6 +750,9 @@ class Qwen2DecoderLayer(nn.Module):
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
kwargs (`dict`, *optional*):
|
||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
||||
into the model
|
||||
"""
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
@@ -876,6 +876,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
||||
output_router_logits: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
@@ -894,6 +895,9 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
kwargs (`dict`, *optional*):
|
||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
||||
into the model
|
||||
"""
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
@@ -713,6 +713,7 @@ class Starcoder2DecoderLayer(nn.Module):
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
@@ -728,6 +729,9 @@ class Starcoder2DecoderLayer(nn.Module):
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
kwargs (`dict`, *optional*):
|
||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
||||
into the model
|
||||
"""
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
import itertools
|
||||
import os
|
||||
import subprocess
|
||||
import unittest
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
@@ -31,6 +32,7 @@ from transformers.testing_utils import (
|
||||
require_accelerate,
|
||||
require_fsdp,
|
||||
require_torch_accelerator,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_accelerator,
|
||||
slow,
|
||||
torch_device,
|
||||
@@ -276,6 +278,20 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
|
||||
if "learning_rate" in log:
|
||||
self.assertAlmostEqual(log["learning_rate"], log1["learning_rate"], delta=1e-5)
|
||||
|
||||
@require_torch_multi_accelerator
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_fsdp
|
||||
def test_fsdp_cpu_offloading(self):
|
||||
try:
|
||||
subprocess.run(
|
||||
"accelerate launch utils/testing_scripts/fsdp_cpu_offloading.py --config utils/testing_scripts/dummy_fsdp_config.yml",
|
||||
shell=True,
|
||||
check=True,
|
||||
)
|
||||
except: # noqa
|
||||
raise AssertionError("CPU offloading failed with FSDP!")
|
||||
|
||||
def run_cmd_and_get_logs(self, use_accelerate, sharding_strategy, launcher, script, args, output_dir):
|
||||
if not use_accelerate:
|
||||
fsdp_args = [
|
||||
|
||||
Reference in New Issue
Block a user