Remove graph breaks for torch.compile() in flash_attention_forward when Lllama Model is padding free tuned (#33932)
* fix: fixes for graph breaks Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * fix: formatting Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * fix: import error Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * fix: Add Fa2Kwargs Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * fix: PR Changes Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * PR changes Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * PR changes Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * PR changes Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * PR changes Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * Revert "PR changes" This reverts commit 39d2868e5c93cc5f3f3c7c6ff981b66614c0e0e4. * PR changes Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * fix: FlashAttentionKwarg Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * fix: FlashAttentionKwarg Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * PR Changes Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * PR Changes Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * PR Changes Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * PR Changes Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * PR Changes Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * addition of documentation Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * change in _flash_attention_forward Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * make fix-copies Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * revert make fix-copies Signed-off-by: Abhishek <maurya.abhishek@ibm.com> * fix copies * style * loss kwargs typing * style and pull latest changes --------- Signed-off-by: Abhishek <maurya.abhishek@ibm.com> Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
This commit is contained in:
@@ -348,6 +348,99 @@ model = AutoModelForCausalLM.from_pretrained(
|
|||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Fine-Tuning with torch.compile and Padding-Free Data Collation
|
||||||
|
|
||||||
|
In addition to optimizing inference, you can also enhance the training efficiency of large language models by leveraging torch.compile during fine-tuning and using a padding-free data collator. This approach can significantly speed up training and reduce computational overhead.
|
||||||
|
|
||||||
|
Here's how you can fine-tune a Llama model using SFTTrainer from the TRL library, with torch_compile enabled and a padding-free data collator:
|
||||||
|
|
||||||
|
```
|
||||||
|
#################### IMPORTS ###################
|
||||||
|
|
||||||
|
import math
|
||||||
|
import datasets
|
||||||
|
import dataclasses
|
||||||
|
from transformers import (
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
AutoTokenizer,
|
||||||
|
TrainingArguments
|
||||||
|
)
|
||||||
|
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM
|
||||||
|
|
||||||
|
#################### MODEL LOADING WITH FLASH ATTENTION ###################
|
||||||
|
|
||||||
|
model_name = "meta-llama/Llama-3.2-1B"
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
attn_implementation="flash_attention_2" # Enables FlashAttention-2
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||||
|
|
||||||
|
#################### DATA PREPROCESSING (PADDING-FREE) ###################
|
||||||
|
|
||||||
|
response_template = "\n### Label:"
|
||||||
|
response_template_ids = tokenizer.encode(
|
||||||
|
response_template, add_special_tokens=False
|
||||||
|
)[2:] # Exclude special tokens
|
||||||
|
|
||||||
|
data_collator = DataCollatorForCompletionOnlyLM(
|
||||||
|
response_template_ids=response_template_ids,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
ignore_index=-100,
|
||||||
|
padding_free=True # Enables padding-free collation
|
||||||
|
)
|
||||||
|
|
||||||
|
def format_dataset(example):
|
||||||
|
return {
|
||||||
|
"output": example["output"] + tokenizer.eos_token
|
||||||
|
}
|
||||||
|
|
||||||
|
data_files = {"train": "path/to/dataset"} # Replace with your dataset path
|
||||||
|
json_dataset = datasets.load_dataset("json", data_files=data_files)
|
||||||
|
formatted_train_dataset = json_dataset["train"].map(format_dataset)
|
||||||
|
|
||||||
|
################# TRAINING CONFIGURATION ############################
|
||||||
|
|
||||||
|
train_args = TrainingArguments(
|
||||||
|
num_train_epochs=5,
|
||||||
|
per_device_train_batch_size=4,
|
||||||
|
per_device_eval_batch_size=4,
|
||||||
|
gradient_accumulation_steps=4,
|
||||||
|
learning_rate=1e-5,
|
||||||
|
weight_decay=0.0,
|
||||||
|
warmup_ratio=0.03,
|
||||||
|
lr_scheduler_type="cosine",
|
||||||
|
logging_steps=1,
|
||||||
|
include_tokens_per_second=True,
|
||||||
|
save_strategy="epoch",
|
||||||
|
output_dir="output",
|
||||||
|
torch_compile=True, # Enables torch.compile
|
||||||
|
torch_compile_backend="inductor",
|
||||||
|
torch_compile_mode="default"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert TrainingArguments to SFTConfig
|
||||||
|
transformer_train_arg_fields = [x.name for x in dataclasses.fields(SFTConfig)]
|
||||||
|
transformer_kwargs = {
|
||||||
|
k: v
|
||||||
|
for k, v in train_args.to_dict().items()
|
||||||
|
if k in transformer_train_arg_fields
|
||||||
|
}
|
||||||
|
training_args = SFTConfig(**transformer_kwargs)
|
||||||
|
|
||||||
|
####################### FINE-TUNING #####################
|
||||||
|
|
||||||
|
trainer = SFTTrainer(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
train_dataset=formatted_train_dataset,
|
||||||
|
data_collator=data_collator,
|
||||||
|
dataset_text_field="output",
|
||||||
|
args=training_args,
|
||||||
|
)
|
||||||
|
trainer.train()
|
||||||
|
```
|
||||||
|
|
||||||
### PyTorch scaled dot product attention
|
### PyTorch scaled dot product attention
|
||||||
|
|
||||||
Scaled dot product attention (SDPA) is automatically enabled in PyTorch 2.0 and it supports FlashAttention, xFormers, and PyTorch's C++ implementation. SDPA chooses the most performant attention algorithm if you're using a CUDA backend. For other backends, SDPA defaults to the PyTorch C++ implementation.
|
Scaled dot product attention (SDPA) is automatically enabled in PyTorch 2.0 and it supports FlashAttention, xFormers, and PyTorch's C++ implementation. SDPA chooses the most performant attention algorithm if you're using a CUDA backend. For other backends, SDPA defaults to the PyTorch C++ implementation.
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple, TypedDict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -180,6 +180,10 @@ def prepare_fa2_from_position_ids(query, key, value, position_ids):
|
|||||||
return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))
|
return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))
|
||||||
|
|
||||||
|
|
||||||
|
flash_241 = is_flash_attn_greater_or_equal("2.4.1")
|
||||||
|
deterministic_g = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
|
||||||
|
|
||||||
|
|
||||||
def _flash_attention_forward(
|
def _flash_attention_forward(
|
||||||
query_states: torch.Tensor,
|
query_states: torch.Tensor,
|
||||||
key_states: torch.Tensor,
|
key_states: torch.Tensor,
|
||||||
@@ -194,6 +198,10 @@ def _flash_attention_forward(
|
|||||||
use_top_left_mask: bool = False,
|
use_top_left_mask: bool = False,
|
||||||
softcap: Optional[float] = None,
|
softcap: Optional[float] = None,
|
||||||
deterministic: bool = None,
|
deterministic: bool = None,
|
||||||
|
cu_seq_lens_q: Optional[torch.LongTensor] = None,
|
||||||
|
cu_seq_lens_k: Optional[torch.LongTensor] = None,
|
||||||
|
max_length_q: Optional[int] = None,
|
||||||
|
max_length_k: Optional[int] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
||||||
@@ -232,9 +240,9 @@ def _flash_attention_forward(
|
|||||||
)
|
)
|
||||||
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
|
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
|
||||||
|
|
||||||
if is_flash_attn_greater_or_equal("2.4.1"):
|
if flash_241:
|
||||||
if deterministic is None:
|
if deterministic is None:
|
||||||
deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
|
deterministic = deterministic_g
|
||||||
flash_kwargs["deterministic"] = deterministic
|
flash_kwargs["deterministic"] = deterministic
|
||||||
|
|
||||||
if softcap is not None:
|
if softcap is not None:
|
||||||
@@ -267,24 +275,32 @@ def _flash_attention_forward(
|
|||||||
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
|
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
|
||||||
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
|
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
|
||||||
# Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
|
# Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
|
||||||
# Note: the `torch.diff(...)` condition is last to use short-circuit and avoid the cuda synchronization it incurs during inference (query_length == 1 always)
|
elif position_ids is not None and (
|
||||||
elif position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all():
|
max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())
|
||||||
|
):
|
||||||
batch_size = query_states.size(0)
|
batch_size = query_states.size(0)
|
||||||
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
|
|
||||||
query_states, key_states, value_states, position_ids
|
|
||||||
)
|
|
||||||
|
|
||||||
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
if cu_seq_lens_q is None or cu_seq_lens_k is None:
|
||||||
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = (
|
||||||
|
prepare_fa2_from_position_ids(query_states, key_states, value_states, position_ids)
|
||||||
|
)
|
||||||
|
|
||||||
|
cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens
|
||||||
|
max_length_q, max_length_k = max_seq_lens
|
||||||
|
|
||||||
|
else:
|
||||||
|
query_states = query_states.reshape(-1, query_states.size(-2), query_states.size(-1))
|
||||||
|
key_states = key_states.reshape(-1, key_states.size(-2), key_states.size(-1))
|
||||||
|
value_states = value_states.reshape(-1, value_states.size(-2), value_states.size(-1))
|
||||||
|
|
||||||
attn_output = flash_attn_varlen_func(
|
attn_output = flash_attn_varlen_func(
|
||||||
query_states,
|
query_states,
|
||||||
key_states,
|
key_states,
|
||||||
value_states,
|
value_states,
|
||||||
cu_seqlens_q=cu_seqlens_q,
|
cu_seqlens_q=cu_seq_lens_q,
|
||||||
cu_seqlens_k=cu_seqlens_k,
|
cu_seqlens_k=cu_seq_lens_k,
|
||||||
max_seqlen_q=max_seqlen_in_batch_q,
|
max_seqlen_q=max_length_q,
|
||||||
max_seqlen_k=max_seqlen_in_batch_k,
|
max_seqlen_k=max_length_k,
|
||||||
dropout_p=dropout,
|
dropout_p=dropout,
|
||||||
softmax_scale=softmax_scale,
|
softmax_scale=softmax_scale,
|
||||||
causal=causal,
|
causal=causal,
|
||||||
@@ -299,3 +315,24 @@ def _flash_attention_forward(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
class FlashAttentionKwargs(TypedDict, total=False):
|
||||||
|
"""
|
||||||
|
Keyword arguments for Flash Attention with Compile.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
cu_seq_lens_q (`torch.LongTensor`, *optional*)
|
||||||
|
Gets cumlative sequence length for query state.
|
||||||
|
cu_seq_lens_k (`torch.LongTensor`, *optional*)
|
||||||
|
Gets cumlative sequence length for key state.
|
||||||
|
max_length_q (`int`, *optional*):
|
||||||
|
Maximum sequence length for query state.
|
||||||
|
max_length_k (`int`, *optional*):
|
||||||
|
Maximum sequence length for key state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
cu_seq_lens_q: Optional[torch.LongTensor]
|
||||||
|
cu_seq_lens_k: Optional[torch.LongTensor]
|
||||||
|
max_length_q: Optional[int]
|
||||||
|
max_length_k: Optional[int]
|
||||||
|
|||||||
@@ -33,12 +33,14 @@ from ...activations import ACT2FN
|
|||||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutputWithPast,
|
BaseModelOutputWithPast,
|
||||||
CausalLMOutputWithPast,
|
CausalLMOutputWithPast,
|
||||||
)
|
)
|
||||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
|
from ...processing_utils import Unpack
|
||||||
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
@@ -832,6 +834,7 @@ class CohereModel(CoherePreTrainedModel):
|
|||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
@@ -913,6 +916,7 @@ class CohereModel(CoherePreTrainedModel):
|
|||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
position_embeddings=position_embeddings,
|
position_embeddings=position_embeddings,
|
||||||
|
**flash_attn_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ from ...modeling_outputs import (
|
|||||||
)
|
)
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
|
add_code_sample_docstrings,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
is_flash_attn_2_available,
|
is_flash_attn_2_available,
|
||||||
@@ -51,7 +52,11 @@ from .configuration_glm import GlmConfig
|
|||||||
if is_flash_attn_2_available():
|
if is_flash_attn_2_available():
|
||||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
from ...modeling_flash_attention_utils import _flash_attention_forward
|
||||||
|
|
||||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward
|
||||||
|
from ...processing_utils import Unpack
|
||||||
|
|
||||||
|
|
||||||
|
_CHECKPOINT_FOR_DOC = "dummy"
|
||||||
|
|
||||||
|
|
||||||
class GlmRMSNorm(nn.Module):
|
class GlmRMSNorm(nn.Module):
|
||||||
@@ -736,6 +741,7 @@ class GlmModel(GlmPreTrainedModel):
|
|||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
@@ -817,6 +823,7 @@ class GlmModel(GlmPreTrainedModel):
|
|||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
position_embeddings=position_embeddings,
|
position_embeddings=position_embeddings,
|
||||||
|
**flash_attn_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
@@ -1222,6 +1229,11 @@ class GlmForTokenClassification(GlmPreTrainedModel):
|
|||||||
self.model.embed_tokens = value
|
self.model.embed_tokens = value
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(GLM_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(GLM_INPUTS_DOCSTRING)
|
||||||
|
@add_code_sample_docstrings(
|
||||||
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||||
|
output_type=TokenClassifierOutput,
|
||||||
|
config_class=_CONFIG_FOR_DOC,
|
||||||
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
|||||||
@@ -46,6 +46,8 @@ from .configuration_glm import GlmConfig
|
|||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
_CHECKPOINT_FOR_DOC = "dummy"
|
||||||
|
|
||||||
|
|
||||||
class GlmRMSNorm(Phi3RMSNorm):
|
class GlmRMSNorm(Phi3RMSNorm):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from ...activations import ACT2FN
|
|||||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
from ...cache_utils import Cache, DynamicCache, StaticCache
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_flash_attention_utils import _flash_attention_forward
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutputWithPast,
|
BaseModelOutputWithPast,
|
||||||
CausalLMOutputWithPast,
|
CausalLMOutputWithPast,
|
||||||
@@ -39,8 +39,10 @@ from ...modeling_outputs import (
|
|||||||
)
|
)
|
||||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
|
from ...processing_utils import Unpack
|
||||||
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
|
LossKwargs,
|
||||||
add_code_sample_docstrings,
|
add_code_sample_docstrings,
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
@@ -422,6 +424,7 @@ class LlamaFlashAttention2(LlamaAttention):
|
|||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
||||||
|
**kwargs: Unpack[FlashAttentionKwargs],
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
if isinstance(past_key_value, StaticCache):
|
if isinstance(past_key_value, StaticCache):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -506,6 +509,7 @@ class LlamaFlashAttention2(LlamaAttention):
|
|||||||
sliding_window=getattr(self, "sliding_window", None),
|
sliding_window=getattr(self, "sliding_window", None),
|
||||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||||
is_causal=self.is_causal,
|
is_causal=self.is_causal,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
||||||
@@ -870,6 +874,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
@@ -951,6 +956,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
position_embeddings=position_embeddings,
|
position_embeddings=position_embeddings,
|
||||||
|
**flash_attn_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
@@ -1102,6 +1108,9 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||||||
return causal_mask
|
return causal_mask
|
||||||
|
|
||||||
|
|
||||||
|
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||||
|
|
||||||
|
|
||||||
class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
|
class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
|
||||||
_tied_weights_keys = ["lm_head.weight"]
|
_tied_weights_keys = ["lm_head.weight"]
|
||||||
|
|
||||||
@@ -1148,7 +1157,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
|
|||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
num_logits_to_keep: int = 0,
|
num_logits_to_keep: int = 0,
|
||||||
**loss_kwargs,
|
**kwargs: Unpack[KwargsForCausalLM],
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
@@ -1198,6 +1207,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
|
|||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
@@ -1211,7 +1221,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
|
|||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **loss_kwargs)
|
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[1:]
|
output = (logits,) + outputs[1:]
|
||||||
|
|||||||
@@ -815,7 +815,7 @@ class BatchEncoding(UserDict):
|
|||||||
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
|
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
|
||||||
# into a HalfTensor
|
# into a HalfTensor
|
||||||
if isinstance(device, str) or is_torch_device(device) or isinstance(device, int):
|
if isinstance(device, str) or is_torch_device(device) or isinstance(device, int):
|
||||||
self.data = {k: v.to(device=device) for k, v in self.data.items() if isinstance(v, torch.Tensor)}
|
self.data = {k: v.to(device=device) if isinstance(v, torch.Tensor) else v for k, v in self.data.items()}
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.")
|
logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.")
|
||||||
return self
|
return self
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ from .doc import (
|
|||||||
from .generic import (
|
from .generic import (
|
||||||
ContextManagers,
|
ContextManagers,
|
||||||
ExplicitEnum,
|
ExplicitEnum,
|
||||||
|
LossKwargs,
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
PaddingStrategy,
|
PaddingStrategy,
|
||||||
TensorType,
|
TensorType,
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from contextlib import ExitStack, contextmanager
|
|||||||
from dataclasses import fields, is_dataclass
|
from dataclasses import fields, is_dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
from typing import Any, ContextManager, Iterable, List, Optional, Tuple
|
from typing import Any, ContextManager, Iterable, List, Optional, Tuple, TypedDict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from packaging import version
|
from packaging import version
|
||||||
@@ -854,3 +854,16 @@ def filter_out_non_signature_kwargs(extra: Optional[list] = None):
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
class LossKwargs(TypedDict, total=False):
|
||||||
|
"""
|
||||||
|
Keyword arguments to be passed to the loss function
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
num_items_in_batch (`int`, *optional*):
|
||||||
|
Number of items in the batch. It is recommended to pass it when
|
||||||
|
you are doing gradient accumulation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_items_in_batch: Optional[int]
|
||||||
|
|||||||
Reference in New Issue
Block a user