Remove graph breaks for torch.compile() in flash_attention_forward when Lllama Model is padding free tuned (#33932)

* fix: fixes for graph breaks

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* fix: formatting

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* fix: import error

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* fix: Add Fa2Kwargs

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* fix: PR Changes

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* PR changes

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* PR changes

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* PR changes

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* PR changes

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* Revert "PR changes"

This reverts commit 39d2868e5c93cc5f3f3c7c6ff981b66614c0e0e4.

* PR changes

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* fix: FlashAttentionKwarg

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* fix: FlashAttentionKwarg

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* PR Changes

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* PR Changes

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* PR Changes

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* PR Changes

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* PR Changes

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* addition of documentation

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* change in _flash_attention_forward

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* make fix-copies

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* revert make fix-copies

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* fix copies

* style

* loss kwargs typing

* style and pull latest changes

---------

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
This commit is contained in:
Abhishek Maurya
2024-10-24 05:02:54 -04:00
committed by GitHub
parent b0f0c61899
commit 65753d6065
9 changed files with 192 additions and 20 deletions

View File

@@ -348,6 +348,99 @@ model = AutoModelForCausalLM.from_pretrained(
) )
``` ```
### Fine-Tuning with torch.compile and Padding-Free Data Collation
In addition to optimizing inference, you can also enhance the training efficiency of large language models by leveraging torch.compile during fine-tuning and using a padding-free data collator. This approach can significantly speed up training and reduce computational overhead.
Here's how you can fine-tune a Llama model using SFTTrainer from the TRL library, with torch_compile enabled and a padding-free data collator:
```
#################### IMPORTS ###################
import math
import datasets
import dataclasses
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments
)
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM
#################### MODEL LOADING WITH FLASH ATTENTION ###################
model_name = "meta-llama/Llama-3.2-1B"
model = AutoModelForCausalLM.from_pretrained(
model_name,
attn_implementation="flash_attention_2" # Enables FlashAttention-2
)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
#################### DATA PREPROCESSING (PADDING-FREE) ###################
response_template = "\n### Label:"
response_template_ids = tokenizer.encode(
response_template, add_special_tokens=False
)[2:] # Exclude special tokens
data_collator = DataCollatorForCompletionOnlyLM(
response_template_ids=response_template_ids,
tokenizer=tokenizer,
ignore_index=-100,
padding_free=True # Enables padding-free collation
)
def format_dataset(example):
return {
"output": example["output"] + tokenizer.eos_token
}
data_files = {"train": "path/to/dataset"} # Replace with your dataset path
json_dataset = datasets.load_dataset("json", data_files=data_files)
formatted_train_dataset = json_dataset["train"].map(format_dataset)
################# TRAINING CONFIGURATION ############################
train_args = TrainingArguments(
num_train_epochs=5,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=1e-5,
weight_decay=0.0,
warmup_ratio=0.03,
lr_scheduler_type="cosine",
logging_steps=1,
include_tokens_per_second=True,
save_strategy="epoch",
output_dir="output",
torch_compile=True, # Enables torch.compile
torch_compile_backend="inductor",
torch_compile_mode="default"
)
# Convert TrainingArguments to SFTConfig
transformer_train_arg_fields = [x.name for x in dataclasses.fields(SFTConfig)]
transformer_kwargs = {
k: v
for k, v in train_args.to_dict().items()
if k in transformer_train_arg_fields
}
training_args = SFTConfig(**transformer_kwargs)
####################### FINE-TUNING #####################
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=formatted_train_dataset,
data_collator=data_collator,
dataset_text_field="output",
args=training_args,
)
trainer.train()
```
### PyTorch scaled dot product attention ### PyTorch scaled dot product attention
Scaled dot product attention (SDPA) is automatically enabled in PyTorch 2.0 and it supports FlashAttention, xFormers, and PyTorch's C++ implementation. SDPA chooses the most performant attention algorithm if you're using a CUDA backend. For other backends, SDPA defaults to the PyTorch C++ implementation. Scaled dot product attention (SDPA) is automatically enabled in PyTorch 2.0 and it supports FlashAttention, xFormers, and PyTorch's C++ implementation. SDPA chooses the most performant attention algorithm if you're using a CUDA backend. For other backends, SDPA defaults to the PyTorch C++ implementation.

View File

