Remove graph breaks for torch.compile() in flash_attention_forward when Lllama Model is padding free tuned (#33932)

* fix: fixes for graph breaks

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* fix: formatting

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* fix: import error

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* fix: Add Fa2Kwargs

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* fix: PR Changes

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* PR changes

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* PR changes

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* PR changes

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* PR changes

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* Revert "PR changes"

This reverts commit 39d2868e5c93cc5f3f3c7c6ff981b66614c0e0e4.

* PR changes

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* fix: FlashAttentionKwarg

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* fix: FlashAttentionKwarg

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* PR Changes

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* PR Changes

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* PR Changes

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* PR Changes

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* PR Changes

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* addition of documentation

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* change in _flash_attention_forward

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* make fix-copies

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* revert make fix-copies

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* fix copies

* style

* loss kwargs typing

* style and pull latest changes

---------

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
This commit is contained in:
Abhishek Maurya
2024-10-24 05:02:54 -04:00
committed by GitHub
parent b0f0c61899
commit 65753d6065
9 changed files with 192 additions and 20 deletions

View File

@@ -348,6 +348,99 @@ model = AutoModelForCausalLM.from_pretrained(
)
```
### Fine-Tuning with torch.compile and Padding-Free Data Collation
In addition to optimizing inference, you can also enhance the training efficiency of large language models by leveraging torch.compile during fine-tuning and using a padding-free data collator. This approach can significantly speed up training and reduce computational overhead.
Here's how you can fine-tune a Llama model using SFTTrainer from the TRL library, with torch_compile enabled and a padding-free data collator:
```
#################### IMPORTS ###################
import math
import datasets
import dataclasses
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments
)
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM
#################### MODEL LOADING WITH FLASH ATTENTION ###################
model_name = "meta-llama/Llama-3.2-1B"
model = AutoModelForCausalLM.from_pretrained(
model_name,
attn_implementation="flash_attention_2" # Enables FlashAttention-2
)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
#################### DATA PREPROCESSING (PADDING-FREE) ###################
response_template = "\n### Label:"
response_template_ids = tokenizer.encode(
response_template, add_special_tokens=False
)[2:] # Exclude special tokens
data_collator = DataCollatorForCompletionOnlyLM(
response_template_ids=response_template_ids,
tokenizer=tokenizer,
ignore_index=-100,
padding_free=True # Enables padding-free collation
)
def format_dataset(example):
return {
"output": example["output"] + tokenizer.eos_token
}
data_files = {"train": "path/to/dataset"} # Replace with your dataset path
json_dataset = datasets.load_dataset("json", data_files=data_files)
formatted_train_dataset = json_dataset["train"].map(format_dataset)
################# TRAINING CONFIGURATION ############################
train_args = TrainingArguments(
num_train_epochs=5,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=1e-5,
weight_decay=0.0,
warmup_ratio=0.03,
lr_scheduler_type="cosine",
logging_steps=1,
include_tokens_per_second=True,
save_strategy="epoch",
output_dir="output",
torch_compile=True, # Enables torch.compile
torch_compile_backend="inductor",
torch_compile_mode="default"
)
# Convert TrainingArguments to SFTConfig
transformer_train_arg_fields = [x.name for x in dataclasses.fields(SFTConfig)]
transformer_kwargs = {
k: v
for k, v in train_args.to_dict().items()
if k in transformer_train_arg_fields
}
training_args = SFTConfig(**transformer_kwargs)
####################### FINE-TUNING #####################
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=formatted_train_dataset,
data_collator=data_collator,
dataset_text_field="output",
args=training_args,
)
trainer.train()
```
### PyTorch scaled dot product attention
Scaled dot product attention (SDPA) is automatically enabled in PyTorch 2.0 and it supports FlashAttention, xFormers, and PyTorch's C++ implementation. SDPA chooses the most performant attention algorithm if you're using a CUDA backend. For other backends, SDPA defaults to the PyTorch C++ implementation.

View File

@@ -15,7 +15,7 @@
import inspect
import os
from typing import Optional, Tuple
from typing import Optional, Tuple, TypedDict
import torch
import torch.nn.functional as F
@@ -180,6 +180,10 @@ def prepare_fa2_from_position_ids(query, key, value, position_ids):
return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))
flash_241 = is_flash_attn_greater_or_equal("2.4.1")
deterministic_g = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
def _flash_attention_forward(
query_states: torch.Tensor,
key_states: torch.Tensor,
@@ -194,6 +198,10 @@ def _flash_attention_forward(
use_top_left_mask: bool = False,
softcap: Optional[float] = None,
deterministic: bool = None,
cu_seq_lens_q: Optional[torch.LongTensor] = None,
cu_seq_lens_k: Optional[torch.LongTensor] = None,
max_length_q: Optional[int] = None,
max_length_k: Optional[int] = None,
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
@@ -232,9 +240,9 @@ def _flash_attention_forward(
)
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
if is_flash_attn_greater_or_equal("2.4.1"):
if flash_241:
if deterministic is None:
deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
deterministic = deterministic_g
flash_kwargs["deterministic"] = deterministic
if softcap is not None:
@@ -267,24 +275,32 @@ def _flash_attention_forward(
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
# Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
# Note: the `torch.diff(...)` condition is last to use short-circuit and avoid the cuda synchronization it incurs during inference (query_length == 1 always)
elif position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all():
elif position_ids is not None and (
max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())
):
batch_size = query_states.size(0)
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
query_states, key_states, value_states, position_ids
if cu_seq_lens_q is None or cu_seq_lens_k is None:
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = (
prepare_fa2_from_position_ids(query_states, key_states, value_states, position_ids)
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens
max_length_q, max_length_k = max_seq_lens
else:
query_states = query_states.reshape(-1, query_states.size(-2), query_states.size(-1))
key_states = key_states.reshape(-1, key_states.size(-2), key_states.size(-1))
value_states = value_states.reshape(-1, value_states.size(-2), value_states.size(-1))
attn_output = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
cu_seqlens_q=cu_seq_lens_q,
cu_seqlens_k=cu_seq_lens_k,
max_seqlen_q=max_length_q,
max_seqlen_k=max_length_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
@@ -299,3 +315,24 @@ def _flash_attention_forward(
)
return attn_output
class FlashAttentionKwargs(TypedDict, total=False):
"""
Keyword arguments for Flash Attention with Compile.
Attributes:
cu_seq_lens_q (`torch.LongTensor`, *optional*)
Gets cumlative sequence length for query state.
cu_seq_lens_k (`torch.LongTensor`, *optional*)
Gets cumlative sequence length for key state.
max_length_q (`int`, *optional*):
Maximum sequence length for query state.
max_length_k (`int`, *optional*):
Maximum sequence length for key state.
"""
cu_seq_lens_q: Optional[torch.LongTensor]
cu_seq_lens_k: Optional[torch.LongTensor]
max_length_q: Optional[int]
max_length_k: Optional[int]

View File

@@ -33,12 +33,14 @@ from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
add_start_docstrings,
@@ -832,6 +834,7 @@ class CohereModel(CoherePreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@@ -913,6 +916,7 @@ class CohereModel(CoherePreTrainedModel):
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]

View File

@@ -38,6 +38,7 @@ from ...modeling_outputs import (
)
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
@@ -51,7 +52,11 @@ from .configuration_glm import GlmConfig
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
from ...modeling_flash_attention_utils import _flash_attention_forward
from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward
from ...processing_utils import Unpack
_CHECKPOINT_FOR_DOC = "dummy"
class GlmRMSNorm(nn.Module):
@@ -736,6 +741,7 @@ class GlmModel(GlmPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@@ -817,6 +823,7 @@ class GlmModel(GlmPreTrainedModel):
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
@@ -1222,6 +1229,11 @@ class GlmForTokenClassification(GlmPreTrainedModel):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(GLM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,

View File

@@ -46,6 +46,8 @@ from .configuration_glm import GlmConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "dummy"
class GlmRMSNorm(Phi3RMSNorm):
pass

View File

@@ -29,7 +29,7 @@ from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import _flash_attention_forward
from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
@@ -39,8 +39,10 @@ from ...modeling_outputs import (
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
LossKwargs,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
@@ -422,6 +424,7 @@ class LlamaFlashAttention2(LlamaAttention):
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if isinstance(past_key_value, StaticCache):
raise ValueError(
@@ -506,6 +509,7 @@ class LlamaFlashAttention2(LlamaAttention):
sliding_window=getattr(self, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
**kwargs,
)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
@@ -870,6 +874,7 @@ class LlamaModel(LlamaPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@@ -951,6 +956,7 @@ class LlamaModel(LlamaPreTrainedModel):
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
@@ -1102,6 +1108,9 @@ class LlamaModel(LlamaPreTrainedModel):
return causal_mask
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
@@ -1148,7 +1157,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@@ -1198,6 +1207,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
@@ -1211,7 +1221,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **loss_kwargs)
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
if not return_dict:
output = (logits,) + outputs[1:]

View File

@@ -815,7 +815,7 @@ class BatchEncoding(UserDict):
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
# into a HalfTensor
if isinstance(device, str) or is_torch_device(device) or isinstance(device, int):
self.data = {k: v.to(device=device) for k, v in self.data.items() if isinstance(v, torch.Tensor)}
self.data = {k: v.to(device=device) if isinstance(v, torch.Tensor) else v for k, v in self.data.items()}
else:
logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.")
return self

View File

@@ -37,6 +37,7 @@ from .doc import (
from .generic import (
ContextManagers,
ExplicitEnum,
LossKwargs,
ModelOutput,
PaddingStrategy,
TensorType,

View File

@@ -24,7 +24,7 @@ from contextlib import ExitStack, contextmanager
from dataclasses import fields, is_dataclass
from enum import Enum
from functools import partial, wraps
from typing import Any, ContextManager, Iterable, List, Optional, Tuple
from typing import Any, ContextManager, Iterable, List, Optional, Tuple, TypedDict
import numpy as np
from packaging import version
@@ -854,3 +854,16 @@ def filter_out_non_signature_kwargs(extra: Optional[list] = None):
return wrapper
return decorator
class LossKwargs(TypedDict, total=False):
"""
Keyword arguments to be passed to the loss function
Attributes:
num_items_in_batch (`int`, *optional*):
Number of items in the batch. It is recommended to pass it when
you are doing gradient accumulation.
"""
num_items_in_batch: Optional[int]