From 8bcf9c8dd4419b6b0216fa68c2a5b2b156cd64de Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 7 Jun 2024 11:51:41 +0200 Subject: [PATCH] Fix jetmoe model (#31279) * Fix jetmoe model * Remove skip-tests --- .../models/jetmoe/modeling_jetmoe.py | 22 ++++++++----------- tests/models/jetmoe/test_modeling_jetmoe.py | 8 ------- 2 files changed, 9 insertions(+), 21 deletions(-) diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index bd25c77fda..3ae381880b 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -1404,18 +1404,14 @@ class JetMoeForCausalLM(JetMoePreTrainedModel): past_length = 0 if past_key_values is not None: - if isinstance(past_key_values, Cache): - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() - max_cache_length = ( - torch.tensor(past_key_values.get_max_length(), device=input_ids.device) - if past_key_values.get_max_length() is not None - else None - ) - cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) - # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None + # Past key values are always initialized with a `Cache` object -> no need for if-else anymore + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where @@ -1446,7 +1442,7 @@ class JetMoeForCausalLM(JetMoePreTrainedModel): position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: + if inputs_embeds is not None and past_length == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise diff --git a/tests/models/jetmoe/test_modeling_jetmoe.py b/tests/models/jetmoe/test_modeling_jetmoe.py index c61ff01be2..12e5dd682c 100644 --- a/tests/models/jetmoe/test_modeling_jetmoe.py +++ b/tests/models/jetmoe/test_modeling_jetmoe.py @@ -472,14 +472,6 @@ class JetMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest("JetMoe flash attention does not support right padding") - @unittest.skip("TODO: @ArthurZucker - Breaks after #30536 ") - def test_beam_sample_generate(self): - pass - - @unittest.skip("TODO: @ArthurZucker - Breaks after #30536 ") - def test_generate_from_inputs_embeds_decoder_only(self): - pass - @require_torch class JetMoeIntegrationTest(unittest.TestCase):