Llama et al. / FSDP : Fix breaking change in 4.40 for FSDP (#31161)

* fix llama fsdp

* fixup

* adding FSDP tests for CPU offloading

* fixes

* fix tests

* fix tests

* add it for mixtral

* propagate the changes on other models

* Update src/transformers/models/phi/modeling_phi.py

* Delete utils/testing_scripts/fsdp_cpu_offloading.py

Remove script - FSDP + CPU offloading it tested in the test suite

* Delete utils/testing_scripts/dummy_fsdp_config.yml

* Update + add cache_positions docstring

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Younes Belkada
2024-06-26 15:50:08 +02:00
committed by GitHub
parent ac52084bf2
commit 3f93fd0694
13 changed files with 71 additions and 1 deletions

View File

@@ -782,6 +782,7 @@ class FalconDecoderLayer(nn.Module):
head_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False, use_cache: bool = False,
output_attentions: bool = False, output_attentions: bool = False,
**kwargs,
): ):
residual = hidden_states residual = hidden_states

View File

@@ -628,6 +628,7 @@ class GemmaDecoderLayer(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: Args:
@@ -642,6 +643,11 @@ class GemmaDecoderLayer(nn.Module):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`). (see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
""" """
residual = hidden_states residual = hidden_states

View File

@@ -706,6 +706,7 @@ class GPTBigCodeBlock(nn.Module):
encoder_attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
**kwargs,
) -> Union[ ) -> Union[
Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor] Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
]: ]:

View File

@@ -301,6 +301,7 @@ class LlamaAttention(nn.Module):
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
@@ -590,6 +591,7 @@ class LlamaSdpaAttention(LlamaAttention):
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions: if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
@@ -687,6 +689,7 @@ class LlamaDecoderLayer(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: Args:
@@ -701,6 +704,11 @@ class LlamaDecoderLayer(nn.Module):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`). (see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
""" """
residual = hidden_states residual = hidden_states

View File

@@ -591,6 +591,7 @@ class MistralSdpaAttention(MistralAttention):
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions: if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
@@ -689,6 +690,7 @@ class MistralDecoderLayer(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: Args:
@@ -703,8 +705,12 @@ class MistralDecoderLayer(nn.Module):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`). (see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
""" """
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)

View File

@@ -888,6 +888,7 @@ class MixtralDecoderLayer(nn.Module):
output_router_logits: Optional[bool] = False, output_router_logits: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: Args:
@@ -906,6 +907,9 @@ class MixtralDecoderLayer(nn.Module):
(see `past_key_values`). (see `past_key_values`).
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Indices depicting the position of the input sequence tokens in the sequence.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
""" """
residual = hidden_states residual = hidden_states

View File

@@ -666,6 +666,7 @@ class OlmoDecoderLayer(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: Args:
@@ -680,6 +681,11 @@ class OlmoDecoderLayer(nn.Module):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`). (see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
""" """
residual = hidden_states residual = hidden_states

View File

@@ -771,6 +771,7 @@ class PhiDecoderLayer(nn.Module):
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: Args:
@@ -790,6 +791,9 @@ class PhiDecoderLayer(nn.Module):
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence Indices depicting the position of the input sequence tokens in the sequence
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
""" """
residual = hidden_states residual = hidden_states

View File

@@ -830,6 +830,7 @@ class Phi3DecoderLayer(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: Args:
@@ -847,6 +848,11 @@ class Phi3DecoderLayer(nn.Module):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`). (see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
""" """
residual = hidden_states residual = hidden_states

View File

@@ -734,6 +734,7 @@ class Qwen2DecoderLayer(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: Args:
@@ -749,6 +750,9 @@ class Qwen2DecoderLayer(nn.Module):
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Indices depicting the position of the input sequence tokens in the sequence.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
""" """
residual = hidden_states residual = hidden_states

View File

@@ -876,6 +876,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
output_router_logits: Optional[bool] = False, output_router_logits: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: Args:
@@ -894,6 +895,9 @@ class Qwen2MoeDecoderLayer(nn.Module):
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Indices depicting the position of the input sequence tokens in the sequence.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
""" """
residual = hidden_states residual = hidden_states

View File

@@ -713,6 +713,7 @@ class Starcoder2DecoderLayer(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: Args:
@@ -728,6 +729,9 @@ class Starcoder2DecoderLayer(nn.Module):
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Indices depicting the position of the input sequence tokens in the sequence.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
""" """
residual = hidden_states residual = hidden_states

View File

@@ -14,6 +14,7 @@
import itertools import itertools
import os import os
import subprocess
import unittest import unittest
from copy import deepcopy from copy import deepcopy
from functools import partial from functools import partial
@@ -31,6 +32,7 @@ from transformers.testing_utils import (
require_accelerate, require_accelerate,
require_fsdp, require_fsdp,
require_torch_accelerator, require_torch_accelerator,
require_torch_gpu,
require_torch_multi_accelerator, require_torch_multi_accelerator,
slow, slow,
torch_device, torch_device,
@@ -276,6 +278,20 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
if "learning_rate" in log: if "learning_rate" in log:
self.assertAlmostEqual(log["learning_rate"], log1["learning_rate"], delta=1e-5) self.assertAlmostEqual(log["learning_rate"], log1["learning_rate"], delta=1e-5)
@require_torch_multi_accelerator
@slow
@require_torch_gpu
@require_fsdp
def test_fsdp_cpu_offloading(self):
try:
subprocess.run(
"accelerate launch utils/testing_scripts/fsdp_cpu_offloading.py --config utils/testing_scripts/dummy_fsdp_config.yml",
shell=True,
check=True,
)
except: # noqa
raise AssertionError("CPU offloading failed with FSDP!")
def run_cmd_and_get_logs(self, use_accelerate, sharding_strategy, launcher, script, args, output_dir): def run_cmd_and_get_logs(self, use_accelerate, sharding_strategy, launcher, script, args, output_dir):
if not use_accelerate: if not use_accelerate:
fsdp_args = [ fsdp_args = [