From 9cde2f5d420f23764acd8f6808f3e7d2836bbe87 Mon Sep 17 00:00:00 2001 From: Pablo Montalvo <39954772+molbap@users.noreply.github.com> Date: Tue, 20 May 2025 15:15:54 +0200 Subject: [PATCH] Minor llama4 fixes (#38123) * fix wrong scaling value/default Cache init * style * fix various issues on integration tests * change expected outputs * fixup * fix config access * protect default scaling --- .../models/llama4/modeling_llama4.py | 35 ++++++++++-- tests/models/llama4/test_modeling_llama4.py | 56 +++++++++---------- 2 files changed, 58 insertions(+), 33 deletions(-) diff --git a/src/transformers/models/llama4/modeling_llama4.py b/src/transformers/models/llama4/modeling_llama4.py index c4ef631e47..5bf5f1488c 100644 --- a/src/transformers/models/llama4/modeling_llama4.py +++ b/src/transformers/models/llama4/modeling_llama4.py @@ -258,6 +258,33 @@ def eager_attention_forward( return attn_output, attn_weights +# Adapted from transformers.models.llama.modeling_llama.eager_attention_forward -> llama4 doesn't cast attn weights to fp32 +def vision_eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * module.head_dim**-0.5 + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class Llama4TextAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -534,10 +561,10 @@ class Llama4TextModel(Llama4PreTrainedModel): inputs_embeds = self.embed_tokens(input_ids.to(self.embed_tokens.weight.device)) if use_cache and past_key_values is None: - if self.config.get_text_config().get("attention_chunk_size") is not None: + if self.config.get_text_config().attention_chunk_size is not None: past_key_values = HybridChunkedCache(self.config, inputs_embeds.shape[0], inputs_embeds.shape[1]) else: - past_key_values = DynamicCache(self.config, inputs_embeds.shape[0], inputs_embeds.shape[1]) + past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -1099,7 +1126,7 @@ class Llama4VisionAttention(nn.Module): key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - attention_interface: Callable = eager_attention_forward + attention_interface: Callable = vision_eager_attention_forward # flex disable because breaks on TP 8, embed is 88 not power of 2 if self.config._attn_implementation not in ["eager", "flex_attention"]: if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): @@ -1117,7 +1144,7 @@ class Llama4VisionAttention(nn.Module): value_states, None, dropout=0.0 if not self.training else self.attention_dropout, - scaling=None, + scaling=None, # TODO Might be enforced here for TP compatibility as scaling is not just sqrt(head_dim) is_causal=False, # HAS TO BE ENFORCED **kwargs, ) diff --git a/tests/models/llama4/test_modeling_llama4.py b/tests/models/llama4/test_modeling_llama4.py index f3f1d137d1..b349c47e3c 100644 --- a/tests/models/llama4/test_modeling_llama4.py +++ b/tests/models/llama4/test_modeling_llama4.py @@ -37,7 +37,7 @@ if is_torch_available(): @require_torch_large_gpu @require_read_token class Llama4IntegrationTest(unittest.TestCase): - model_id = "ll-re/Llama-4-17B-Omni-Instruct" + model_id = "meta-llama/Llama-4-Scout-17B-16E" # This variable is used to determine which CUDA device are we using for our runners (A10 or T4) # Depending on the hardware we get different logits / generations cuda_compute_capability_major_version = None @@ -48,14 +48,17 @@ class Llama4IntegrationTest(unittest.TestCase): # 8 is for A100 / A10 and 7 for T4 cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] cls.model = Llama4ForConditionalGeneration.from_pretrained( - "ll-re/Llama-4-17B-Omni-Instruct", device_map="auto", torch_dtype=torch.float32 + "meta-llama/Llama-4-Scout-17B-16E", + device_map="auto", + torch_dtype=torch.float32, + attn_implementation="eager", ) def setUp(self): - self.processor = Llama4Processor.from_pretrained("ll-re/Llama-4-17B-Omni-Instruct", padding_side="left") + self.processor = Llama4Processor.from_pretrained("meta-llama/Llama-4-Scout-17B-16E", padding_side="left") url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png" - self.messages = [ + self.messages_1 = [ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, { "role": "user", @@ -66,27 +69,7 @@ class Llama4IntegrationTest(unittest.TestCase): }, ] - def test_model_17b_16e_fp16(self): - EXPECTED_TEXT = [ - "The capital of France is Paris, which is located in the north-central part of the country. Paris is known for its iconic landmarks such as the", - "Roses are red, violets are blue, and this poem is about you. Roses are red, violets are blue, and I love", - ] - - messages = [ - {"role": "user", "content": "Who are you?"}, - ] - inputs = self.processor.apply_chat_template( - messages, add_generation_prompt=True, return_tensors="pt", return_dict=True - ).to(torch_device) - - output = self.model.generate(**inputs, max_new_tokens=100) - output_text = self.processor.batch_decode(output, skip_special_tokens=True) - - print(output_text) - self.assertEqual(output_text, EXPECTED_TEXT) - - def test_model_17b_16e_batch(self): - messages_2 = [ + self.messages_2 = [ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, { "role": "user", @@ -101,20 +84,35 @@ class Llama4IntegrationTest(unittest.TestCase): }, ] + def test_model_17b_16e_fp16(self): + EXPECTED_TEXT = [ + 'system\n\nYou are a helpful assistant.user\n\nWhat is shown in this image?assistant\n\nThe image shows a cow standing on a beach, with a blue sky and a body of water in the background. The cow is brown with a white' + ] # fmt: skip + inputs = self.processor.apply_chat_template( - [self.messages, messages_2], + self.messages_1, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True + ).to(device=torch_device, dtype=self.model.dtype) + output = self.model.generate(**inputs, max_new_tokens=30, do_sample=False) + output_text = self.processor.batch_decode(output, skip_special_tokens=True) + + print(output_text) + self.assertEqual(output_text, EXPECTED_TEXT) + + def test_model_17b_16e_batch(self): + inputs = self.processor.apply_chat_template( + [self.messages_1, self.messages_2], tokenize=True, return_dict=True, return_tensors="pt", padding=True, add_generation_prompt=True, - ).to(torch_device) + ).to(device=torch_device, dtype=torch.float32) output = self.model.generate(**inputs, max_new_tokens=30, do_sample=False) output_text = self.processor.batch_decode(output, skip_special_tokens=True) EXPECTED_TEXTS = [ - 'user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear turquoise water and a blue sky in the background. It looks like', - "user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, these images are not identical. \n\nHere's a breakdown of the differences:\n\n* **Image 1:** Shows a cow" + 'system\n\nYou are a helpful assistant.user\n\nWhat is shown in this image?assistant\n\nThe image shows a cow standing on a beach, with a blue sky and a body of water in the background. The cow is brown with a white', + 'system\n\nYou are a helpful assistant.user\n\nAre these images identical?assistant\n\nNo, these images are not identical. The first image shows a cow standing on a beach with a blue sky and a white cloud in the background.' ] # fmt: skip self.assertEqual(output_text, EXPECTED_TEXTS)