Remove graph breaks for torch.compile() in flash_attention_forward when Lllama Model is padding free tuned (#33932)
* fix: fixes for graph breaks Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * fix: formatting Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * fix: import error Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * fix: Add Fa2Kwargs Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * fix: PR Changes Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * PR changes Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * PR changes Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * PR changes Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * PR changes Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * Revert "PR changes" This reverts commit 39d2868e5c93cc5f3f3c7c6ff981b66614c0e0e4. * PR changes Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * fix: FlashAttentionKwarg Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * fix: FlashAttentionKwarg Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * PR Changes Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * PR Changes Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * PR Changes Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * PR Changes Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * PR Changes Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * addition of documentation Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * change in _flash_attention_forward Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * make fix-copies Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * revert make fix-copies Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * fix copies * style * loss kwargs typing * style and pull latest changes --------- Signed-off-by: Abhishek <maurya.abhishek@ibm.com> Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
This commit is contained in:
@@ -348,6 +348,99 @@ model = AutoModelForCausalLM.from_pretrained(
|
||||
)
|
||||
```
|
||||
|
||||
### Fine-Tuning with torch.compile and Padding-Free Data Collation
|
||||
|
||||
In addition to optimizing inference, you can also enhance the training efficiency of large language models by leveraging torch.compile during fine-tuning and using a padding-free data collator. This approach can significantly speed up training and reduce computational overhead.
|
||||
|
||||
Here's how you can fine-tune a Llama model using SFTTrainer from the TRL library, with torch_compile enabled and a padding-free data collator:
|
||||
|
||||
```
|
||||
#################### IMPORTS ###################
|
||||
|
||||
import math
|
||||
import datasets
|
||||
import dataclasses
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
TrainingArguments
|
||||
)
|
||||
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM
|
||||
|
||||
#################### MODEL LOADING WITH FLASH ATTENTION ###################
|
||||
|
||||
model_name = "meta-llama/Llama-3.2-1B"
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
attn_implementation="flash_attention_2" # Enables FlashAttention-2
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||
|
||||
#################### DATA PREPROCESSING (PADDING-FREE) ###################
|
||||
|
||||
response_template = "\n### Label:"
|
||||
response_template_ids = tokenizer.encode(
|
||||
response_template, add_special_tokens=False
|
||||
)[2:] # Exclude special tokens
|
||||
|
||||
data_collator = DataCollatorForCompletionOnlyLM(
|
||||
response_template_ids=response_template_ids,
|
||||
tokenizer=tokenizer,
|
||||
ignore_index=-100,
|
||||
padding_free=True # Enables padding-free collation
|
||||
)
|
||||
|
||||
def format_dataset(example):
|
||||
return {
|
||||
"output": example["output"] + tokenizer.eos_token
|
||||
}
|
||||
|
||||
data_files = {"train": "path/to/dataset"} # Replace with your dataset path
|
||||
json_dataset = datasets.load_dataset("json", data_files=data_files)
|
||||
formatted_train_dataset = json_dataset["train"].map(format_dataset)
|
||||
|
||||
################# TRAINING CONFIGURATION ############################
|
||||
|
||||
train_args = TrainingArguments(
|
||||
num_train_epochs=5,
|
||||
per_device_train_batch_size=4,
|
||||
per_device_eval_batch_size=4,
|
||||
gradient_accumulation_steps=4,
|
||||
learning_rate=1e-5,
|
||||
weight_decay=0.0,
|
||||
warmup_ratio=0.03,
|
||||
lr_scheduler_type="cosine",
|
||||
logging_steps=1,
|
||||
include_tokens_per_second=True,
|
||||
save_strategy="epoch",
|
||||
output_dir="output",
|
||||
torch_compile=True, # Enables torch.compile
|
||||
torch_compile_backend="inductor",
|
||||
torch_compile_mode="default"
|
||||
)
|
||||
|
||||
# Convert TrainingArguments to SFTConfig
|
||||
transformer_train_arg_fields = [x.name for x in dataclasses.fields(SFTConfig)]
|
||||
transformer_kwargs = {
|
||||
k: v
|
||||
for k, v in train_args.to_dict().items()
|
||||
if k in transformer_train_arg_fields
|
||||
}
|
||||
training_args = SFTConfig(**transformer_kwargs)
|
||||
|
||||
####################### FINE-TUNING #####################
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=formatted_train_dataset,
|
||||
data_collator=data_collator,
|
||||
dataset_text_field="output",
|
||||
args=training_args,
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### PyTorch scaled dot product attention
|
||||
|
||||
Scaled dot product attention (SDPA) is automatically enabled in PyTorch 2.0 and it supports FlashAttention, xFormers, and PyTorch's C++ implementation. SDPA chooses the most performant attention algorithm if you're using a CUDA backend. For other backends, SDPA defaults to the PyTorch C++ implementation.
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional, Tuple, TypedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -180,6 +180,10 @@ def prepare_fa2_from_position_ids(query, key, value, position_ids):
|
||||
return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))
|
||||
|
||||
|
||||
flash_241 = is_flash_attn_greater_or_equal("2.4.1")
|
||||
deterministic_g = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
|
||||
|
||||
|
||||
def _flash_attention_forward(
|
||||
query_states: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
@@ -194,6 +198,10 @@ def _flash_attention_forward(
|
||||
use_top_left_mask: bool = False,
|
||||
softcap: Optional[float] = None,
|
||||
deterministic: bool = None,
|
||||
cu_seq_lens_q: Optional[torch.LongTensor] = None,
|
||||
cu_seq_lens_k: Optional[torch.LongTensor] = None,
|
||||
max_length_q: Optional[int] = None,
|
||||
max_length_k: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
||||
@@ -232,9 +240,9 @@ def _flash_attention_forward(
|
||||
)
|
||||
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
|
||||
|
||||
if is_flash_attn_greater_or_equal("2.4.1"):
|
||||
if flash_241:
|
||||
if deterministic is None:
|
||||
deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
|
||||
deterministic = deterministic_g
|
||||
flash_kwargs["deterministic"] = deterministic
|
||||
|
||||
if softcap is not None:
|
||||
@@ -267,24 +275,32 @@ def _flash_attention_forward(
|
||||
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
|
||||
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
|
||||
# Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
|
||||
# Note: the `torch.diff(...)` condition is last to use short-circuit and avoid the cuda synchronization it incurs during inference (query_length == 1 always)
|
||||
elif position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all():
|
||||
elif position_ids is not None and (
|
||||
max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())
|
||||
):
|
||||
batch_size = query_states.size(0)
|
||||
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
|
||||
query_states, key_states, value_states, position_ids
|
||||
|
||||
if cu_seq_lens_q is None or cu_seq_lens_k is None:
|
||||
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = (
|
||||
prepare_fa2_from_position_ids(query_states, key_states, value_states, position_ids)
|
||||
)
|
||||
|
||||
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
||||
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
||||
cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens
|
||||
max_length_q, max_length_k = max_seq_lens
|
||||
|
||||
else:
|
||||
query_states = query_states.reshape(-1, query_states.size(-2), query_states.size(-1))
|
||||
key_states = key_states.reshape(-1, key_states.size(-2), key_states.size(-1))
|
||||
value_states = value_states.reshape(-1, value_states.size(-2), value_states.size(-1))
|
||||
|
||||
attn_output = flash_attn_varlen_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_in_batch_q,
|
||||
max_seqlen_k=max_seqlen_in_batch_k,
|
||||
cu_seqlens_q=cu_seq_lens_q,
|
||||
cu_seqlens_k=cu_seq_lens_k,
|
||||
max_seqlen_q=max_length_q,
|
||||
max_seqlen_k=max_length_k,
|
||||
dropout_p=dropout,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
@@ -299,3 +315,24 @@ def _flash_attention_forward(
|
||||
)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
class FlashAttentionKwargs(TypedDict, total=False):
|
||||
"""
|
||||
Keyword arguments for Flash Attention with Compile.
|
||||
|
||||
Attributes:
|
||||
cu_seq_lens_q (`torch.LongTensor`, *optional*)
|
||||
Gets cumlative sequence length for query state.
|
||||
cu_seq_lens_k (`torch.LongTensor`, *optional*)
|
||||
Gets cumlative sequence length for key state.
|
||||
max_length_q (`int`, *optional*):
|
||||
Maximum sequence length for query state.
|
||||
max_length_k (`int`, *optional*):
|
||||
Maximum sequence length for key state.
|
||||
"""
|
||||
|
||||
cu_seq_lens_q: Optional[torch.LongTensor]
|
||||
cu_seq_lens_k: Optional[torch.LongTensor]
|
||||
max_length_q: Optional[int]
|
||||
max_length_k: Optional[int]
|
||||
|
||||
@@ -33,12 +33,14 @@ from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
)
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||
from ...utils import (
|
||||
add_start_docstrings,
|
||||
@@ -832,6 +834,7 @@ class CohereModel(CoherePreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@@ -913,6 +916,7 @@ class CohereModel(CoherePreTrainedModel):
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
@@ -38,6 +38,7 @@ from ...modeling_outputs import (
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
@@ -51,7 +52,11 @@ from .configuration_glm import GlmConfig
|
||||
if is_flash_attn_2_available():
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward
|
||||
from ...processing_utils import Unpack
|
||||
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "dummy"
|
||||
|
||||
|
||||
class GlmRMSNorm(nn.Module):
|
||||
@@ -736,6 +741,7 @@ class GlmModel(GlmPreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@@ -817,6 +823,7 @@ class GlmModel(GlmPreTrainedModel):
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
@@ -1222,6 +1229,11 @@ class GlmForTokenClassification(GlmPreTrainedModel):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
@add_start_docstrings_to_model_forward(GLM_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=TokenClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
|
||||
@@ -46,6 +46,8 @@ from .configuration_glm import GlmConfig
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "dummy"
|
||||
|
||||
|
||||
class GlmRMSNorm(Phi3RMSNorm):
|
||||
pass
|
||||
|
||||
@@ -29,7 +29,7 @@ from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
@@ -39,8 +39,10 @@ from ...modeling_outputs import (
|
||||
)
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||
from ...utils import (
|
||||
LossKwargs,
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
@@ -422,6 +424,7 @@ class LlamaFlashAttention2(LlamaAttention):
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if isinstance(past_key_value, StaticCache):
|
||||
raise ValueError(
|
||||
@@ -506,6 +509,7 @@ class LlamaFlashAttention2(LlamaAttention):
|
||||
sliding_window=getattr(self, "sliding_window", None),
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
is_causal=self.is_causal,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||
@@ -870,6 +874,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
@@ -951,6 +956,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
@@ -1102,6 +1108,9 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
return causal_mask
|
||||
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||
|
||||
|
||||
class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
|
||||
@@ -1148,7 +1157,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
num_logits_to_keep: int = 0,
|
||||
**loss_kwargs,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
@@ -1198,6 +1207,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
@@ -1211,7 +1221,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **loss_kwargs)
|
||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
|
||||
@@ -815,7 +815,7 @@ class BatchEncoding(UserDict):
|
||||
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
|
||||
# into a HalfTensor
|
||||
if isinstance(device, str) or is_torch_device(device) or isinstance(device, int):
|
||||
self.data = {k: v.to(device=device) for k, v in self.data.items() if isinstance(v, torch.Tensor)}
|
||||
self.data = {k: v.to(device=device) if isinstance(v, torch.Tensor) else v for k, v in self.data.items()}
|
||||
else:
|
||||
logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.")
|
||||
return self
|
||||
|
||||
@@ -37,6 +37,7 @@ from .doc import (
|
||||
from .generic import (
|
||||
ContextManagers,
|
||||
ExplicitEnum,
|
||||
LossKwargs,
|
||||
ModelOutput,
|
||||
PaddingStrategy,
|
||||
TensorType,
|
||||
|
||||
@@ -24,7 +24,7 @@ from contextlib import ExitStack, contextmanager
|
||||
from dataclasses import fields, is_dataclass
|
||||
from enum import Enum
|
||||
from functools import partial, wraps
|
||||
from typing import Any, ContextManager, Iterable, List, Optional, Tuple
|
||||
from typing import Any, ContextManager, Iterable, List, Optional, Tuple, TypedDict
|
||||
|
||||
import numpy as np
|
||||
from packaging import version
|
||||
@@ -854,3 +854,16 @@ def filter_out_non_signature_kwargs(extra: Optional[list] = None):
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class LossKwargs(TypedDict, total=False):
|
||||
"""
|
||||
Keyword arguments to be passed to the loss function
|
||||
|
||||
Attributes:
|
||||
num_items_in_batch (`int`, *optional*):
|
||||
Number of items in the batch. It is recommended to pass it when
|
||||
you are doing gradient accumulation.
|
||||
"""
|
||||
|
||||
num_items_in_batch: Optional[int]
|
||||
|
||||
Reference in New Issue
Block a user