Add padding-free to Granite hybrid moe models (#39677)

* start fixing kwarg handling

* fmt

* updates padding free tests

* docs

* add missing kwargs modeling_granitemoe.py

* run modular util

* rm unrelated changes from modular util
This commit is contained in:
Garrett Goon
2025-07-25 14:10:50 -04:00
committed by GitHub
parent d6e9f71a6e
commit 97f8c71f52
7 changed files with 146 additions and 16 deletions

View File

@@ -48,6 +48,32 @@ for i in output:
This HF implementation is contributed by [Sukriti Sharma](https://huggingface.co/SukritiSharma) and [Alexander Brooks](https://huggingface.co/abrooks9944).
## Notes
- `GraniteMoeHybridForCausalLM` supports padding-free training which concatenates distinct training examples while still processing inputs as separate batches. It can significantly accelerate inference by [~2x](https://github.com/huggingface/transformers/pull/35861#issue-2807873129) (depending on model and data distribution) and reduce memory-usage if there are examples of varying lengths by avoiding unnecessary compute and memory overhead from padding tokens.
Padding-free training requires the `flash-attn`, `mamba-ssm`, and `causal-conv1d` packages and the following arguments must be passed to the model in addition to `input_ids` and `labels`.
- `position_ids: torch.LongTensor`: the position index of each token in each sequence.
- `seq_idx: torch.IntTensor`: the index of each sequence in the batch.
- Each of the [`FlashAttentionKwargs`]
- `cu_seq_lens_q: torch.LongTensor`: the cumulative sequence lengths of all queries.
- `cu_seq_lens_k: torch.LongTensor`: the cumulative sequence lengths of all keys.
- `max_length_q: int`: the longest query length in the batch.
- `max_length_k: int`: the longest key length in the batch.
The `attention_mask` inputs should not be provided. The [`DataCollatorWithFlattening`] programmatically generates the set of additional arguments above using `return_seq_idx=True` and `return_flash_attn_kwargs=True`. See the [Improving Hugging Face Training Efficiency Through Packing with Flash Attention](https://huggingface.co/blog/packing-with-FA2) blog post for additional information.
```python
from transformers import DataCollatorWithFlattening
# Example of using padding-free training
data_collator = DataCollatorWithFlattening(
tokenizer=tokenizer,
return_seq_idx=True,
return_flash_attn_kwargs=True
)
```
## GraniteMoeHybridConfig
@@ -61,4 +87,4 @@ This HF implementation is contributed by [Sukriti Sharma](https://huggingface.co
## GraniteMoeHybridForCausalLM
[[autodoc]] GraniteMoeHybridForCausalLM
- forward
- forward