Add padding-free to Granite hybrid moe models (#39677)
* start fixing kwarg handling * fmt * updates padding free tests * docs * add missing kwargs modeling_granitemoe.py * run modular util * rm unrelated changes from modular util
This commit is contained in:
@@ -551,6 +551,15 @@ class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
|
||||
dummy_attention_mask = inputs_dict["attention_mask"]
|
||||
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
|
||||
# Ensure inputs_dict also has labels in it, as their presence/absence can induce
|
||||
# dtype conversions. This also lets us compare losses.
|
||||
labels = inputs_dict["input_ids"].clone()
|
||||
# Mask padding tokens
|
||||
labels[~dummy_attention_mask.bool()] = -100
|
||||
# Also need to mask the first non-trivial token to match the padding-free batch.
|
||||
first_nonneg_idx = (labels >= 0).int().argmax(dim=1)
|
||||
labels[torch.arange(labels.size(0), device=labels.device), first_nonneg_idx] = -100
|
||||
inputs_dict["labels"] = labels
|
||||
|
||||
model = (
|
||||
model_class.from_pretrained(
|
||||
@@ -586,6 +595,10 @@ class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
tol = torch.finfo(torch.float16).eps
|
||||
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
|
||||
|
||||
loss_padded = res_padded.loss
|
||||
loss_padfree = res_padfree.loss
|
||||
torch.testing.assert_close(loss_padded, loss_padfree)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
|
||||
Reference in New Issue
Block a user