@@ -15,7 +15,7 @@
import inspect import inspect
import os import os
from typing import Optional, Tuple from typing import Optional, Tuple, TypedDict
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@@ -180,6 +180,10 @@ def prepare_fa2_from_position_ids(query, key, value, position_ids):
return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length)) return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))
flash_241 = is_flash_attn_greater_or_equal("2.4.1")
deterministic_g = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
def _flash_attention_forward( def _flash_attention_forward(
query_states: torch.Tensor, query_states: torch.Tensor,
key_states: torch.Tensor, key_states: torch.Tensor,
@@ -194,6 +198,10 @@ def _flash_attention_forward(
use_top_left_mask: bool = False, use_top_left_mask: bool = False,
softcap: Optional[float] = None, softcap: Optional[float] = None,
deterministic: bool = None, deterministic: bool = None,
cu_seq_lens_q: Optional[torch.LongTensor] = None,
cu_seq_lens_k: Optional[torch.LongTensor] = None,
max_length_q: Optional[int] = None,
max_length_k: Optional[int] = None,
): ):
""" """
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
@@ -232,9 +240,9 @@ def _flash_attention_forward(
) )
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {} flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
if is_flash_attn_greater_or_equal("2.4.1"): if flash_241:
if deterministic is None: if deterministic is None:
deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" deterministic = deterministic_g
flash_kwargs["deterministic"] = deterministic flash_kwargs["deterministic"] = deterministic
if softcap is not None: if softcap is not None:
@@ -267,24 +275,32 @@ def _flash_attention_forward(
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing # If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage. # then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
# Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach # Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
# Note: the `torch.diff(...)` condition is last to use short-circuit and avoid the cuda synchronization it incurs during inference (query_length == 1 always) elif position_ids is not None and (
elif position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all(): max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())
):
batch_size = query_states.size(0) batch_size = query_states.size(0)
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
query_states, key_states, value_states, position_ids
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens if cu_seq_lens_q is None or cu_seq_lens_k is None:
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = (
prepare_fa2_from_position_ids(query_states, key_states, value_states, position_ids)
)
cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens
max_length_q, max_length_k = max_seq_lens
else:
query_states = query_states.reshape(-1, query_states.size(-2), query_states.size(-1))
key_states = key_states.reshape(-1, key_states.size(-2), key_states.size(-1))
value_states = value_states.reshape(-1, value_states.size(-2), value_states.size(-1))
attn_output = flash_attn_varlen_func( attn_output = flash_attn_varlen_func(
query_states, query_states,
key_states, key_states,
value_states, value_states,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seq_lens_q,
cu_seqlens_k=cu_seqlens_k, cu_seqlens_k=cu_seq_lens_k,
max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_q=max_length_q,
max_seqlen_k=max_seqlen_in_batch_k, max_seqlen_k=max_length_k,
dropout_p=dropout, dropout_p=dropout,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
causal=causal, causal=causal,
@@ -299,3 +315,24 @@ def _flash_attention_forward(
) )
return attn_output return attn_output
class FlashAttentionKwargs(TypedDict, total=False):
"""
Keyword arguments for Flash Attention with Compile.
Attributes:
cu_seq_lens_q (`torch.LongTensor`, *optional*)
Gets cumlative sequence length for query state.
cu_seq_lens_k (`torch.LongTensor`, *optional*)
Gets cumlative sequence length for key state.
max_length_q (`int`, *optional*):
Maximum sequence length for query state.
max_length_k (`int`, *optional*):
Maximum sequence length for key state.
"""
cu_seq_lens_q: Optional[torch.LongTensor]
cu_seq_lens_k: Optional[torch.LongTensor]
max_length_q: Optional[int]
max_length_k: Optional[int]

View File

@@ -33,12 +33,14 @@ from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
CausalLMOutputWithPast, CausalLMOutputWithPast,
) )
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import ( from ...utils import (
add_start_docstrings, add_start_docstrings,
@@ -832,6 +834,7 @@ class CohereModel(CoherePreTrainedModel):
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]: ) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
@@ -913,6 +916,7 @@ class CohereModel(CoherePreTrainedModel):
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
position_embeddings=position_embeddings, position_embeddings=position_embeddings,
**flash_attn_kwargs,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]

View File

