Fix: Jamba batched generation (#32914)
* init fix * fix mask during cached forward, move mask related stuff to own function * adjust tests as left padding does not change logits as much anymore + batch gen (with todo on logits comp) * revert overwriting new integration tests * move some comments to docstring
This commit is contained in:
@@ -458,51 +458,6 @@ class JambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||
)
|
||||
|
||||
def test_left_padding_compatibility(self):
|
||||
r"""
|
||||
Overriding the test_left_padding_compatibility test as the mamba layers accentuate the numerical differences
|
||||
effect of the left padding discussed in the issue in the note. Using a more permissive tolerance value.
|
||||
"""
|
||||
import inspect
|
||||
# NOTE: left-padding results in small numerical differences. This is expected.
|
||||
# See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535
|
||||
|
||||
# First, filter out models that don't support left padding - generative and decoder-only.
|
||||
# Jamba is a decoder-only architecture
|
||||
decoder_only_classes = self.all_generative_model_classes
|
||||
|
||||
# Then, test left-padding
|
||||
def _prepare_model_kwargs(input_ids, attention_mask, signature):
|
||||
model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||
if "position_ids" in signature:
|
||||
position_ids = torch.cumsum(attention_mask, dim=-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
model_kwargs["position_ids"] = position_ids
|
||||
if "cache_position" in signature:
|
||||
cache_position = torch.arange(input_ids.shape[-1], device=torch_device)
|
||||
model_kwargs["cache_position"] = cache_position
|
||||
return model_kwargs
|
||||
|
||||
for model_class in decoder_only_classes:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
signature = inspect.signature(model.forward).parameters.keys()
|
||||
|
||||
# Without padding
|
||||
model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature)
|
||||
next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :]
|
||||
|
||||
# With left-padding (length 32)
|
||||
pad_size = (input_ids.shape[0], 32)
|
||||
padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * config.pad_token_id
|
||||
padded_input_ids = torch.cat((padding, input_ids), dim=1)
|
||||
padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1)
|
||||
model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature)
|
||||
next_logits_with_padding = model(**model_kwargs).logits[:, -1, :]
|
||||
|
||||
# They should result in very similar logits
|
||||
self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=3e-3))
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@require_bitsandbytes
|
||||
@@ -692,7 +647,7 @@ class JambaModelIntegrationTest(unittest.TestCase):
|
||||
EXPECTED_LOGITS_NO_GRAD = torch.tensor(
|
||||
[
|
||||
0.0134, -0.2197, 0.0396, -0.1011, 0.0459, 0.2793, -0.1465, 0.1660,
|
||||
-0.2930, -0.0278, 0.0269, -0.5586, -0.2109, -0.1426, -0.1553, 0.1279,
|
||||
-0.2930, -0.0278, 0.0269, -0.5586, -0.2109, -0.1426, -0.1553, 0.1279,
|
||||
0.0713, 0.2246, 0.1660, -0.2314, -0.1187, -0.1162, -0.1377, 0.0292,
|
||||
0.1245, 0.2275, 0.0374, 0.1089, -0.1348, -0.2305, 0.1484, -0.3906,
|
||||
0.1709, -0.4590, -0.0447, 0.2422, 0.1592, -0.1855, 0.2441, -0.0562
|
||||
@@ -737,10 +692,11 @@ class JambaModelIntegrationTest(unittest.TestCase):
|
||||
with torch.no_grad():
|
||||
logits = self.model(input_ids=inputs["input_ids"]).logits
|
||||
|
||||
# TODO fix logits
|
||||
EXPECTED_LOGITS_NO_GRAD_0 = torch.tensor(
|
||||
[
|
||||
0.0166, -0.2227, 0.0396, -0.1035, 0.0459, 0.2754, -0.1445, 0.1641,
|
||||
-0.2910, -0.0273, 0.0227, -0.5547, -0.2139, -0.1396, -0.1582, 0.1289,
|
||||
-0.2910, -0.0273, 0.0227, -0.5547, -0.2139, -0.1396, -0.1582, 0.1289,
|
||||
0.0713, 0.2256, 0.1699, -0.2295, -0.1182, -0.1167, -0.1387, 0.0261,
|
||||
0.1270, 0.2285, 0.0403, 0.1108, -0.1318, -0.2334, 0.1455, -0.3945,
|
||||
0.1729, -0.4609, -0.0410, 0.2412, 0.1572, -0.1895, 0.2402, -0.0583
|
||||
@@ -749,7 +705,7 @@ class JambaModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
EXPECTED_LOGITS_NO_GRAD_1 = torch.tensor(
|
||||
[
|
||||
-0.1318, 0.2354, -0.4160, -0.0325, -0.0461, 0.0342, 0.2578, 0.0874,
|
||||
-0.1318, 0.2354, -0.4160, -0.0325, -0.0461, 0.0342, 0.2578, 0.0874,
|
||||
0.1484, 0.2266, -0.1182, -0.1396, -0.1494, -0.1089, -0.0019, -0.2852,
|
||||
0.1973, -0.2676, 0.0586, -0.1992, -0.2520, -0.1147, -0.1973, 0.2129,
|
||||
0.0520, 0.1699, 0.1816, 0.1289, 0.1699, -0.1216, -0.2656, -0.2891,
|
||||
|
||||
Reference in New Issue
Block a user