Add padding-free to Granite hybrid moe models (#39677)

* start fixing kwarg handling

* fmt

* updates padding free tests

* docs

* add missing kwargs modeling_granitemoe.py

* run modular util

* rm unrelated changes from modular util
This commit is contained in:
Garrett Goon
2025-07-25 14:10:50 -04:00
committed by GitHub
parent d6e9f71a6e
commit 97f8c71f52
7 changed files with 146 additions and 16 deletions

View File

@@ -551,6 +551,15 @@ class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
dummy_attention_mask = inputs_dict["attention_mask"]
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
# Ensure inputs_dict also has labels in it, as their presence/absence can induce
# dtype conversions. This also lets us compare losses.
labels = inputs_dict["input_ids"].clone()
# Mask padding tokens
labels[~dummy_attention_mask.bool()] = -100
# Also need to mask the first non-trivial token to match the padding-free batch.
first_nonneg_idx = (labels >= 0).int().argmax(dim=1)
labels[torch.arange(labels.size(0), device=labels.device), first_nonneg_idx] = -100
inputs_dict["labels"] = labels
model = (
model_class.from_pretrained(
@@ -586,6 +595,10 @@ class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
tol = torch.finfo(torch.float16).eps
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
loss_padded = res_padded.loss
loss_padfree = res_padfree.loss
torch.testing.assert_close(loss_padded, loss_padfree)
@slow
@require_torch