@@ -38,6 +38,7 @@ from ...modeling_outputs import (
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import (
add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_flash_attn_2_available, is_flash_attn_2_available,
@@ -51,7 +52,11 @@ from .configuration_glm import GlmConfig
if is_flash_attn_2_available(): if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_flash_attention_utils import _flash_attention_forward
from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward
from ...processing_utils import Unpack
_CHECKPOINT_FOR_DOC = "dummy"
class GlmRMSNorm(nn.Module): class GlmRMSNorm(nn.Module):
@@ -736,6 +741,7 @@ class GlmModel(GlmPreTrainedModel):
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]: ) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
@@ -817,6 +823,7 @@ class GlmModel(GlmPreTrainedModel):
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
position_embeddings=position_embeddings, position_embeddings=position_embeddings,
**flash_attn_kwargs,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
@@ -1222,6 +1229,11 @@ class GlmForTokenClassification(GlmPreTrainedModel):
self.model.embed_tokens = value self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(GLM_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(GLM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward( def forward(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,

View File

@@ -46,6 +46,8 @@ from .configuration_glm import GlmConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "dummy"
class GlmRMSNorm(Phi3RMSNorm): class GlmRMSNorm(Phi3RMSNorm):
pass pass

View File

@@ -29,7 +29,7 @@ from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
CausalLMOutputWithPast, CausalLMOutputWithPast,
@@ -39,8 +39,10 @@ from ...modeling_outputs import (
) )
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import ( from ...utils import (
LossKwargs,
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
@@ -422,6 +424,7 @@ class LlamaFlashAttention2(LlamaAttention):
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if isinstance(past_key_value, StaticCache): if isinstance(past_key_value, StaticCache):
raise ValueError( raise ValueError(
@@ -506,6 +509,7 @@ class LlamaFlashAttention2(LlamaAttention):
sliding_window=getattr(self, "sliding_window", None), sliding_window=getattr(self, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask, use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal, is_causal=self.is_causal,
**kwargs,
) )
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
@@ -870,6 +874,7 @@ class LlamaModel(LlamaPreTrainedModel):
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]: ) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
@@ -951,6 +956,7 @@ class LlamaModel(LlamaPreTrainedModel):
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
position_embeddings=position_embeddings, position_embeddings=position_embeddings,
**flash_attn_kwargs,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
@@ -1102,6 +1108,9 @@ class LlamaModel(LlamaPreTrainedModel):
return causal_mask return causal_mask
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
@@ -1148,7 +1157,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0, num_logits_to_keep: int = 0,
**loss_kwargs, **kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]: ) -> Union[Tuple, CausalLMOutputWithPast]:
r""" r"""
Args: Args:
@@ -1198,6 +1207,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
cache_position=cache_position, cache_position=cache_position,
**kwargs,
) )
hidden_states = outputs[0] hidden_states = outputs[0]
@@ -1211,7 +1221,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
loss = None loss = None
if labels is not None: if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **loss_kwargs) loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]

View File

@@ -815,7 +815,7 @@ class BatchEncoding(UserDict):
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs # Otherwise it passes the casts down and casts the LongTensor containing the token idxs
# into a HalfTensor # into a HalfTensor
if isinstance(device, str) or is_torch_device(device) or isinstance(device, int): if isinstance(device, str) or is_torch_device(device) or isinstance(device, int):
self.data = {k: v.to(device=device) for k, v in self.data.items() if isinstance(v, torch.Tensor)} self.data = {k: v.to(device=device) if isinstance(v, torch.Tensor) else v for k, v in self.data.items()}
else: else:
logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.") logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.")
return self return self

View File

@@ -37,6 +37,7 @@ from .doc import (
from .generic import ( from .generic import (
ContextManagers, ContextManagers,
ExplicitEnum, ExplicitEnum,
LossKwargs,
ModelOutput, ModelOutput,
PaddingStrategy, PaddingStrategy,
TensorType, TensorType,

View File

@@ -24,7 +24,7 @@ from contextlib import ExitStack, contextmanager
from dataclasses import fields, is_dataclass from dataclasses import fields, is_dataclass
from enum import Enum from enum import Enum
from functools import partial, wraps from functools import partial, wraps
from typing import Any, ContextManager, Iterable, List, Optional, Tuple from typing import Any, ContextManager, Iterable, List, Optional, Tuple, TypedDict
import numpy as np import numpy as np
from packaging import version from packaging import version
@@ -854,3 +854,16 @@ def filter_out_non_signature_kwargs(extra: Optional[list] = None):
return wrapper return wrapper
return decorator return decorator
class LossKwargs(TypedDict, total=False):
"""
Keyword arguments to be passed to the loss function
Attributes:
num_items_in_batch (`int`, *optional*):
Number of items in the batch. It is recommended to pass it when
you are doing gradient accumulation.
"""
num_items_in_batch: Optional[int]