diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index ed306fac94..6c4b7498b3 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -507,6 +507,8 @@
title: Llama2
- local: model_doc/llama3
title: Llama3
+ - local: model_doc/llama4
+ title: Llama4
- local: model_doc/longformer
title: Longformer
- local: model_doc/longt5
diff --git a/docs/source/en/model_doc/llama4.md b/docs/source/en/model_doc/llama4.md
new file mode 100644
index 0000000000..8e2cd3a278
--- /dev/null
+++ b/docs/source/en/model_doc/llama4.md
@@ -0,0 +1,442 @@
+
+
+# Llama4
+
+
+
+
+

+

+
+
+
+Llama 4, developed by Meta, introduces a new auto-regressive Mixture-of-Experts (MoE) architecture.
+This generation includes two models:
+- The highly capable Llama 4 Maverick with 17B active parameters out of ~400B total, with 128 experts.
+- The efficient Llama 4 Scout also has 17B active parameters out of ~109B total, using just 16 experts.
+-
+Both models leverage early fusion for native multimodality, enabling them to process text and image inputs.
+Maverick and Scout are both trained on up to 40 trillion tokens on data encompassing 200 languages
+(with specific fine-tuning support for 12 languages including Arabic, Spanish, German, and Hindi).
+
+For deployment, Llama 4 Scout is designed for accessibility, fitting on a single server-grade GPU via
+on-the-fly 4-bit or 8-bitint4 quantization, while Maverick is available in BF16 and FP8 formats.
+These models are released under the custom Llama 4 Community License Agreement, available on the model repositories.
+
+You can find all the original Llama checkpoints under the [meta-llama](https://huggingface.co/meta-llama) organization.
+
+> [!TIP]
+> The Llama 4 family of models comes in two flavors: 109B, and 402B parameters. Both of these flavors are extremely
+> large and won't fit on your run-of-the-mill device. See below for some examples to reduce the memory usage of the
+> model.
+>
+> For the download to be faster and more resilient, we recommend installing the `hf_xet` dependency as followed:
+> `pip install transformers[hf_xet]`
+
+The examples below demonstrates how to generate with [`Pipeline`] or the [`AutoModel`]. We additionally add an example
+showcasing how to toggle the right attributes to enable very long-context generations, as some flavors of Llama 4
+have context lengths going up to 10 million tokens.
+
+
+
+
+
+```py
+from transformers import pipeline
+import torch
+
+model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
+
+messages = [
+ {"role": "user", "content": "what is the recipe of mayonnaise?"},
+]
+
+pipe = pipeline(
+ "text-generation",
+ model=model_id,
+ device_map="auto",
+ torch_dtype=torch.bfloat16
+)
+
+output = pipe(messages, do_sample=False, max_new_tokens=200)
+print(output[0]["generated_text"][-1]["content"])
+```
+
+
+
+
+```py
+from transformers import AutoTokenizer, Llama4ForConditionalGeneration
+import torch
+
+model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
+
+tokenizer = AutoTokenizer.from_pretrained(model_id)
+
+messages = [
+ {"role": "user", "content": "Who are you?"},
+]
+inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True)
+
+model = Llama4ForConditionalGeneration.from_pretrained(
+ model_id,
+ device_map="auto",
+ torch_dtype=torch.bfloat16
+)
+
+outputs = model.generate(**inputs.to(model.device), max_new_tokens=100)
+outputs = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])
+print(outputs[0])
+```
+
+
+
+
+```py
+from transformers import AutoProcessor, Llama4ForConditionalGeneration
+import torch
+
+model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
+
+processor = AutoProcessor.from_pretrained(model_id)
+model = Llama4ForConditionalGeneration.from_pretrained(
+ model_id,
+ device_map="auto",
+ torch_dtype=torch.bfloat16,
+)
+
+img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
+messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "image", "url": img_url},
+ {"type": "text", "text": "Describe this image in two sentences."},
+ ]
+ },
+]
+
+inputs = processor.apply_chat_template(
+ messages,
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ return_tensors="pt",
+).to(model.device)
+
+outputs = model.generate(
+ **inputs,
+ max_new_tokens=256,
+)
+
+response = processor.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])[0]
+print(response)
+```
+
+
+
+
+```py
+from transformers import AutoProcessor, Llama4ForConditionalGeneration
+import torch
+
+model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
+
+processor = AutoProcessor.from_pretrained(model_id)
+model = Llama4ForConditionalGeneration.from_pretrained(
+ model_id,
+ device_map="auto",
+ torch_dtype=torch.bfloat16,
+)
+
+url1 = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
+url2 = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png"
+messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "image", "url": url1},
+ {"type": "image", "url": url2},
+ {"type": "text", "text": "Can you describe how these two images are similar, and how they differ?"},
+ ]
+ },
+]
+
+inputs = processor.apply_chat_template(
+ messages,
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ return_tensors="pt",
+).to(model.device)
+
+outputs = model.generate(
+ **inputs,
+ max_new_tokens=256,
+)
+
+response = processor.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])[0]
+print(response)
+```
+
+
+
+
+Beware: the example below uses both `device_map="auto"` and flex-attention.
+Please use `torchrun` to run this example in tensor-parallel mode.
+
+We will work to enable running with `device_map="auto"` and flex-attention without
+tensor-parallel in the future.
+
+```py
+from transformers import Llama4ForConditionalGeneration, AutoTokenizer
+import torch
+import time
+
+file = "very_long_context_prompt.txt"
+model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
+
+with open(file, "r") as f:
+ very_long_text = "\n".join(f.readlines())
+
+tokenizer = AutoTokenizer.from_pretrained(model_id)
+model = Llama4ForConditionalGeneration.from_pretrained(
+ model_id,
+ device_map="auto",
+ attn_implementation="flex_attention",
+ torch_dtype=torch.bfloat16
+)
+
+messages = [
+ {"role": "user", "content": f"Look at the following texts: [{very_long_text}]\n\n\n\nWhat are the books, and who wrote them? Make me a nice list."},
+]
+input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
+
+torch.cuda.synchronize()
+start = time.time()
+out = model.generate(
+ input_ids.to(model.device),
+ prefill_chunk_size=2048*8,
+ max_new_tokens=300,
+ cache_implementation="hybrid",
+)
+print(time.time()-start)
+print(tokenizer.batch_decode(out[:, input_ids.shape[-1]:]))
+print(f"{torch.cuda.max_memory_allocated(model.device) / 1024**3:.2f} GiB")
+```
+
+
+
+
+## Efficiency; how to get the best out of llama 4
+
+### The Attention methods
+
+Updating the default attention function can significantly improve compute performance as well as memory usage. Refer to the [Attention Interface](../attention_interface) overview for an in-depth explanation of our interface.
+
+As of release, the Llama 4 model supports the following attention methods: `eager`, `flex_attention`, `sdpa`. We recommend using `flex_attention` for best results.
+Switching attention mechanism is done at the model initialization step:
+
+
+
+
+
+Setting Flex Attention ensures the best results with the very long context the model can handle.
+
+> [!TIP] Beware: the example below uses both `device_map="auto"` and flex-attention.
+> Please use `torchrun` to run this example in tensor-parallel mode.
+>
+> We will work to enable running with `device_map="auto"` and flex-attention without
+> tensor-parallel in the future.
+
+```py
+from transformers import Llama4ForConditionalGeneration
+import torch
+
+model = Llama4ForConditionalGeneration.from_pretrained(
+ model_id,
+ attn_implementation="flex_attention",
+ device_map="auto",
+ torch_dtype=torch.bfloat16,
+)
+```
+
+
+The `sdpa` attention method is generally more compute-efficient than the `eager` method.
+
+```py
+from transformers import Llama4ForConditionalGeneration
+import torch
+
+model = Llama4ForConditionalGeneration.from_pretrained(
+ model_id,
+ attn_implementation="sdpa",
+ device_map="auto",
+ torch_dtype=torch.bfloat16,
+)
+```
+
+
+The `eager` attention method is set by default, so no need for anything different when loading the model:
+
+```py
+from transformers import Llama4ForConditionalGeneration
+import torch
+
+model = Llama4ForConditionalGeneration.from_pretrained(
+ model_id,
+ device_map="auto",
+ torch_dtype=torch.bfloat16,
+)
+```
+
+
+
+
+### Quantization
+
+Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for available quantization backends.
+At time of release, both FBGEMM and LLM-Compressor are supported; more quantization methods will be supported in the days that follow the release.
+
+See below for examples using both:
+
+
+
+Here is an example loading an BF16 model in FP8 using the FBGEMM approach:
+
+
+
+
+```python
+from transformers import AutoTokenizer, Llama4ForConditionalGeneration, FbgemmFp8Config
+import torch
+
+model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
+
+tokenizer = AutoTokenizer.from_pretrained(model_id)
+
+messages = [
+ {"role": "user", "content": "Who are you?"},
+]
+inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True)
+
+model = Llama4ForConditionalGeneration.from_pretrained(
+ model_id,
+ device_map="auto",
+ torch_dtype=torch.bfloat16,
+ quantization_config=FbgemmFp8Config()
+)
+
+outputs = model.generate(**inputs.to(model.device), max_new_tokens=100)
+outputs = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])
+print(outputs[0])
+```
+
+
+
+
+To use the LLM-Compressor technique, we recommend leveraging the pre-quantized FP8 checkpoint available with the release:
+
+```python
+from transformers import AutoTokenizer, Llama4ForConditionalGeneration
+import torch
+
+model_id = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"
+
+tokenizer = AutoTokenizer.from_pretrained(model_id)
+
+messages = [
+ {"role": "user", "content": "Who are you?"},
+]
+inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True)
+
+model = Llama4ForConditionalGeneration.from_pretrained(
+ model_id,
+ tp_plan="auto",
+ torch_dtype=torch.bfloat16,
+)
+
+outputs = model.generate(**inputs.to(model.device), max_new_tokens=100)
+outputs = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])
+print(outputs[0])
+```
+
+
+
+### Offloading
+
+Enabling CPU-offloading means that components of the model might be moved to CPU instead of GPU in case the GPU-memory available isn't sufficient to load the entire model.
+At inference, different components will be loaded/unloaded from/to the GPU on the fly. This ensures that the model can be loaded on smaller machines as long as the CPU-memory is sufficient.
+However, this also slows down inference as it adds communication overhead.
+
+In order to enable CPU-offloading, you simply need to specify the `device_map` to `auto` at model load:
+
+```py
+from transformers import Llama4ForConditionalGeneration
+import torch
+
+model = Llama4ForConditionalGeneration.from_pretrained(
+ model_id,
+ device_map="auto",
+ torch_dtype=torch.bfloat16,
+)
+```
+
+## Llama4Config
+
+[[autodoc]] Llama4Config
+
+## Llama4TextConfig
+
+[[autodoc]] Llama4TextConfig
+
+## Llama4VisionConfig
+
+[[autodoc]] Llama4VisionConfig
+
+## Llama4Processor
+
+[[autodoc]] Llama4Processor
+
+## Llama4ImageProcessorFast
+
+[[autodoc]] Llama4ImageProcessorFast
+
+## Llama4ForConditionalGeneration
+
+[[autodoc]] Llama4ForConditionalGeneration
+- forward
+
+## Llama4ForCausalLM
+
+[[autodoc]] Llama4ForCausalLM
+- forward
+
+## Llama4TextModel
+
+[[autodoc]] Llama4TextModel
+- forward
+
+## Llama4ForCausalLM
+
+[[autodoc]] Llama4ForCausalLM
+- forward
+
+## Llama4VisionModel
+
+[[autodoc]] Llama4VisionModel
+- forward
diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py
index 96f993e7e5..3b525f7f15 100644
--- a/src/transformers/__init__.py
+++ b/src/transformers/__init__.py
@@ -562,6 +562,12 @@ _import_structure = {
"models.levit": ["LevitConfig"],
"models.lilt": ["LiltConfig"],
"models.llama": ["LlamaConfig"],
+ "models.llama4": [
+ "Llama4Config",
+ "Llama4Processor",
+ "Llama4TextConfig",
+ "Llama4VisionConfig",
+ ],
"models.llava": [
"LlavaConfig",
"LlavaProcessor",
@@ -1354,6 +1360,7 @@ else:
_import_structure["models.detr"].append("DetrImageProcessorFast")
_import_structure["models.gemma3"].append("Gemma3ImageProcessorFast")
_import_structure["models.got_ocr2"].append("GotOcr2ImageProcessorFast")
+ _import_structure["models.llama4"].append("Llama4ImageProcessorFast")
_import_structure["models.llava"].append("LlavaImageProcessorFast")
_import_structure["models.llava_next"].append("LlavaNextImageProcessorFast")
_import_structure["models.llava_onevision"].append("LlavaOnevisionImageProcessorFast")
@@ -2510,6 +2517,15 @@ else:
"GlmPreTrainedModel",
]
)
+ _import_structure["models.llama4"].extend(
+ [
+ "Llama4ForCausalLM",
+ "Llama4ForConditionalGeneration",
+ "Llama4TextModel",
+ "Llama4VisionModel",
+ "Llama4PreTrainedModel",
+ ]
+ )
_import_structure["models.glpn"].extend(
[
"GLPNForDepthEstimation",
@@ -5807,6 +5823,12 @@ if TYPE_CHECKING:
from .models.levit import LevitConfig
from .models.lilt import LiltConfig
from .models.llama import LlamaConfig
+ from .models.llama4 import (
+ Llama4Config,
+ Llama4Processor,
+ Llama4TextConfig,
+ Llama4VisionConfig,
+ )
from .models.llava import (
LlavaConfig,
LlavaProcessor,
@@ -6646,6 +6668,7 @@ if TYPE_CHECKING:
from .models.detr import DetrImageProcessorFast
from .models.gemma3 import Gemma3ImageProcessorFast
from .models.got_ocr2 import GotOcr2ImageProcessorFast
+ from .models.llama4 import Llama4ImageProcessorFast
from .models.llava import LlavaImageProcessorFast
from .models.llava_next import LlavaNextImageProcessorFast
from .models.llava_onevision import LlavaOnevisionImageProcessorFast
@@ -7827,6 +7850,13 @@ if TYPE_CHECKING:
LlamaModel,
LlamaPreTrainedModel,
)
+ from .models.llama4 import (
+ Llama4ForCausalLM,
+ Llama4ForConditionalGeneration,
+ Llama4PreTrainedModel,
+ Llama4TextModel,
+ Llama4VisionModel,
+ )
from .models.llava import (
LlavaForConditionalGeneration,
LlavaPreTrainedModel,
diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py
index 3f2e02a703..cd2a711960 100644
--- a/src/transformers/cache_utils.py
+++ b/src/transformers/cache_utils.py
@@ -1811,6 +1811,200 @@ class HybridCache(Cache):
self.value_cache[layer_idx].zero_()
+class HybridChunkedCache(Cache):
+ """
+ Hybrid Cache class to be used with `torch.compile` for Gemma2 models that alternate between a local sliding window attention
+ and global attention in every other layer. Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention
+ and ["StaticCache"] for global attention. For more information, see the documentation of each subcomponeent cache class.
+
+ Parameters:
+ config (`PretrainedConfig):
+ The configuration file defining the shape-related attributes required to initialize the static cache.
+ max_batch_size (`int`):
+ The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a
+ smaller batch size is used.
+ max_cache_len (`int`, *optional*):
+ The maximum sequence length with which the model will be used.
+ device (`torch.device` or `str`, *optional*):
+ The device on which the cache should be initialized. If you're using more than 1 computation device, you
+ should pass the `layer_device_map` argument instead.
+ dtype (torch.dtype, *optional*, defaults to `torch.bfloat16`):
+ The default `dtype` to use when initializing the layer.
+ layer_device_map (`Optional[Dict[int, Union[str, torch.device, int]]]]`, *optional*):
+ Mapping between the layers and its device. This is required when you are manually initializing the cache
+ and the model is split between different gpus. You can know which layers mapped to which device by
+ checking the associated device_map: `model.hf_device_map`.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache
+
+ >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b")
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
+
+ >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt")
+
+ >>> # Prepare a cache class and pass it to model's forward
+ >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
+ >>> max_generated_length = inputs.input_ids.shape[1] + 10
+ >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
+ >>> outputs.past_key_values # access cache filled with key/values from generation
+ HybridCache()
+ ```
+ """
+
+ # TODO (joao): dive deeper into gemma2 and paligemma -- there are reports of speed loss with compilation. Revert
+ # ALL changes from the PR that commented the line below when reactivating it.
+ # is_compileable = True
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ max_batch_size: int,
+ max_cache_len: Optional[int] = None,
+ device: Union[torch.device, str, None] = None,
+ dtype: torch.dtype = torch.bfloat16,
+ layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
+ ) -> None:
+ super().__init__()
+ if not hasattr(config, "sliding_window") or config.sliding_window is None:
+ self.sliding_window = getattr(config.get_text_config(), "attention_chunk_size", 8192)
+ else:
+ self.sliding_window = config.sliding_window
+ self.max_cache_len = max_cache_len
+ self.max_batch_size = max_batch_size
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self._dtype = dtype
+
+ if hasattr(config.get_text_config(), "no_rope_layers"):
+ self.is_sliding = config.no_rope_layers
+ else:
+ layer_switch = getattr(config, "sliding_window_pattern", 2)
+ self.is_sliding = [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)]
+
+ self.key_cache: List[torch.Tensor] = []
+ self.value_cache: List[torch.Tensor] = []
+ self.cumulative_length = [0 for _ in range(config.num_hidden_layers)]
+
+ def initialise_cache_layer(self, layer_idx, key_states):
+ if len(self.key_cache) > layer_idx:
+ return
+
+ num_key_value_heads = key_states.shape[1]
+ device = key_states.device
+ global_cache_shape = (self.max_batch_size, num_key_value_heads, self.max_cache_len, self.head_dim)
+ sliding_cache_shape = (
+ self.max_batch_size,
+ num_key_value_heads,
+ self.sliding_window,
+ self.head_dim,
+ )
+ # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
+ # breaks when updating the cache.
+ cache_shape = sliding_cache_shape if self.is_sliding[layer_idx] else global_cache_shape
+ new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device)
+ new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device)
+ torch._dynamo.mark_static_address(new_layer_key_cache)
+ torch._dynamo.mark_static_address(new_layer_value_cache)
+ self.key_cache.append(new_layer_key_cache)
+ self.value_cache.append(new_layer_value_cache)
+
+ def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
+ cumulative_length = self.cumulative_length[layer_idx]
+ is_full = cumulative_length >= max_cache_len
+ if is_full:
+ full_key_states = torch.cat((k_out[:, :, 1:, :], key_states), dim=-2)
+ full_value_states = torch.cat((v_out[:, :, 1:, :], value_states), dim=-2)
+ elif not is_full and cumulative_length + key_states.shape[2] > max_cache_len:
+ full_key_states = torch.cat((k_out[:, :, :cumulative_length, :], key_states), dim=-2)
+ full_value_states = torch.cat((v_out[:, :, :cumulative_length, :], value_states), dim=-2)
+ else:
+ self.key_cache[layer_idx].index_copy_(2, cache_position, key_states)
+ self.value_cache[layer_idx].index_copy_(2, cache_position, value_states)
+ self.cumulative_length[layer_idx] += key_states.shape[-2]
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
+
+ self.key_cache[layer_idx].copy_(full_key_states[:, :, -max_cache_len:, :])
+ self.value_cache[layer_idx].copy_(full_value_states[:, :, -max_cache_len:, :])
+ self.cumulative_length[layer_idx] += key_states.shape[-2]
+ # we should return the whole states instead of k_out, v_out to take the whole prompt
+ # into consideration when building kv cache instead of just throwing away tokens outside of the window
+ return full_key_states, full_value_states
+
+ def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
+ k_out[:, :, cache_position] = key_states
+ v_out[:, :, cache_position] = value_states
+
+ self.key_cache[layer_idx] = k_out
+ self.value_cache[layer_idx] = v_out
+ return k_out, v_out
+
+ def update(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ if cache_kwargs is None:
+ cache_kwargs = {}
+ cache_position = cache_kwargs.get("cache_position")
+ self.initialise_cache_layer(layer_idx, key_states)
+
+ # These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used
+ # when the cache is initialized in the forward pass (e.g. Gemma2)
+ if self.key_cache[layer_idx].device != key_states.device:
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device)
+ if self.value_cache[layer_idx].device != value_states.device:
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device)
+
+ k_out = self.key_cache[layer_idx]
+ v_out = self.value_cache[layer_idx]
+ key_states = key_states.to(k_out.dtype)
+ value_states = value_states.to(v_out.dtype)
+
+ if self.is_sliding[layer_idx]:
+ update_fn = self._sliding_update
+ else:
+ update_fn = self._static_update
+
+ return update_fn(
+ cache_position,
+ layer_idx,
+ key_states,
+ value_states,
+ k_out,
+ v_out,
+ k_out.shape[2],
+ )
+
+ def get_max_cache_shape(self) -> Optional[int]:
+ return self.max_cache_len
+
+ def get_seq_length(self, layer_idx: Optional[int] = 0):
+ # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
+ # limit the check to the first batch member and head dimension.
+ # TODO: deprecate this function in favor of `cache_position`
+ if layer_idx != 0:
+ raise ValueError(
+ "`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. "
+ "Using the `layer_idx` argument is not supported."
+ )
+ if len(self.key_cache) == 0:
+ return 0
+ return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
+
+ def reset(self):
+ """Resets the cache values while preserving the objects"""
+ for layer_idx in range(len(self.key_cache)):
+ # In-place ops prevent breaking the static address
+ self.key_cache[layer_idx].zero_()
+ self.value_cache[layer_idx].zero_()
+ self.cumulative_length = [0 for _ in range(len(self.cumulative_length))]
+
+
class MambaCache:
"""
Cache for mamba model which does not have attention mechanism and key value states.
diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py
index 808cfbcf93..d5991cae8f 100755
--- a/src/transformers/configuration_utils.py
+++ b/src/transformers/configuration_utils.py
@@ -801,18 +801,19 @@ class PretrainedConfig(PushToHubMixin):
def to_diff_dict(self) -> dict[str, Any]:
"""
- Removes all attributes from config which correspond to the default config attributes for better readability and
- serializes to a Python dictionary.
+ Removes all attributes from the configuration that correspond to the default config attributes for
+ better readability, while always retaining the `config` attribute from the class. Serializes to a
+ Python dictionary.
Returns:
- `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
+ Dict[str, Any]: Dictionary of all the attributes that make up this configuration instance.
"""
config_dict = self.to_dict()
- # get the default config dict
+ # Get the default config dict (from a fresh PreTrainedConfig instance)
default_config_dict = PretrainedConfig().to_dict()
- # get class specific config dict
+ # Get class-specific config dict if not part of a composition
class_config_dict = self.__class__().to_dict() if not self.is_composition else {}
serializable_config_dict = {}
@@ -847,8 +848,7 @@ class PretrainedConfig(PushToHubMixin):
if not isinstance(self.quantization_config, dict)
else self.quantization_config
)
-
- # pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
+ # Pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
_ = serializable_config_dict.pop("_pre_quantization_dtype", None)
self.dict_torch_dtype_to_str(serializable_config_dict)
diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py
index a6b0a72162..c743480d78 100644
--- a/src/transformers/generation/configuration_utils.py
+++ b/src/transformers/generation/configuration_utils.py
@@ -52,6 +52,7 @@ if is_torch_available():
from ..cache_utils import (
HQQQuantizedCache,
HybridCache,
+ HybridChunkedCache,
MambaCache,
OffloadedStaticCache,
QuantizedCacheConfig,
@@ -69,6 +70,7 @@ if is_torch_available():
"offloaded_static": OffloadedStaticCache,
"sliding_window": SlidingWindowCache,
"hybrid": HybridCache,
+ "hybrid_chunked": HybridChunkedCache,
"mamba": MambaCache,
}
QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}
@@ -416,6 +418,7 @@ class GenerationConfig(PushToHubMixin):
if isinstance(self.cache_config, dict):
self.cache_config = cache_config_class.from_dict(self.cache_config)
self.return_legacy_cache = kwargs.pop("return_legacy_cache", None)
+ self.prefill_chunk_size = kwargs.pop("prefill_chunk_size", None)
# Parameters for manipulation of the model output logits
self.temperature = kwargs.pop("temperature", 1.0)
diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py
index 232cceeedf..bc00e29ba5 100644
--- a/src/transformers/generation/utils.py
+++ b/src/transformers/generation/utils.py
@@ -1830,6 +1830,9 @@ class GenerationMixin:
Returns the resulting cache object.
"""
+ if cache_implementation == "hybrid" and "llama4" in getattr(self.config, "model_type", ""):
+ cache_implementation = "hybrid_chunked"
+
cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation]
requires_cross_attention_cache = (
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
@@ -3405,7 +3408,12 @@ class GenerationMixin:
os.environ["TOKENIZERS_PARALLELISM"] = "0"
model_forward = self.get_compiled_call(generation_config.compile_config)
- is_prefill = True
+ if generation_config.prefill_chunk_size is not None:
+ model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs)
+ is_prefill = False
+ else:
+ is_prefill = True
+
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
@@ -4855,6 +4863,45 @@ class GenerationMixin:
else:
return input_ids
+ def _prefill_chunking(self, input_ids: torch.LongTensor, generation_config: GenerationConfig, **model_kwargs):
+ # Even if we are not compiling the forward, flex is always compiled when used. With chunk prefill, we may
+ # end up needing just a bit more graphs than the default (which is 8). Doing this avoids very cryptic warnings
+ torch._dynamo.config.cache_size_limit = 64
+
+ chunk_size = generation_config.prefill_chunk_size
+ # Only chunk up the token just before last, so that decoding is completely performed outside this function
+ # (here we simply prefill the cache)
+ input_chunks = torch.split(input_ids[:, :-1], chunk_size, dim=-1)
+
+ if "past_key_values" not in model_kwargs:
+ raise ValueError("Cannot use prefill chunkink without a cache")
+
+ model_forward = self.get_compiled_call(generation_config.compile_config)
+ attention_mask = model_kwargs.pop("attention_mask", None)
+
+ past_length = 0
+ for input_chunk in input_chunks:
+ current_length = past_length + input_chunk.shape[-1]
+ # Prepare inputs
+ if attention_mask is not None:
+ model_kwargs["attention_mask"] = attention_mask[:, :current_length]
+ model_kwargs["cache_position"] = torch.arange(
+ past_length, current_length, dtype=torch.long, device=input_chunk.device
+ )
+ model_kwargs["position_ids"] = model_kwargs["cache_position"].unsqueeze(0)
+ model_inputs = self.prepare_inputs_for_generation(input_chunk, **model_kwargs)
+
+ outputs = model_forward(**model_inputs, return_dict=True)
+
+ model_kwargs["past_key_values"] = outputs.past_key_values
+ past_length = current_length
+
+ model_kwargs["attention_mask"] = attention_mask
+ model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
+ _ = model_kwargs.pop("position_ids", None)
+
+ return model_kwargs
+
def _speculative_sampling(
candidate_input_ids,
diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py
index da8c9cd4c6..8d03c5cf79 100755
--- a/src/transformers/integrations/__init__.py
+++ b/src/transformers/integrations/__init__.py
@@ -53,7 +53,7 @@ _import_structure = {
"unset_hf_deepspeed_config",
],
"eetq": ["replace_with_eetq_linear"],
- "fbgemm_fp8": ["FbgemmFp8Linear", "replace_with_fbgemm_fp8_linear"],
+ "fbgemm_fp8": ["FbgemmFp8Linear", "FbgemmFp8Llama4TextExperts", "replace_with_fbgemm_fp8_linear"],
"finegrained_fp8": ["FP8Linear", "replace_with_fp8_linear"],
"fsdp": ["is_fsdp_managed_module"],
"ggml": [
@@ -192,7 +192,7 @@ if TYPE_CHECKING:
unset_hf_deepspeed_config,
)
from .eetq import replace_with_eetq_linear
- from .fbgemm_fp8 import FbgemmFp8Linear, replace_with_fbgemm_fp8_linear
+ from .fbgemm_fp8 import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts, replace_with_fbgemm_fp8_linear
from .finegrained_fp8 import FP8Linear, replace_with_fp8_linear
from .fsdp import is_fsdp_managed_module
from .ggml import (
diff --git a/src/transformers/integrations/compressed_tensors.py b/src/transformers/integrations/compressed_tensors.py
new file mode 100644
index 0000000000..752227914d
--- /dev/null
+++ b/src/transformers/integrations/compressed_tensors.py
@@ -0,0 +1,54 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from transformers.utils import is_torch_available
+
+
+if is_torch_available():
+ import torch
+ import torch.nn as nn
+
+from transformers.models.llama4.modeling_llama4 import Llama4TextMLP
+
+
+def skip(*args, **kwargs):
+ pass
+
+
+class CompressedExpertsLinear(nn.Module):
+ """
+ A module that implements a compressed version of a list of expert modules.
+ This is specifically designed to work with Llama4TextExperts in MoE layers.
+ """
+
+ def __init__(self, config):
+ # Skip random weight initialization for experts. Otherwise,
+ # the init of this module would take over minutes. For a model
+ # with tens of layers of experts, it would easily take over 20 minutes.
+ nn.init.kaiming_uniform_ = skip
+ nn.init.uniform_ = skip
+ nn.init.normal_ = skip
+ super().__init__()
+ self.num_experts = config.num_local_experts
+ self.expert_modules = nn.ModuleList([Llama4TextMLP(config) for _ in range(self.num_experts)])
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor:
+ hidden_states = hidden_states.reshape(self.num_experts, -1, hidden_states.shape[-1])
+ expert_routed_out_list = []
+ for expert_idx in range(self.num_experts):
+ expert_routed_out_list.append(self.expert_modules[expert_idx](hidden_states[expert_idx]))
+ routed_out = torch.cat(expert_routed_out_list, dim=0)
+ return routed_out
diff --git a/src/transformers/integrations/fbgemm_fp8.py b/src/transformers/integrations/fbgemm_fp8.py
index 71c2b570cc..5cca37f515 100644
--- a/src/transformers/integrations/fbgemm_fp8.py
+++ b/src/transformers/integrations/fbgemm_fp8.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from ..activations import ACT2FN
from ..utils import is_accelerate_available, is_fbgemm_gpu_available, is_torch_available, logging
@@ -28,36 +29,36 @@ if is_fbgemm_gpu_available():
logger = logging.get_logger(__name__)
-class FbgemmFp8Linear(torch.nn.Module):
+class FbgemmFp8Linear(torch.nn.Linear):
def __init__(self, in_features, out_features, bias, weight_dtype=torch.float32):
- super().__init__()
+ super().__init__(in_features, out_features, bias)
self.in_features = in_features
self.out_features = out_features
- self.register_buffer("weight", torch.zeros((out_features, in_features), dtype=torch.float8_e4m3fn))
- self.register_buffer("weight_scale", torch.zeros((out_features, 1), dtype=weight_dtype))
+ self.weight = torch.nn.Parameter(torch.zeros((out_features, in_features), dtype=torch.float8_e4m3fn))
+ self.weight_scale = torch.nn.Parameter(torch.zeros((out_features, 1), dtype=weight_dtype))
self.register_buffer("input_scale_ub", torch.zeros([1], dtype=torch.float), persistent=False)
if bias:
- self.register_buffer("bias", torch.zeros((self.out_features), dtype=weight_dtype))
+ self.bias = torch.nn.Parameter(torch.zeros((self.out_features), dtype=weight_dtype))
else:
self.bias = None
def forward(self, x):
- num_tokens = None
# quantize_fp8_per_row will squash the leading dimensions, so save the desired shape here
output_shape = (*x.shape[:-1], -1)
# x_quantized and x_scale are not necessarily on the same device as x, this is an issue.
# https://github.com/pytorch/FBGEMM/blob/e08af8539c391437f447173863df0f3f6f6f1855/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L1237C3-L1237C45
x_quantized, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
- x.view(-1, x.shape[-1]), num_tokens, self.input_scale_ub
+ x.view(-1, x.shape[-1]), scale_ub=self.input_scale_ub
)
# moving x_quantized, x_scale here creates glibberish output ... However, if we move the output, it works
# x_quantized, x_scale = x_quantized.to(x.device), x_scale.to(x.device)
# The computation still happens on the device where self.weight is even if x_quantized is not on the same device as self.weight
+ weight_scale_float32 = self.weight_scale.to(torch.float32)
output = torch.ops.fbgemm.f8f8bf16_rowwise(
- x_quantized, self.weight, x_scale, self.weight_scale, use_fast_accum=True
+ x_quantized, self.weight, x_scale, weight_scale_float32, use_fast_accum=True
)
output = output + self.bias if self.bias is not None else output
# Hacky for now, we have the output to the device of x
@@ -67,6 +68,92 @@ class FbgemmFp8Linear(torch.nn.Module):
return output
+class FbgemmFp8Llama4TextExperts(nn.Module):
+ def __init__(self, config, dtype=torch.float32):
+ super().__init__()
+ self.num_experts = config.num_local_experts
+ self.intermediate_size = config.intermediate_size
+ self.hidden_size = config.hidden_size
+ self.expert_dim = self.intermediate_size
+ self.act_fn = ACT2FN[config.hidden_act]
+ # Register FP8 buffers for gate_up_proj
+ self.gate_up_proj = torch.nn.Parameter(
+ torch.zeros((self.num_experts, self.hidden_size, 2 * self.expert_dim), dtype=torch.float8_e4m3fn)
+ )
+ self.gate_up_proj_scale = torch.nn.Parameter(
+ torch.zeros((self.num_experts, 1, self.expert_dim * 2), dtype=torch.float32)
+ )
+ # Register FP8 buffers for down_proj
+ self.down_proj = torch.nn.Parameter(
+ torch.zeros((self.num_experts, self.expert_dim, self.hidden_size), dtype=torch.float8_e4m3fn)
+ )
+ self.down_proj_scale = torch.nn.Parameter(
+ torch.zeros((self.num_experts, self.hidden_size, 1), dtype=torch.float32)
+ )
+ # Register input scale upper bound
+ self.register_buffer("input_scale_ub", torch.zeros([1], dtype=torch.float), persistent=False)
+
+ def forward(self, hidden_states):
+ """
+ Args:
+ hidden_states (torch.Tensor): (batch_size * token_num, hidden_size)
+ Returns:
+ torch.Tensor: (batch_size * token_num, hidden_size)
+ """
+ # Reshape hidden states for expert computation
+ hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)
+ num_tokens = None
+
+ # Pre-allocate tensor for all expert outputs with same shape as hidden_states
+ next_states = torch.empty_like(hidden_states)
+
+ for i in range(self.num_experts):
+ # Extract expert's hidden states
+ expert_hidden = hidden_states[i]
+ expert_hidden_reshaped = expert_hidden.reshape(-1, self.hidden_size)
+ # Quantize for this expert
+ expert_quantized, expert_scale = torch.ops.fbgemm.quantize_fp8_per_row(
+ expert_hidden_reshaped, num_tokens, self.input_scale_ub
+ )
+ sharded_expert_dim = self.gate_up_proj.shape[-1] // 2
+ gate_up_proj_scale_float32 = self.gate_up_proj_scale.to(torch.float32)
+
+ gate = torch.ops.fbgemm.f8f8bf16_rowwise(
+ expert_quantized,
+ self.gate_up_proj[i].transpose(0, 1)[:sharded_expert_dim].contiguous(),
+ expert_scale,
+ gate_up_proj_scale_float32[i][0][:sharded_expert_dim].view(-1, 1).contiguous(),
+ use_fast_accum=True,
+ )
+
+ up = torch.ops.fbgemm.f8f8bf16_rowwise(
+ expert_quantized,
+ self.gate_up_proj[i].transpose(0, 1)[sharded_expert_dim:].contiguous(),
+ expert_scale,
+ gate_up_proj_scale_float32[i][0][sharded_expert_dim:].view(-1, 1).contiguous(),
+ use_fast_accum=True,
+ )
+
+ activated = up * self.act_fn(gate)
+
+ activated_quantized, activated_scale = torch.ops.fbgemm.quantize_fp8_per_row(
+ activated, num_tokens, self.input_scale_ub
+ )
+
+ down_proj_scale_float32 = self.down_proj_scale.to(torch.float32)
+ expert_output = torch.ops.fbgemm.f8f8bf16_rowwise(
+ activated_quantized,
+ self.down_proj[i].transpose(0, 1).contiguous(),
+ activated_scale,
+ down_proj_scale_float32[i].view(-1, 1).contiguous(),
+ use_fast_accum=True,
+ )
+
+ next_states[i] = expert_output
+ next_states = next_states.to(hidden_states.device)
+ return next_states.view(-1, self.hidden_size)
+
+
def _replace_with_fbgemm_fp8_linear(
model,
modules_to_not_convert=None,
@@ -74,12 +161,17 @@ def _replace_with_fbgemm_fp8_linear(
quantization_config=None,
has_been_replaced=False,
pre_quantized=False,
+ config=None,
+ tp_plan=None,
):
"""
Private method that wraps the recursion for module replacement.
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
"""
+
+ import re
+
if current_key_name is None:
current_key_name = []
@@ -105,9 +197,27 @@ def _replace_with_fbgemm_fp8_linear(
# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)
# set non persistant buffer outside of init_empty_weights
+ model._modules[name].input_scale_ub = torch.tensor(
+ [quantization_config.activation_scale_ub],
+ dtype=torch.float,
+ )
+ if module.__class__.__name__ == "Llama4TextExperts" and name not in modules_to_not_convert:
+ current_key_name_str = ".".join(current_key_name)
+ if not any(
+ (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
+ ):
+ with init_empty_weights(include_buffers=True):
+ tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".gate_up_proj_scale")] = tp_plan[
+ re.sub(r"\d+", "*", current_key_name_str + ".gate_up_proj")
+ ]
+ tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".down_proj_scale")] = None
+ model._modules[name] = FbgemmFp8Llama4TextExperts(
+ config.text_config,
+ )
model._modules[name].input_scale_ub = torch.tensor(
[quantization_config.activation_scale_ub], dtype=torch.float
)
+
if len(list(module.children())) > 0:
_, has_been_replaced = _replace_with_fbgemm_fp8_linear(
module,
@@ -116,6 +226,8 @@ def _replace_with_fbgemm_fp8_linear(
quantization_config,
has_been_replaced=has_been_replaced,
pre_quantized=pre_quantized,
+ config=config,
+ tp_plan=tp_plan,
)
# Remove the last key for recursion
current_key_name.pop(-1)
@@ -123,7 +235,13 @@ def _replace_with_fbgemm_fp8_linear(
def replace_with_fbgemm_fp8_linear(
- model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, pre_quantized=False
+ model,
+ modules_to_not_convert=None,
+ current_key_name=None,
+ quantization_config=None,
+ pre_quantized=False,
+ config=None,
+ tp_plan=None,
):
"""
A helper function to replace all `torch.nn.Linear` modules by `FbgemmFp8Linear` modules.
@@ -151,9 +269,14 @@ def replace_with_fbgemm_fp8_linear(
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
modules_to_not_convert = list(set(modules_to_not_convert))
model, has_been_replaced = _replace_with_fbgemm_fp8_linear(
- model, modules_to_not_convert, current_key_name, quantization_config, pre_quantized=pre_quantized
+ model,
+ modules_to_not_convert,
+ current_key_name,
+ quantization_config,
+ pre_quantized=pre_quantized,
+ config=config,
+ tp_plan=tp_plan,
)
-
if not has_been_replaced:
logger.warning(
"You are loading your model using FP8 quantization but no linear modules were found in your model."
diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py
index b0a054998c..1aa146e4a4 100644
--- a/src/transformers/integrations/flex_attention.py
+++ b/src/transformers/integrations/flex_attention.py
@@ -34,10 +34,7 @@ from ..utils import is_torch_flex_attn_available
if is_torch_flex_attn_available():
- from torch.nn.attention.flex_attention import (
- BlockMask,
- flex_attention,
- )
+ from torch.nn.attention.flex_attention import BlockMask, flex_attention
from torch.nn.attention.flex_attention import (
create_block_mask as create_block_causal_mask_flex,
)
@@ -64,14 +61,23 @@ class WrappedFlexAttention:
Initialize or update the singleton instance.
"""
if self._is_flex_compiled is False:
- self._compiled_flex_attention = torch.compile(flex_attention, dynamic=False)
+ self._compiled_flex_attention = torch.compile(flex_attention, backend="inductor")
self._is_flex_compiled = True
def __call__(self):
return self._compiled_flex_attention
-def make_flex_block_causal_mask(attention_mask_2d: torch.Tensor) -> "BlockMask":
+Offset = Union[torch.Tensor, int]
+
+
+def make_flex_block_causal_mask(
+ attention_mask_2d: torch.Tensor,
+ attention_chunk_size: Optional[int] = None,
+ query_length=None,
+ key_length=None,
+ offsets: Optional[Tuple[Offset, Offset]] = None,
+) -> "BlockMask":
"""
Create a block causal document mask for a batch of sequences, both packed and unpacked.
Create Block causal logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`.
@@ -94,10 +100,13 @@ def make_flex_block_causal_mask(attention_mask_2d: torch.Tensor) -> "BlockMask":
Returns:
BlockMask
"""
+ attention_mask_2d = torch.nn.functional.pad(attention_mask_2d, value=0, pad=(0, key_length))
device = attention_mask_2d.device
+ document_ids = attention_mask_2d.clone()
- document_ids = attention_mask_2d
- batch_size, total_seq_len = document_ids.shape
+ if attention_chunk_size is not None:
+ # we create an arange, then we just // by chunk size to get [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]
+ document_ids = (document_ids.fill_(1).cumsum(-1) - 1) // (attention_chunk_size)
# Instead of passing a tensor mask, flex attention requires a mask_mod function
# that determines which elements of QK^T should be included in the attention
@@ -112,18 +121,30 @@ def make_flex_block_causal_mask(attention_mask_2d: torch.Tensor) -> "BlockMask":
See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
for an illustration.
"""
- causal_mask = q_idx >= kv_idx
+ causal_mask = q_idx >= kv_idx # not valid when decoding
document_mask = document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx]
- padding_mask = document_ids[batch_idx, q_idx] > 0
- return causal_mask & document_mask & padding_mask
+ padding_mask = attention_mask_2d[batch_idx, q_idx] > 0
+ final_mask = causal_mask & padding_mask & document_mask
+ return final_mask
+ if offsets is not None:
+ q_offset = offsets[0]
+ kv_offset = offsets[1]
+
+ def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
+ offset_q = q_idx + q_offset
+ offset_kv = kv_idx + kv_offset
+ return causal_mask_mod(batch_idx, head_idx, offset_q, offset_kv)
+ else:
+ mask_mod = causal_mask_mod
return create_block_causal_mask_flex(
- mask_mod=causal_mask_mod,
- B=batch_size,
+ mask_mod=mask_mod,
+ B=1,
H=None, # attention head
- Q_LEN=total_seq_len,
- KV_LEN=total_seq_len,
+ Q_LEN=query_length,
+ KV_LEN=key_length,
device=device,
+ _compile=True,
)
@@ -144,6 +165,18 @@ def compile_friendly_flex_attention(
)
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
def flex_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
@@ -174,14 +207,25 @@ def flex_attention_forward(
score = score + head_mask[batch_idx][head_idx][0][0]
return score
+ enable_gqa = True
+ num_local_query_heads = query.shape[1]
+
+ # When running TP this helps:
+ if not ((num_local_query_heads & (num_local_query_heads - 1)) == 0):
+ key = repeat_kv(key, query.shape[1] // key.shape[1])
+ value = repeat_kv(value, query.shape[1] // value.shape[1])
+ enable_gqa = False
+
+ kernel_options = kwargs.get("kernel_options", None)
attn_output, attention_weights = compile_friendly_flex_attention(
query,
key,
value,
score_mod=score_mod,
block_mask=block_mask,
- enable_gqa=True,
+ enable_gqa=enable_gqa,
scale=scaling,
+ kernel_options=kernel_options,
# Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
# For simplification, we thus always return it as no additional computations are introduced.
return_lse=True,
diff --git a/src/transformers/integrations/sdpa_attention.py b/src/transformers/integrations/sdpa_attention.py
index 890832a26e..9c924c048a 100644
--- a/src/transformers/integrations/sdpa_attention.py
+++ b/src/transformers/integrations/sdpa_attention.py
@@ -31,7 +31,7 @@ def sdpa_attention_forward(
value = repeat_kv(value, module.num_key_value_groups)
causal_mask = attention_mask
- if attention_mask is not None:
+ if attention_mask is not None and causal_mask.ndim == 4:
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
# SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions
diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py
index f8a04c037b..34b21444fe 100644
--- a/src/transformers/integrations/tensor_parallel.py
+++ b/src/transformers/integrations/tensor_parallel.py
@@ -61,6 +61,21 @@ def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> Li
return [single_size] * blocks
+str_to_torch_dtype = {
+ "BOOL": torch.bool,
+ "U8": torch.uint8,
+ "I8": torch.int8,
+ "I16": torch.int16,
+ "F16": torch.float16,
+ "BF16": torch.bfloat16,
+ "I32": torch.int32,
+ "F32": torch.float32,
+ "F64": torch.float64,
+ "I64": torch.int64,
+ "F8_E4M3": torch.float8_e4m3fn,
+}
+
+
def get_packed_weights(param, empty_param, device_mesh, rank, dim):
"""
When weights are packed (gate_up_proj), we need to make sure each shard gets its correct share.
@@ -106,6 +121,12 @@ def get_packed_weights(param, empty_param, device_mesh, rank, dim):
tensors_slices += range(block_offset + start, block_offset + stop)
block_offset += block_size
+ slice_dtype = slice_.get_dtype()
+ # Handle F8_E4M3 dtype by converting to float16 before slicing
+ # Without upcasting, the slicing causes : RuntimeError: "index_cpu" not implemented for 'Float8_e4m3fn'
+ if slice_dtype == "F8_E4M3":
+ slice_ = slice_[...].to(torch.float16)
+
if dim == 0:
tensor = slice_[tensors_slices, ...]
elif dim == 1 or dim == -2:
@@ -114,7 +135,7 @@ def get_packed_weights(param, empty_param, device_mesh, rank, dim):
tensor = slice_[..., tensors_slices]
else:
raise ValueError(f"Unsupported dim {dim}, only dim 0, 1 or 2 are supported")
- return tensor
+ return tensor.to(str_to_torch_dtype[slice_dtype])
def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
@@ -199,11 +220,12 @@ class GatherParallel(TensorParallelLayer):
@staticmethod
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
if isinstance(inputs[0], DTensor):
- inputs[0] = inputs[0].to_local()
+ inputs = inputs[0].to_local()
return inputs
@staticmethod
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
+ # this op cannot be asynch, otherwise it completely breaks the outputs of models
torch.distributed.all_reduce(outputs[0], op=torch.distributed.ReduceOp.SUM, async_op=False)
return outputs
@@ -266,7 +288,7 @@ class ColwiseParallel(TensorParallelLayer):
# transform the input layouts to the desired layouts of ColwiseParallel
if input_layouts != desired_input_layouts:
- input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
+ input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=False)
return input_tensor
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
@@ -291,7 +313,7 @@ class ColwiseParallel(TensorParallelLayer):
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
# outputs is a shard on last dimension DTensor, i.e. Shard(-1)
if outputs.placements != output_layouts:
- outputs = outputs.redistribute(placements=output_layouts, async_op=True)
+ outputs = outputs.redistribute(placements=output_layouts, async_op=False)
# back to local tensor
return outputs.to_local() if use_local_output else outputs
@@ -343,16 +365,6 @@ class RowwiseParallel(TensorParallelLayer):
self.use_local_output = use_local_output
self.use_dtensor = use_dtensor
- @staticmethod
- def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
- input_tensor = inputs[0]
- if not isinstance(input_tensor, DTensor):
- input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
-
- if input_layouts != desired_input_layouts:
- input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
- return input_tensor
-
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
# Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
# means Rowwise as nn.Linear is input * weight^T + bias, where
@@ -371,6 +383,20 @@ class RowwiseParallel(TensorParallelLayer):
parameter = DTensor.from_local(parameter, device_mesh, shard, run_check=False)
return nn.Parameter(parameter)
+ @staticmethod
+ def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
+ if hasattr(mod, "bias") and mod.bias is not None:
+ mod._bias = mod.bias
+ mod.bias = None
+
+ input_tensor = inputs[0]
+ if not isinstance(input_tensor, DTensor):
+ input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
+
+ if input_layouts != desired_input_layouts:
+ input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
+ return input_tensor
+
@staticmethod
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
# Rowwise sharding produces partial output, depending on output layouts:
@@ -378,6 +404,8 @@ class RowwiseParallel(TensorParallelLayer):
# 2. to shard -> reduce_scatter
if outputs.placements != output_layouts:
outputs = outputs.redistribute(placements=output_layouts, async_op=True)
+ if hasattr(mod, "_bias"):
+ outputs += mod._bias
# back to local tensor if use_local_output is True
return outputs.to_local() if use_local_output else outputs
@@ -418,6 +446,90 @@ class PackedRowwiseParallel(RowwiseParallel):
return nn.Parameter(parameter)
+class SequenceParallel(TensorParallelLayer):
+ """
+ SequenceParallel replicates a compatible ``nn.Module`` parameters and runs the sharded computation with
+ input sharded on the sequence dimension. This currently supports ``nn.LayerNorm``, ``nn.Dropout``, and the
+ `RMSNorm python implementation `__
+
+ This style implements the operation that is described in the paper
+ `Reducing Activation Recomputation in Large Transformer Models `__
+
+ If the input passed in to this ``nn.Module`` is a :class:`torch.Tensor`, it assumes that the input is already sharded
+ on the sequence dimension and converts the input to a :class:`DTensor` sharded on the sequence dimension. If the input
+ passed in to this ``nn.Module`` is already a :class:`DTensor` but is not sharded on the sequence dimension, it would
+ redistribute the input to be sharded on the sequence dimension.
+
+ The output of the ``nn.Module`` will be sharded on the sequence dimension.
+
+ Keyword Args:
+ sequence_dim (int, optional):
+ The sequence dimension of the input tensor for the ``nn.Module``, this is used to annotate the input tensor to
+ become a DTensor that is sharded on the sequence dimension, default: 1.
+ use_local_output (bool, optional):
+ Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: False.
+ Returns:
+ A :class:`ParallelStyle` object that represents Sequence Parallel of the ``nn.Module``.
+
+ Example::
+ >>> # xdoctest: +SKIP(failing)
+ >>> from torch.distributed.tensor.parallel import parallelize_module, SequenceParallel
+ >>> from torch.distributed.device_mesh import init_device_mesh
+ >>> ...
+ >>> m = Model(...) # m is a nn.Module that contains a "norm" nn.LayerNorm submodule
+ >>> tp_mesh = init_device_mesh("cuda", (8,))
+ >>>
+ >>> # By default, the input of the "norm" will be converted to DTensor that shards on the sequence dim
+ >>> # and the output of "norm" will return a sharded on sequence dimension :class:`DTensor`.
+ >>>
+ >>> sharded_mod = parallelize_module(m, tp_mesh, {"norm": SequenceParallel()}),
+ >>> ...
+
+ .. note:: SequenceParallel style assumes ones initialization if there are weights in the nn.Module (i.e.
+ ``nn.LayerNorm`` or ``RMSNorm``, and they by default have ones initialization). If you have custom
+ inits for the weights on those modules, you need to broadcast the weights before/after parallelizing
+ to ensure that they are replicated.
+ """
+
+ def __init__(self, *, sequence_dim: int = 1, use_local_output: bool = False, use_dtensor=False):
+ super().__init__()
+ self.input_layouts = (Replicate(),)
+ self.desired_input_layouts = (Shard(1),)
+ self.output_layouts = (Replicate(),)
+ self.use_local_output = use_local_output
+ self.use_dtensor = True
+ self.sequence_sharding = (Shard(sequence_dim),)
+ self.use_local_output = use_local_output
+
+ @staticmethod
+ def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
+ input_tensor = inputs[0]
+ if not isinstance(input_tensor, DTensor):
+ input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
+ if input_layouts != desired_input_layouts:
+ input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
+ return input_tensor
+
+ @staticmethod
+ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
+ outputs = outputs.redistribute(
+ placements=(Replicate(),), async_op=True
+ ) # maybe we have to replicate ? because next layer is not sharded
+ return outputs.to_local() # if use_local_output else outputs
+
+ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
+ # colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
+ # means Colwise as Linear is input * weight^T + bias, where
+ # weight would become Shard(1)
+ parameter = param[:]
+ parameter = parameter.to(param_casting_dtype)
+ if to_contiguous:
+ parameter = parameter.contiguous()
+ if self.use_dtensor:
+ parameter = DTensor.from_local(parameter, device_mesh, [Replicate()], run_check=False)
+ return nn.Parameter(parameter)
+
+
SUPPORTED_TP_STYLES = {
"colwise",
"rowwise",
@@ -428,6 +540,7 @@ SUPPORTED_TP_STYLES = {
"local",
"gather",
"local_packed_rowwise",
+ "sequence_parallel",
}
@@ -459,6 +572,8 @@ def translate_to_torch_parallel_style(style: str):
return GatherParallel()
elif style == "local_packed_rowwise":
return PackedRowwiseParallel(use_dtensor=False)
+ elif style == "sequence_parallel":
+ return SequenceParallel()
else:
raise ValueError(f"Unsupported parallel style value: {style}")
@@ -518,6 +633,7 @@ def shard_and_distribute_module(
tp_plan = model._tp_plan
module_to_tp = model.get_submodule(param_name)
current_module_plan = None
+ rank = int(rank)
generic_param_name = re.sub(r"\d+", "*", parameter_name)
if generic_param_name in tp_plan:
current_module_plan = tp_plan[generic_param_name]
@@ -531,12 +647,18 @@ def shard_and_distribute_module(
module_to_tp._is_hooked = True
if current_module_plan is not None:
- tp_layer = translate_to_torch_parallel_style(current_module_plan)
- param = tp_layer.partition_tensor(
- param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
- )
+ try:
+ tp_layer = translate_to_torch_parallel_style(current_module_plan)
+ param = tp_layer.partition_tensor(
+ param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
+ )
+ except NotImplementedError as e:
+ print(
+ f"Trying to prepare {parameter_name}, but it's not supported. Corresponding module: {module_to_tp} Fix it's TP plan, current layer: {tp_layer} : {e}"
+ )
else:
# TODO log no plan modules in set
+ # print("No plan for", parameter_name,end ="\n")
param = param[...].to(param_casting_dtype)
if is_contiguous:
param = param.contiguous()
diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py
index 2e0a389d74..6a3286cbc9 100644
--- a/src/transformers/modeling_utils.py
+++ b/src/transformers/modeling_utils.py
@@ -484,6 +484,7 @@ str_to_torch_dtype = {
"F32": torch.float32,
"F64": torch.float64,
"I64": torch.int64,
+ "F8_E4M3": torch.float8_e4m3fn,
}
if is_torch_greater_or_equal("2.1.0"):
@@ -1914,16 +1915,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
# If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config
- if self.base_model is self:
- self._pp_plan = (
- self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else None
- )
- self._tp_plan = self.config.base_model_tp_plan.copy() if self.config.base_model_tp_plan is not None else {}
- else:
- self._tp_plan = self._tp_plan or {}
- for name, module in self.named_children():
- if plan := getattr(module, "_tp_plan", None):
- self._tp_plan.update({f"{name}.{k}": v for k, v in plan.items()})
+ self._pp_plan = self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else None
+ self._tp_plan = self.config.base_model_tp_plan.copy() if self.config.base_model_tp_plan is not None else {}
+ for name, module in self.named_children():
+ if plan := getattr(module, "_tp_plan", None):
+ self._tp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()})
if self._tp_plan is not None and is_torch_greater_or_equal("2.3"):
for _, v in self._tp_plan.items():
@@ -4054,6 +4050,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
import sys
sys.stdout = open(os.devnull, "w")
+ sys.stderr = open(os.devnull, "w")
# This is the easiest way to dispatch to the current process device
device_map = tp_device
# Assuming sharding the model onto the world
@@ -4238,6 +4235,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
device_map = hf_quantizer.update_device_map(device_map)
+ config = hf_quantizer.update_tp_plan(config)
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
if hasattr(hf_quantizer.quantization_config.quant_method, "value"):
@@ -4370,9 +4368,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if hf_quantizer is not None:
hf_quantizer.preprocess_model(
- model=model, device_map=device_map, keep_in_fp32_modules=model._keep_in_fp32_modules
+ model=model, device_map=device_map, keep_in_fp32_modules=model._keep_in_fp32_modules, config=config
)
-
# We store the original dtype for quantized models as we cannot easily retrieve it
# once the weights have been quantized
# Note that once you have loaded a quantized model, you can't change its dtype so this will
@@ -4901,7 +4898,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
name,
casting_dtype,
to_contiguous,
- tp_device.index,
+ os.environ["RANK"],
device_mesh,
)
@@ -5174,6 +5171,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
want to use compiled version to avoid recomputing the graph with new shapes) and iterative decoding
(where we want the speed-ups of compiled version with static shapes)."""
# Only reset it if not present or different from previous config
+ if "llama4" in self.config.model_type: # TODO try to enable for FULL COMPILE HYBRID CACHE SUPPORT
+ return self.__call__
default_config = getattr(self.generation_config, "compile_config", CompileConfig())
if (
not hasattr(self, "_compiled_call")
diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py
index b2cde9f4bc..08cec64b41 100644
--- a/src/transformers/models/__init__.py
+++ b/src/transformers/models/__init__.py
@@ -148,6 +148,7 @@ from . import (
levit,
lilt,
llama,
+ llama4,
llava,
llava_next,
llava_next_video,
diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py
index c06017fd76..d21a16beb1 100644
--- a/src/transformers/models/auto/auto_factory.py
+++ b/src/transformers/models/auto/auto_factory.py
@@ -544,10 +544,6 @@ class _BaseAutoModelClass:
if kwargs_orig.get("quantization_config", None) is not None:
kwargs["quantization_config"] = kwargs_orig["quantization_config"]
- # AutoClass-specific config manipulation
- config = copy.deepcopy(config)
- config = cls._prepare_config_for_auto_class(config)
-
has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map
has_local_code = type(config) in cls._model_mapping.keys()
trust_remote_code = resolve_trust_remote_code(
@@ -570,6 +566,8 @@ class _BaseAutoModelClass:
)
elif type(config) in cls._model_mapping.keys():
model_class = _get_model_class(config, cls._model_mapping)
+ if model_class.config_class == config.sub_configs.get("text_config", None):
+ config = config.get_text_config()
return model_class.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
)
diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py
index 9937b55a8b..759b7ad3d9 100644
--- a/src/transformers/models/auto/configuration_auto.py
+++ b/src/transformers/models/auto/configuration_auto.py
@@ -170,6 +170,8 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("levit", "LevitConfig"),
("lilt", "LiltConfig"),
("llama", "LlamaConfig"),
+ ("llama4", "Llama4Config"),
+ ("llama4_text", "Llama4TextConfig"),
("llava", "LlavaConfig"),
("llava_next", "LlavaNextConfig"),
("llava_next_video", "LlavaNextVideoConfig"),
@@ -519,6 +521,8 @@ MODEL_NAMES_MAPPING = OrderedDict(
("llama", "LLaMA"),
("llama2", "Llama2"),
("llama3", "Llama3"),
+ ("llama4", "Llama4"),
+ ("llama4_text", "Llama4ForCausalLM"),
("llava", "LLaVa"),
("llava_next", "LLaVA-NeXT"),
("llava_next_video", "LLaVa-NeXT-Video"),
@@ -776,6 +780,7 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict(
("rt_detr_resnet", "rt_detr"),
("granitevision", "llava_next"),
("sam_vision_model", "sam"),
+ ("llama4_text", "llama4"),
]
)
diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py
index 7cfda72a93..2f9d42fcdb 100644
--- a/src/transformers/models/auto/image_processing_auto.py
+++ b/src/transformers/models/auto/image_processing_auto.py
@@ -104,6 +104,7 @@ else:
("layoutlmv2", ("LayoutLMv2ImageProcessor",)),
("layoutlmv3", ("LayoutLMv3ImageProcessor",)),
("levit", ("LevitImageProcessor",)),
+ ("llama4", ("Llama4ImageProcessor", "Llama4ImageProcessorFast")),
("llava", ("LlavaImageProcessor", "LlavaImageProcessorFast")),
("llava_next", ("LlavaNextImageProcessor", "LlavaNextImageProcessorFast")),
("llava_next_video", ("LlavaNextVideoImageProcessor",)),
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index 5df90ff5a8..d33d0f20d5 100644
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -17,7 +17,6 @@
import warnings
from collections import OrderedDict
-from ...configuration_utils import PretrainedConfig
from ...utils import logging
from .auto_factory import (
_BaseAutoBackboneClass,
@@ -161,6 +160,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("levit", "LevitModel"),
("lilt", "LiltModel"),
("llama", "LlamaModel"),
+ ("llama4", "Llama4ForConditionalGeneration"),
("longformer", "LongformerModel"),
("longt5", "LongT5Model"),
("luke", "LukeModel"),
@@ -547,6 +547,8 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("jamba", "JambaForCausalLM"),
("jetmoe", "JetMoeForCausalLM"),
("llama", "LlamaForCausalLM"),
+ ("llama4", "Llama4ForCausalLM"),
+ ("llama4_text", "Llama4ForCausalLM"),
("mamba", "MambaForCausalLM"),
("mamba2", "Mamba2ForCausalLM"),
("marian", "MarianForCausalLM"),
@@ -634,6 +636,7 @@ MODEL_FOR_IMAGE_MAPPING_NAMES = OrderedDict(
("ijepa", "IJepaModel"),
("imagegpt", "ImageGPTModel"),
("levit", "LevitModel"),
+ ("llama4", "Llama4VisionModel"),
("mllama", "MllamaVisionModel"),
("mobilenet_v1", "MobileNetV1Model"),
("mobilenet_v2", "MobileNetV2Model"),
@@ -849,6 +852,7 @@ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
("idefics3", "Idefics3ForConditionalGeneration"),
("instructblip", "InstructBlipForConditionalGeneration"),
("kosmos-2", "Kosmos2ForConditionalGeneration"),
+ ("llama4", "Llama4ForConditionalGeneration"),
("llava", "LlavaForConditionalGeneration"),
("llava_next", "LlavaNextForConditionalGeneration"),
("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
@@ -1492,6 +1496,7 @@ MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict(
("emu3", "Emu3TextModel"),
("flaubert", "FlaubertModel"),
("ibert", "IBertModel"),
+ ("llama4", "Llama4TextModel"),
("longformer", "LongformerModel"),
("mllama", "MllamaTextModel"),
("mobilebert", "MobileBertModel"),
@@ -1678,30 +1683,6 @@ _AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="languag
class AutoModelForCausalLM(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING
- @classmethod
- def _prepare_config_for_auto_class(cls, config: PretrainedConfig) -> PretrainedConfig:
- """
- Additional autoclass-specific config post-loading manipulation. In this specific autoclass, if the config has
- a nested text decoder section, uses that section instead.
-
- Under the hood, multimodal models mapped by AutoModelForCausalLM assume the text decoder receives its own
- config, rather than the config for the whole model. This is used e.g. to load the text-only part of a VLM.
- """
- possible_text_config_names = ("decoder", "generator", "text_config")
- text_config_names = []
- for text_config_name in possible_text_config_names:
- if hasattr(config, text_config_name):
- text_config_names += [text_config_name]
-
- text_config = config.get_text_config(decoder=True)
- if text_config_names and type(text_config) in cls._model_mapping.keys():
- warnings.warn(
- "Loading a multimodal model with `AutoModelForCausalLM` is deprecated and will be removed in v5. "
- "`AutoModelForCausalLM` will be used to load only the text-to-text generation module.",
- FutureWarning,
- )
- return config
-
AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling")
diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py
index 48081b9df8..6e655edff1 100644
--- a/src/transformers/models/auto/processing_auto.py
+++ b/src/transformers/models/auto/processing_auto.py
@@ -77,6 +77,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
("kosmos-2", "Kosmos2Processor"),
("layoutlmv2", "LayoutLMv2Processor"),
("layoutlmv3", "LayoutLMv3Processor"),
+ ("llama4", "Llama4Processor"),
("llava", "LlavaProcessor"),
("llava_next", "LlavaNextProcessor"),
("llava_next_video", "LlavaNextVideoProcessor"),
diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py
index 42ff109cca..eb54d95ab6 100644
--- a/src/transformers/models/auto/tokenization_auto.py
+++ b/src/transformers/models/auto/tokenization_auto.py
@@ -292,6 +292,20 @@ else:
"LlamaTokenizerFast" if is_tokenizers_available() else None,
),
),
+ (
+ "llama4",
+ (
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
+ (
+ "llama4_text",
+ (
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
("llava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("llava_next", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("llava_next_video", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
diff --git a/src/transformers/models/llama4/__init__.py b/src/transformers/models/llama4/__init__.py
new file mode 100644
index 0000000000..59fe1686ce
--- /dev/null
+++ b/src/transformers/models/llama4/__init__.py
@@ -0,0 +1,29 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_llama4 import *
+ from .image_processing_llama4_fast import *
+ from .modeling_llama4 import *
+ from .processing_llama4 import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/src/transformers/models/llama4/configuration_llama4.py b/src/transformers/models/llama4/configuration_llama4.py
new file mode 100644
index 0000000000..1c4c00f48f
--- /dev/null
+++ b/src/transformers/models/llama4/configuration_llama4.py
@@ -0,0 +1,432 @@
+# coding=utf-8
+# Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class Llama4VisionConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Llama4VisionModel`]. It is used to instantiate a
+ Llama4 vision model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the Llama4 109B.
+
+ e.g. [meta-llama/Llama-4-Scout-17B-16E](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
+ num_hidden_layers (`int`, *optional*, defaults to 34):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_channels (`int`, *optional*, defaults to 3):
+ Number of channels in the input image.
+ intermediate_size (`int`, *optional*, defaults to 5632):
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
+ vision_output_dim (`int`, *optional*, defaults to 7680):
+ Dimensionality of the vision model output. Includes output of transformer
+ encoder with intermediate layers and global transformer encoder.
+ image_size (`int`, *optional*, defaults to 448):
+ The size (resolution) of each image *tile*.
+ patch_size (`int`, *optional*, defaults to 14):
+ The size (resolution) of each patch.
+ norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the layer normalization layers.
+ vision_feature_layer (``, *optional*, defaults to -1): TODO
+ vision_feature_select_strategy (`int`, *optional*, defaults to `"default"`): TODO
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ pixel_shuffle_ratio (`int`, *optional*, defaults to 0.5): TODO
+ projector_input_dim (`int`, *optional*, defaults to 4096): TODO
+ projector_output_dim (`int`, *optional*, defaults to 4096): TODO
+ multi_modal_projector_bias (`int`, *optional*, defaults to `False`): TODO
+ projector_dropout (`int`, *optional*, defaults to 0.0): TODO
+ attention_dropout (`int`, *optional*, defaults to 0.0): TODO
+ rope_theta (`int`, *optional*, defaults to 10000): TODO
+ """
+
+ base_model_tp_plan = {
+ "model.layers.*.self_attn.q_proj": "colwise",
+ "model.layers.*.self_attn.k_proj": "colwise",
+ "model.layers.*.self_attn.v_proj": "colwise",
+ "model.layers.*.self_attn.o_proj": "rowwise",
+ "vision_adapter.mlp.fc1": "colwise",
+ "vision_adapter.mlp.fc2": "rowwise",
+ "patch_embedding.linear": "colwise_rep",
+ }
+ model_type = "llama4_vision_model"
+ base_config_key = "vision_config"
+
+ def __init__(
+ self,
+ hidden_size: int = 768,
+ hidden_act: str = "gelu",
+ num_hidden_layers: int = 34,
+ num_attention_heads: int = 16,
+ num_channels: int = 3,
+ intermediate_size: int = 5632,
+ vision_output_dim: int = 7680,
+ image_size: int = 448,
+ patch_size: int = 14,
+ norm_eps: float = 1e-5,
+ vision_feature_layer=-1,
+ vision_feature_select_strategy="default",
+ initializer_range: float = 0.02,
+ pixel_shuffle_ratio=0.5,
+ projector_input_dim=4096,
+ projector_output_dim=4096,
+ multi_modal_projector_bias=False,
+ projector_dropout=0.0,
+ attention_dropout=0.0,
+ rope_theta=10000,
+ **kwargs,
+ ):
+ self.hidden_size = hidden_size
+ self.hidden_act = hidden_act
+ self.num_hidden_layers = num_hidden_layers
+ self.num_channels = num_channels
+ self.intermediate_size = intermediate_size
+ self.image_size = image_size
+ self.vision_output_dim = vision_output_dim
+ self.patch_size = patch_size
+ self.norm_eps = norm_eps
+ self.num_attention_heads = num_attention_heads
+ self.initializer_range = initializer_range
+ self.pixel_shuffle_ratio = pixel_shuffle_ratio
+ self.projector_input_dim = projector_input_dim
+ self.projector_output_dim = projector_output_dim
+ self.multi_modal_projector_bias = multi_modal_projector_bias
+ self.projector_dropout = projector_dropout
+ self.attention_dropout = attention_dropout
+ self.vision_feature_layer = vision_feature_layer
+ self.vision_feature_select_strategy = vision_feature_select_strategy
+ self.rope_theta = rope_theta
+ super().__init__(**kwargs)
+
+
+class Llama4TextConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Llama4TextModel`]. It is used to instantiate a
+ Llama4 text model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the Llama4 109B.
+
+ e.g. [meta-llama/Llama-4-Scout-17B-16E](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 202048):
+ Vocabulary size of the Llama4 text model. Defines the maximum number of different tokens that can be represented
+ by the `inputs_ids` passed when calling [`Llama4TextModel`].
+ hidden_size (`int`, *optional*, defaults to 5120):
+ Dimensionality of the embeddings and hidden states.
+ intermediate_size (`int`, *optional*, defaults to 8192):
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
+ intermediate_size_mlp (`int`, *optional*, defaults to 16384): TODO
+ num_hidden_layers (`int`, *optional*, defaults to 48):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 40):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_key_value_heads (`int`, *optional*, defaults to 8):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If not
+ specified, will default to `num_attention_heads`.
+ head_dim (`int`, *optional*, defaults to 128): TODO
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the encoder and pooler.
+ max_position_embeddings (`int`, *optional*, defaults to 131072):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions.
+ pad_token_id (`int`, *optional*, defaults to 128004):
+ The id of the padding token.
+ bos_token_id (`int`, *optional*, defaults to 1):
+ The id of the beginning of sentence token.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ The id of the end of sentence token.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to `500000.0`):
+ The base period of the RoPE embeddings.
+ attention_dropout (`int`, *optional*, defaults to 0.0): TODO
+ num_experts_per_tok (`int`, *optional*, defaults to 1): TODO
+ num_local_experts (`int`, *optional*, defaults to 16): TODO
+ moe_layers (`int`, *optional*): TODO
+ interleave_moe_layer_step (`int`, *optional*, defaults to 1): TODO
+ use_qk_norm (`int`, *optional*, defaults to `True`): TODO
+ output_router_logits (`int`, *optional*, defaults to `False`): TODO
+ router_aux_loss_coef (`int`, *optional*, defaults to 0.001): TODO
+ router_jitter_noise (`int`, *optional*, defaults to 0.0): TODO
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ `short_factor` (`List[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `long_factor` (`List[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `low_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+ `high_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+
+
+ no_rope_layers (`int`, *optional*): TODO
+ no_rope_layer_interval (`int`, *optional*, defaults to 4): TODO
+ attention_chunk_size (`int`, *optional*, defaults to 8192):
+
+ attn_temperature_tuning (`int`, *optional*, defaults to 4): TODO
+ floor_scale (`int`, *optional*, defaults to 8192): TODO
+ attn_scale (`int`, *optional*, defaults to 0.1): TODO
+
+ Example:
+ """
+
+ model_type = "llama4_text"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ base_model_tp_plan = {
+ "layers.*.self_attn.q_proj": "colwise",
+ "layers.*.self_attn.k_proj": "colwise",
+ "layers.*.self_attn.v_proj": "colwise",
+ "layers.*.self_attn.o_proj": "rowwise",
+ "layers.*.input_layernorm.weight": "sequence_parallel",
+ "layers.*.post_attention_layernorm.weight": "sequence_parallel",
+ "norm.weight": "sequence_parallel",
+ "layers.*.feed_forward.shared_expert.gate_proj": "local_colwise",
+ "layers.*.feed_forward.shared_expert.up_proj": "local_colwise",
+ "layers.*.feed_forward.shared_expert.down_proj": "local_rowwise",
+ "layers.*.feed_forward.experts.gate_up_proj": "local_packed_rowwise", # row because not linear
+ "layers.*.feed_forward.experts.down_proj": "local_colwise", # col because not linear
+ "layers.*.feed_forward.experts": "local",
+ "layers.*.feed_forward.gate_proj": "local_colwise",
+ "layers.*.feed_forward.up_proj": "local_colwise",
+ "layers.*.feed_forward.down_proj": "local_rowwise",
+ "layers.*.feed_forward": "gather",
+ }
+
+ def __init__(
+ self,
+ vocab_size=202048,
+ hidden_size=5120,
+ intermediate_size=8192,
+ intermediate_size_mlp=16384,
+ num_hidden_layers=48,
+ num_attention_heads=40,
+ num_key_value_heads=8,
+ head_dim=128,
+ hidden_act="silu",
+ max_position_embeddings=4096 * 32,
+ initializer_range=0.02,
+ rms_norm_eps=1e-5,
+ use_cache=True,
+ pad_token_id=None,
+ bos_token_id=1,
+ eos_token_id=2,
+ tie_word_embeddings=False,
+ rope_theta=500000,
+ attention_dropout=0.0,
+ num_experts_per_tok=1,
+ num_local_experts=16,
+ moe_layers=None,
+ interleave_moe_layer_step=1,
+ use_qk_norm=True,
+ output_router_logits=False,
+ router_aux_loss_coef=0.001,
+ router_jitter_noise=0.0,
+ rope_scaling=None,
+ no_rope_layers=None,
+ no_rope_layer_interval=4,
+ attention_chunk_size=8192,
+ attn_temperature_tuning=4,
+ floor_scale=8192,
+ attn_scale=0.1,
+ **kwargs,
+ ):
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+ self.attn_temperature_tuning = attn_temperature_tuning
+ self.attn_scale = attn_scale
+ self.floor_scale = floor_scale
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.intermediate_size_mlp = intermediate_size_mlp
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.rope_scaling = rope_scaling
+ self.attention_bias = False
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.attention_dropout = attention_dropout
+ self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
+ self.use_qk_norm = use_qk_norm
+
+ self.num_experts_per_tok = num_experts_per_tok
+ self.num_local_experts = num_local_experts
+
+ self.output_router_logits = output_router_logits
+ self.router_aux_loss_coef = router_aux_loss_coef
+ self.router_jitter_noise = router_jitter_noise
+ default_no_rope_layers = [
+ int((layer_idx + 1) % no_rope_layer_interval != 0) for layer_idx in range(self.num_hidden_layers)
+ ]
+
+ # no_rope_layers == [] is invalid as we cannot have 0 layers
+ self.no_rope_layers = no_rope_layers if no_rope_layers else default_no_rope_layers
+
+ self.interleave_moe_layer_step = interleave_moe_layer_step
+ self.moe_layers = (
+ moe_layers
+ if moe_layers is not None
+ else list(range(interleave_moe_layer_step - 1, num_hidden_layers, interleave_moe_layer_step))
+ )
+ self.attention_chunk_size = attention_chunk_size
+
+
+class Llama4Config(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Llama4Model`]. It is used to instantiate an
+ Llama4 model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the Llama4 109B.
+
+ e.g. [meta-llama/Llama-4-Scout-17B-16E](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vision_config (`Llama4VisionConfig`, *optional*):
+ The Llama4 Vision config.
+ text_config (`Llama4TextConfig`, *optional*):
+ The Llama4 Text config.
+ boi_token_index (`int`, *optional*, defaults to 200080):
+ The begin-of-image token index to wrap the image prompt.
+ eoi_token_index (`int`, *optional*, defaults to 200081):
+ The end-of-image token index to wrap the image prompt.
+ image_token_index (`int`, *optional*, defaults to 200092):
+ The image token index to encode the image prompt.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether the model's input and output word embeddings should be tied.
+
+ ```python
+ >>> from transformers import Llama4Model, Llama4Config
+
+ >>> # Initializing a Llama4 7B style configuration
+ >>> configuration = Llama4Config()
+
+ >>> # Initializing a model from the Llama4 7B style configuration
+ >>> model = Llama4Model(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "llama4"
+ sub_configs = {"text_config": Llama4TextConfig, "vision_config": Llama4VisionConfig}
+ base_model_tp_plan = {
+ "multi_modal_projector.linear_1": "colwise_rep",
+ }
+
+ def __init__(
+ self,
+ vision_config=None,
+ text_config=None,
+ boi_token_index=200080,
+ eoi_token_index=200081,
+ image_token_index=200092,
+ tie_word_embeddings=False,
+ **kwargs,
+ ):
+ if vision_config is None:
+ self.vision_config = Llama4VisionConfig()
+ logger.info("vision_config is None, using default llama4 vision config")
+ elif isinstance(vision_config, dict):
+ self.vision_config = Llama4VisionConfig(**vision_config)
+ elif isinstance(vision_config, Llama4VisionConfig):
+ self.vision_config = vision_config
+
+ self.boi_token_index = boi_token_index
+ self.eoi_token_index = eoi_token_index
+ self.image_token_index = image_token_index
+
+ if text_config is None:
+ self.text_config = Llama4TextConfig()
+ logger.info("text_config is None, using default llama4 text config")
+ elif isinstance(text_config, dict):
+ self.text_config = Llama4TextConfig(**text_config)
+ elif isinstance(text_config, Llama4TextConfig):
+ self.text_config = text_config
+
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
+
+
+__all__ = ["Llama4Config", "Llama4TextConfig", "Llama4VisionConfig"]
diff --git a/src/transformers/models/llama4/convert_llama4_weights_to_hf.py b/src/transformers/models/llama4/convert_llama4_weights_to_hf.py
new file mode 100644
index 0000000000..d1cd6b1993
--- /dev/null
+++ b/src/transformers/models/llama4/convert_llama4_weights_to_hf.py
@@ -0,0 +1,736 @@
+import argparse
+import gc
+import io
+import json
+import os
+import re
+from typing import List, Optional
+
+import torch
+from tokenizers import AddedToken, processors
+from tqdm import tqdm
+
+from transformers import (
+ GenerationConfig,
+ Llama4Config,
+ Llama4ForConditionalGeneration,
+ Llama4ImageProcessorFast,
+ Llama4Processor,
+ Llama4TextConfig,
+ Llama4VisionConfig,
+ PreTrainedTokenizerFast,
+)
+from transformers.integrations.tiktoken import TikTokenConverter
+
+
+_OFFLINE_QUANT_COMPATIBLE = os.environ.get("OFFLINE_QUANT_COMPATIBLE", "0") == "1"
+
+torch.serialization.add_safe_globals([io.BytesIO])
+# fmt: off
+# `None` means we drop the key
+
+
+weight_postfix = ".weight" if _OFFLINE_QUANT_COMPATIBLE else ""
+ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
+ # CausalLM keys
+ r"output.weight": r"language_model.lm_head.weight",
+ r"\nnorm.weight": r"\nlanguage_model.model.norm.weight",
+ # Model keys
+ r"tok_embeddings.weight": r"language_model.model.embed_tokens.weight",
+ r"freq_cis": None,
+ r"rope.freqs": None,
+ r"layers.(\d+).attention_norm.weight": r"language_model.model.layers.\1.input_layernorm.weight",
+ r"layers.(\d+).attention.wqkv.layer_norm_weight": r"language_model.model.layers.\1.input_layernorm.weight",
+ r"layers.(\d+).feed_forward.norm.weight": r"language_model.model.layers.\1.post_attention_layernorm.weight",
+ r"layers.(\d+).attention.wo.weight": r"language_model.model.layers.\1.self_attn.o_proj.weight",
+ r"layers.(\d+).attention.wqkv.weight": r"language_model.model.layers.\1.self_attn.qkv_proj.weight",
+
+ # MoE keys: no simple MLPmodel.
+ r"layers.(\d+).feed_forward.experts.moe_w_in_eD_F": r"language_model.model.layers.\1.feed_forward.experts.gate_proj" + weight_postfix, # will be fused with up
+ r"layers.(\d+).feed_forward.experts.moe_w_out_eF_D": r"language_model.model.layers.\1.feed_forward.experts.down_proj" + weight_postfix, # expert win
+ r"layers.(\d+).feed_forward.experts.moe_w_swiglu_eD_F": r"language_model.model.layers.\1.feed_forward.experts.up_proj" + weight_postfix, # fused with up
+ r"layers.(\d+).feed_forward.router_DE": r"language_model.model.layers.\1.feed_forward.router.weight", # used for top
+ r"layers.(\d+).feed_forward.w_in_shared_FD": r"language_model.model.layers.\1.feed_forward.shared_expert.gate_proj", # might need to be fused for efficiency?
+ r"layers.(\d+).feed_forward.w_out_shared_DF": r"language_model.model.layers.\1.feed_forward.shared_expert.down_proj", # might need to be fused for efficiency?
+ r"layers.(\d+).feed_forward.w_swiglu_FD": r"language_model.model.layers.\1.feed_forward.shared_expert.up_proj", # might need to be fused for efficiency?
+ r"layers.(\d+).feed_forward.global_gate_stats_3E": None,
+ # Unused keys in load hooks (explicitly removed)
+ r'layers.(\d+).attention.wqkv._extra_state': None,
+ r'layers.(\d+).attention.wo._extra_state': None,
+ # Key apparently unused in base models
+ r'layers.(\d+).feed_forward.expert_activation_DE': None,
+
+ # MLP layer variant
+ r"layers.(\d+).feed_forward.w1.weight": r"language_model.model.layers.\1.feed_forward.gate_proj.weight", # might need to be fused for efficiency?
+ r"layers.(\d+).feed_forward.w3.weight": r"language_model.model.layers.\1.feed_forward.up_proj.weight", # might need to be fused for efficiency?
+ # r"layers.(\d+).feed_forward.mlp.fc1_weight": r"language_model.model.layers.\1.feed_forward.gate_up_proj.weight",
+ r"layers.(\d+).feed_forward.mlp.fc2_weight": r"language_model.model.layers.\1.feed_forward.down_proj.weight",
+ r"layers.(\d+).feed_forward.mlp.layer_norm.weight": r"language_model.model.layers.\1.post_attention_layernorm.weight",
+
+ # Vision encoder mapping
+ r"vision_embeddings.vision_encoder.conv1._linear": r"vision_model.patch_embedding.linear",
+ r'vision_embeddings.vision_adapter.mlp.c_fc': r"vision_model.vision_adapter.mlp.fc1",
+ r'vision_embeddings.vision_adapter.mlp.c_proj': r"vision_model.vision_adapter.mlp.fc2",
+ r"vision_embeddings.vision_encoder.transformer.resblocks.(\d+).attn.wq.(weight|bias)": r"vision_model.model.layers.\1.self_attn.q_proj.\2",
+ r"vision_embeddings.vision_encoder.transformer.resblocks.(\d+).attn.wk.(weight|bias)": r"vision_model.model.layers.\1.self_attn.k_proj.\2",
+ r"vision_embeddings.vision_encoder.transformer.resblocks.(\d+).attn.wv.(weight|bias)": r"vision_model.model.layers.\1.self_attn.v_proj.\2",
+ r"vision_embeddings.vision_encoder.transformer.resblocks.(\d+).attn.wo.(weight|bias)": r"vision_model.model.layers.\1.self_attn.o_proj.\2",
+ r"vision_embeddings.vision_encoder.transformer.resblocks.(\d+).mlp.c_fc": r"vision_model.model.layers.\1.mlp.fc1",
+ r"vision_embeddings.vision_encoder.transformer.resblocks.(\d+).mlp.c_proj": r"vision_model.model.layers.\1.mlp.fc2",
+ r"vision_embeddings.vision_encoder.transformer.resblocks.(\d+).ln_1.(weight|bias)": r"vision_model.model.layers.\1.input_layernorm.\2",
+ r"vision_embeddings.vision_encoder.transformer.resblocks.(\d+).ln_2.(weight|bias)": r"vision_model.model.layers.\1.post_attention_layernorm.\2",
+ # r'vision_embeddings.vision_encoder.ln_(1|2).(weight|bias)': r'vision_model.transformer.vision_encoder.layernorm_\1.\2',
+ r'vision_embeddings.vision_encoder.ln_post': r'vision_model.layernorm_post',
+ r'vision_embeddings.vision_encoder.ln_pre': r'vision_model.layernorm_pre',
+ r'vision_embeddings.vision_encoder.class_embedding': r'vision_model.class_embedding',
+ r"vision_embeddings.vision_encoder.positional_embedding_vlm": r"vision_model.positional_embedding_vlm",
+ r"vision_embeddings.vision_encoder.(?=\w)": r"vision_model.model.",
+ r"vision_projection.weight": r"multi_modal_projector.linear_1.weight",
+}
+# fmt: on
+
+
+def convert_old_keys_to_new_keys(state_dict_keys: dict = None):
+ """
+ This function should be applied only once, on the concatenated keys to efficiently rename using
+ the key mappings.
+ """
+ output_dict = {}
+ if state_dict_keys is not None:
+ old_text = "\n".join(state_dict_keys)
+ new_text = old_text
+ for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items():
+ if replacement is None:
+ new_text = re.sub(pattern, "", new_text) # an empty line
+ continue
+ new_text = re.sub(pattern, replacement, new_text)
+ output_dict = dict(zip(old_text.split("\n"), new_text.split("\n")))
+ return output_dict
+
+
+def permute_for_rope(input_tensor, n_heads, dim1, dim2):
+ """
+ When you go from the complex ROPE formulation to sin and cos one, you need
+ to permute the query and key weights (to avoid doing it on the fly)
+ """
+ input_tensor = input_tensor.view(n_heads, dim1 // n_heads // 2, 2, dim2)
+ input_tensor = input_tensor.transpose(1, 2).reshape(dim1, dim2)
+ return input_tensor
+
+
+def is_param_same_across_shards(key):
+ """
+ Return `False` if the parameter is different across checkpoint shards
+ and needs to be concatenated.
+ """
+ patterns = [
+ r"language_model.layers.(\d+).(.*)layernorm.weight",
+ r"language_model.norm.weight",
+ r"router.weight",
+ r"feed_forward.global_gate_stats",
+ # not all vision weights are sharded, some are repeated
+ r"vision_model.class_embedding",
+ r"vision_model.positional_embedding_vlm",
+ r"vision_embeddings.vision_encoder.positional_embedding_vlm",
+ r"vision_model.model.layers.(\d+).self_attn.o_proj.bias",
+ r"vision_model.model.layers.(\d+).input_layernorm",
+ r"vision_model.model.layers.(\d+).post_attention_layernorm",
+ r"vision_model.layernorm_pre",
+ r"vision_model.layernorm_post",
+ r"vision_model.model.layers.(\d+).mlp.fc2.bias",
+ r"norm.weight",
+ ] # fmt: skip
+ return any(re.search(pattern, key) for pattern in patterns)
+
+
+def get_concat_dim(key):
+ """
+ Return the dimension to concatenate the weights on.
+ """
+ concat_dim_1 = [
+ # language dim 1 sharded weights
+ "feed_forward.router.weight",
+ "self_attn.o_proj",
+ "experts.gate_proj",
+ "experts.up_proj",
+ "expert.down_proj",
+ # "feed_forward.up_proj",
+ # "feed_forward.gate_proj",
+ "feed_forward.down_proj",
+ "global_gate_stats",
+ # vision dim1 sharded stuff
+ "mlp.fc2.weight", # covers all rowparallels across vis
+ ] # fmt: off
+ if any(re.search(pattern, key) for pattern in concat_dim_1):
+ return 1
+ return 0
+
+
+def compute_intermediate_size(hidden_dim, multiple_of=1024, ffn_dim_multiplier=1.3):
+ hidden_dim = 4 * int(2 * hidden_dim / 3)
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
+ return hidden_dim
+
+
+# Ignore extra info - h/t Aritra
+def safe_load(filename):
+ # Can use weights_only because io.BytesIO was registered, but we still need to skip those objects
+ shard = torch.load(filename, weights_only=True, map_location="cpu", mmap=True)
+ shard = {k: v for k, v in shard.items() if not isinstance(v, io.BytesIO)}
+ return shard
+
+
+# Unpack mlp projections - possibly to be removed when they are fused
+def preprocess_keys(state_dict):
+ new_state_dict = {}
+ for key, value in state_dict.items():
+ if "mlp.fc1_weight" in key:
+ prefix = key.split("mlp.fc1_weight")[0]
+ w1, w3 = value.chunk(2, dim=0)
+ new_state_dict[prefix + "w1.weight"] = w1
+ new_state_dict[prefix + "w3.weight"] = w3
+ else:
+ new_state_dict[key] = value
+ return new_state_dict
+
+
+def max_context_length(model_path, instruct=False):
+ """256K for base, 1M for 128E instruct, 10M for 16E instruct."""
+ if not instruct:
+ return 256 * 1024
+
+ with open(os.path.join(model_path, "params.json"), "r") as f:
+ params = json.load(f)
+ params = params.get("model", params)
+ num_experts = params["moe_args"]["num_experts"]
+ return 10485760 if num_experts == 16 else 1048576
+
+
+def write_model(
+ model_path,
+ input_base_path,
+ num_shards,
+ convert_checkpoints,
+ safe_serialization=True,
+ instruct=False,
+):
+ os.makedirs(model_path, exist_ok=True)
+
+ with open(os.path.join(input_base_path, "params.json"), "r") as f:
+ params = json.load(f)
+
+ params = params.get("model", params)
+ torch_dtype = "bfloat16"
+
+ # ------------------------------------------------------------
+ # Text model params and config
+ # ------------------------------------------------------------
+
+ # params from config
+ vocab_size = 202048 # params["vocab_size"] # seems like the lm head is 25256 so padded instead of 202048
+ num_layers = params["n_layers"]
+ dim = params["dim"]
+ num_heads = params["n_heads"]
+ rms_norm_eps = params["norm_eps"]
+ rope_theta = params["rope_theta"]
+ no_rope_layer_interval = params["nope_layer_interval"]
+ attention_chunk_size = params["attention_chunk_size"]
+
+ config_kwargs = {}
+ if params["use_scaled_rope"]:
+ # some constans from original code
+ rope_scaling = {
+ "rope_type": "llama3",
+ "factor": 8.0,
+ "low_freq_factor": 1.0,
+ "high_freq_factor": 4.0,
+ "original_max_position_embeddings": 8192,
+ }
+ config_kwargs.update({"rope_scaling": rope_scaling})
+
+ # compute additional params for weight conversion
+ num_heads_per_shard = num_heads // num_shards
+ dim_per_head = dim // num_heads
+ # intermediate_size = compute_intermediate_size(dim, multiple_of=params["multiple_of"])
+
+ num_key_value_heads = params["n_kv_heads"] # for GQA / MQA
+
+ num_experts = params["moe_args"]["num_experts"]
+ interleave_moe_layer_step = params["moe_args"].get("interleave_moe_layer_step", 1)
+
+ no_rope_layer_interval = params["nope_layer_interval"]
+
+ bos_token_id = 200000
+ eos_token_id = [200001, 200007, 200008] if instruct else 200001
+ pad_token_id = 200018
+
+ text_config = Llama4TextConfig(
+ num_attention_heads=num_heads,
+ vocab_size=vocab_size,
+ hidden_size=dim,
+ rms_norm_eps=rms_norm_eps,
+ rope_theta=rope_theta,
+ num_hidden_layers=num_layers,
+ intermediate_size=8192,
+ intermediate_size_mlp=16384,
+ max_position_embeddings=max_context_length(input_base_path, instruct),
+ num_local_experts=num_experts,
+ interleave_moe_layer_step=interleave_moe_layer_step,
+ use_qk_norm=params["use_qk_norm"],
+ no_rope_layer_interval=no_rope_layer_interval,
+ attention_chunk_size=attention_chunk_size,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ pad_token_id=pad_token_id,
+ tie_word_embeddings=False, # Constant set to False
+ torch_dtype=torch_dtype,
+ for_llm_compressor=_OFFLINE_QUANT_COMPATIBLE,
+ no_rope_layers=no_rope_layer_interval,
+ **config_kwargs,
+ )
+ # default vision config frmo params
+
+ vision_params = params["vision_args"]
+ vision_dim = vision_params["dim"]
+ vision_num_layers = vision_params["n_layers"]
+ image_size = vision_params["image_size"]["height"] # siglip config is outdated
+ vision_num_heads = vision_params["n_heads"]
+
+ vision_output_dim = vision_params["output_dim"]
+
+ vision_config = Llama4VisionConfig(
+ hidden_act="gelu",
+ num_hidden_layers=vision_num_layers,
+ image_size=image_size,
+ num_attention_heads=vision_num_heads,
+ hidden_size=vision_dim,
+ vision_output_dim=vision_output_dim,
+ )
+
+ config = Llama4Config(text_config=text_config, vision_config=vision_config)
+ config.save_pretrained(model_path)
+
+ print("Model config saved successfully...")
+
+ # ------------------------------------------------------------
+ # Convert weights
+ # ------------------------------------------------------------
+
+ if convert_checkpoints:
+ print(f"Fetching all parameters from the checkpoint at {input_base_path}...")
+ if num_shards == 1:
+ if os.path.exists(os.path.join(input_base_path, "consolidated.00.pth")):
+ path = os.path.join(input_base_path, "consolidated.00.pth")
+ else:
+ path = os.path.join(input_base_path, "consolidated.pth")
+ loaded = [safe_load(path)]
+ else:
+ loaded = [
+ safe_load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"))
+ for i in tqdm(range(num_shards), desc="Loading shards", unit="shard")
+ ]
+ loaded = [preprocess_keys(d) for d in loaded]
+
+ all_keys_raw = list(loaded[0].keys())
+ repeated_keys = []
+ sharded_keys = []
+ for _key in all_keys_raw:
+ try:
+ if (loaded[0][_key] == loaded[1][_key]).all():
+ repeated_keys.append(_key)
+ else:
+ sharded_keys.append(_key)
+ except Exception as e:
+ print(f"Encountered exception {e} for {_key}")
+ print("Initializing an empty model")
+ with torch.device("meta"):
+ model = Llama4ForConditionalGeneration(config)
+
+ print("Converting model...")
+ all_keys = list(loaded[0].keys())
+ new_keys = convert_old_keys_to_new_keys(all_keys)
+ state_dict = {}
+ replicated_params = [] # To keep track of replicated weights.
+ for key in tqdm(all_keys, desc="Renaming and processing all keys", unit="key"):
+ new_key = new_keys[key]
+ print(key, new_key)
+ if not is_param_same_across_shards(new_key):
+ current_parameter = [chunk.pop(key) for chunk in loaded if not isinstance(chunk[key], io.BytesIO)]
+ else:
+ print(f"{key} (now {new_key}) is the same across all shards.")
+ replicated_params.append((key, new_key))
+ current_parameter = [loaded[0].pop(key)] if not isinstance(loaded[0][key], io.BytesIO) else []
+
+ if "running_gate_stats_3E" in key:
+ new_keys.pop(new_key)
+ continue
+
+ concat_dim = get_concat_dim(new_key)
+
+ # Post-process the current_parameter.
+ if "qkv_proj" in new_key:
+ queries = []
+ keys = []
+ values = []
+ for param in current_parameter:
+ query, key_, value = param.split(
+ [
+ num_heads * dim_per_head // num_shards,
+ num_key_value_heads * dim_per_head // num_shards,
+ num_key_value_heads * dim_per_head // num_shards,
+ ]
+ )
+ queries.append(query.reshape(num_heads_per_shard, -1, dim))
+ keys.append(key_.reshape(num_key_value_heads // num_shards, -1, dim))
+ values.append(value.reshape(num_key_value_heads // num_shards, -1, dim))
+
+ queries = torch.cat(queries, dim=0).reshape(dim, dim)
+ keys = torch.cat(keys, dim=0).reshape(num_key_value_heads * dim_per_head, dim)
+ values = torch.cat(values, dim=0).reshape(num_key_value_heads * dim_per_head, dim)
+ # queries = permute_for_rope(queries, num_heads, dim, dim)
+ # keys = permute_for_rope(keys, num_key_value_heads, num_key_value_heads*dim_per_head, dim)
+
+ q = new_key.replace("qkv", "q")
+ tqdm.write(f"Processing: {key.ljust(50)} ->\t {q}, {queries.shape}")
+ state_dict[q] = queries
+
+ k = new_key.replace("qkv", "k")
+ tqdm.write(f"Processing: {key.ljust(50)} ->\t {k}, {keys.shape}")
+ state_dict[k] = keys
+
+ v = new_key.replace("qkv", "v")
+ tqdm.write(f"Processing: {key.ljust(50)} ->\t {v}, {values.shape}")
+ state_dict[v] = values
+ elif _OFFLINE_QUANT_COMPATIBLE and "feed_forward.experts." in new_key:
+ # for experts, we need to split expert for offline quantiation purpose and don't need to fuse
+ expert_lists = []
+ for k in current_parameter:
+ expert_lists.append(
+ list(k.reshape(num_experts, -1, k.shape[-1]).unbind(0))
+ ) # [#expert * IN, OUT] -> #experts * [IN, OUT]
+ for i in range(num_experts):
+ expert = torch.cat([expert_list[i] for expert_list in expert_lists], dim=concat_dim)
+ expert_key = new_key.replace("experts.", f"experts.{i}.")
+ state_dict[expert_key] = expert.transpose(0, 1).contiguous() # [OUT, IN]
+ tqdm.write(f"Processing: {key.ljust(50)} ->\t {expert_key}, {state_dict[expert_key].shape}")
+ elif re.search(r"(gate|up)_proj", new_key):
+ path = new_key.split(".")
+ gate_key = re.sub(r"(gate|up)_proj", lambda m: "gate_proj", new_key)
+ up_key = re.sub(r"(gate|up)_proj", lambda m: "up_proj", new_key)
+ if gate_key == new_key:
+ state_dict[new_key] = torch.cat(current_parameter, dim=concat_dim)
+ elif new_key == up_key:
+ if "experts" not in new_key:
+ state_dict[new_key] = torch.cat(current_parameter, dim=concat_dim)
+ else:
+ gate_proj = state_dict.pop(gate_key)
+ gate_proj = [
+ gate_proj.reshape(num_experts, -1, 8, 1024)[:, :, k, :].reshape(num_experts, -1, 1024)
+ for k in range(8)
+ ]
+ gate_proj = torch.cat(gate_proj, dim=-1)
+
+ up_proj = [
+ k.reshape(num_experts, -1, 8, 1024).reshape(num_experts, -1, 1024)
+ for k in current_parameter
+ ]
+ up_proj = torch.cat(up_proj, dim=-1)
+
+ gate_up_proj = torch.cat((gate_proj, up_proj), dim=-1)
+ new_key = new_key.replace("up_proj", "gate_up_proj")
+ state_dict[new_key] = gate_up_proj.contiguous()
+
+ tqdm.write(f"Processing: {key.ljust(50)} ->\t {new_key}, {state_dict[new_key].shape}")
+ elif "down_proj" in new_key:
+ current_parameter = torch.cat(current_parameter, dim=concat_dim)
+ if "experts" in new_key:
+ p = []
+ for i in range(8):
+ p += [current_parameter.reshape(8, -1, 5120)[i, :, :].view(num_experts, -1, 5120)]
+ current_parameter = torch.cat(p, dim=1)
+ state_dict[new_key] = current_parameter.contiguous()
+ tqdm.write(f"Processing: {key.ljust(50)} ->\t {new_key}, {state_dict[new_key].shape}")
+ elif "router" in new_key:
+ current_parameter = torch.cat(current_parameter, dim=concat_dim)
+ state_dict[new_key] = current_parameter.transpose(0, 1)
+ elif "lm_head" in new_key:
+ current_parameter = torch.cat(current_parameter, dim=concat_dim).clone()
+ # TODO we need to do better than mean, works for now
+ # if (vocab_size - current_parameter.shape[0]) > 0:
+ # mean_embedding = torch.mean(current_parameter, dim=0)[:, None].repeat(vocab_size-current_parameter.shape[0],1)
+ # print(mean_embedding.shape)
+ # current_parameter = torch.cat((current_parameter, mean_embedding), dim=0)
+ state_dict[new_key] = current_parameter
+ tqdm.write(
+ f"Processing: {key.ljust(50)} ->\t {new_key}, {state_dict[new_key].shape}, concat dim = {concat_dim}"
+ )
+ elif new_key == "vision_model.patch_embedding.linear.weight":
+ current_parameter = torch.cat(current_parameter, dim=concat_dim).clone()
+ # We don't reshape the patch embedding as we're using unfolded convolution as well
+ state_dict[new_key] = current_parameter # .reshape(-1, 3, vision_patch_size, vision_patch_size)
+ # generic concat for weights/select one for biases
+ elif isinstance(current_parameter, list) and len(current_parameter) > 0:
+ if not is_param_same_across_shards(new_key):
+ current_parameter = torch.cat(current_parameter, dim=concat_dim)
+ state_dict[new_key] = current_parameter
+ tqdm.write(
+ f"Processing: {key.ljust(50)} ->\t {new_key}, {state_dict[new_key].shape}, concat dim = {concat_dim}"
+ )
+ elif is_param_same_across_shards(new_key):
+ state_dict[new_key] = current_parameter[0]
+ tqdm.write(
+ f"Processing: {key.ljust(50)} ->\t {new_key}, {state_dict[new_key].shape}, concat dim = {concat_dim}"
+ )
+
+ elif new_key == "":
+ # skip empty keys
+ continue
+ else:
+ # just load the parameter
+ state_dict[new_key] = current_parameter
+ tqdm.write(
+ f"Processing: {key.ljust(50)} ->\t {new_key}, {state_dict[new_key].shape}, concat dim = {concat_dim}"
+ )
+ del loaded
+ gc.collect()
+
+ print("Loading the checkpoint in a Llama4 model.")
+ state_dict.pop("")
+ model.load_state_dict(state_dict, strict=True, assign=True)
+ print("Model reloaded successfully.")
+ print("Saving the model.")
+ model.save_pretrained(model_path, safe_serialization=safe_serialization)
+ del state_dict, model
+
+ # Safety check: reload the converted model
+ gc.collect()
+ print("Reloading the model to check if it's saved correctly.")
+ with torch.no_grad():
+ # TODO test if we can do `tp_plan="auto"``
+ model = Llama4ForConditionalGeneration.from_pretrained(
+ model_path, torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="eager"
+ )
+
+ model.generation_config.top_p = 0.9
+ model.generation_config.temperature = 0.6
+ print("Model reloaded successfully.")
+
+ from transformers import AutoTokenizer
+
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
+ inputs = tokenizer(["Roses are red,"], return_tensors="pt").to(model.device)
+ out = model.generate(**inputs, max_new_tokens=4)
+ print(tokenizer.batch_decode(out))
+ # generation config
+ if instruct:
+ print("Saving generation config...")
+ generation_config = GenerationConfig(
+ do_sample=True,
+ temperature=0.6,
+ top_p=0.9,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ pad_token_id=pad_token_id,
+ )
+ generation_config.save_pretrained(model_path)
+
+
+BOS_ADDED_TOKEN = AddedToken(
+ "<|begin_of_text|>", single_word=False, lstrip=False, rstrip=False, normalized=False, special=True
+)
+EOS_ADDED_TOKEN = AddedToken(
+ "<|end_of_text|>", single_word=False, lstrip=False, rstrip=False, normalized=False, special=True
+)
+EOT_ADDED_TOKEN = AddedToken("<|eot|>", single_word=False, lstrip=False, rstrip=False, normalized=False, special=True)
+
+
+def get_reserved_special_tokens(name, count, start_index=0):
+ return [f"<|{name}_reserved_special_token_{i}|>" for i in range(start_index, start_index + count)]
+
+
+# 200005, ..., 200079
+LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS = [
+ "<|header_start|>",
+ "<|header_end|>",
+ "<|eom|>",
+ "<|eot|>",
+ "<|step|>",
+ "<|text_post_train_reserved_special_token_0|>",
+ "<|text_post_train_reserved_special_token_1|>",
+ "<|text_post_train_reserved_special_token_2|>",
+ "<|text_post_train_reserved_special_token_3|>",
+ "<|text_post_train_reserved_special_token_4|>",
+ "<|text_post_train_reserved_special_token_5|>",
+ "<|python_start|>",
+ "<|python_end|>",
+ "<|finetune_right_pad|>",
+] + get_reserved_special_tokens(
+ "text_post_train", 61, 6
+) # <|text_post_train_reserved_special_token_6|>, ..., <|text_post_train_reserved_special_token_66|>
+
+# 200080, ..., 201133
+LLAMA4_VISION_SPECIAL_TOKENS = [
+ "<|image_start|>",
+ "<|image_end|>",
+ "<|vision_reserved_special_token_0|>",
+ "<|vision_reserved_special_token_1|>",
+ "<|tile_x_separator|>",
+ "<|tile_y_separator|>",
+ "<|vision_reserved_special_token_2|>",
+ "<|vision_reserved_special_token_3|>",
+ "<|vision_reserved_special_token_4|>",
+ "<|vision_reserved_special_token_5|>",
+ "<|image|>",
+ "<|vision_reserved_special_token_6|>",
+ "<|patch|>",
+] + get_reserved_special_tokens(
+ "vision", 1041, 7
+) # <|vision_reserved_special_token_7|>, ..., <|vision_reserved_special_token_1047|>
+
+LLAMA4_SPECIAL_TOKENS = LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS + LLAMA4_VISION_SPECIAL_TOKENS
+
+BASIC_SPECIAL_TOKENS = [
+ "<|begin_of_text|>",
+ "<|end_of_text|>",
+ "<|fim_prefix|>",
+ "<|fim_middle|>",
+ "<|fim_suffix|>",
+]
+
+
+class Llama4Converter(TikTokenConverter):
+ def __init__(
+ self,
+ vocab_file,
+ special_tokens: List[str],
+ pattern: str,
+ model_max_length: int = 0,
+ chat_template: Optional[str] = None,
+ **kwargs,
+ ):
+ super().__init__(vocab_file, pattern=pattern)
+ self.additional_special_tokens = special_tokens
+ tokenizer = self.converted()
+ if chat_template is not None:
+ kwargs["chat_template"] = chat_template
+
+ self.converted_tokenizer = PreTrainedTokenizerFast(
+ tokenizer_object=tokenizer,
+ model_input_names=["input_ids", "attention_mask"],
+ model_max_length=model_max_length,
+ **kwargs,
+ )
+
+ # to check
+ # import tiktoken
+ # model = tiktoken.Encoding(
+ # name=Path(model_path).name,
+ # pat_str=self.O200K_PATTERN,
+ # mergeable_ranks=mergeable_ranks,
+ # special_tokens=self.special_tokens,
+ # )
+
+ instruct = chat_template is not None
+ self.update_post_processor(self.converted_tokenizer)
+ # finer special_tokens_map.json
+ self.converted_tokenizer._bos_token = BOS_ADDED_TOKEN
+ self.converted_tokenizer._eos_token = EOT_ADDED_TOKEN if instruct else EOS_ADDED_TOKEN
+
+ # We can't do this while building the tokenizer because we have no easy access to the bos token id
+ def update_post_processor(self, tokenizer):
+ tokenizer._tokenizer.post_processor = processors.Sequence(
+ [
+ processors.ByteLevel(trim_offsets=False),
+ processors.TemplateProcessing(
+ single="<|begin_of_text|> $A",
+ pair="<|begin_of_text|>:0 $A:0 <|begin_of_text|>:1 $B:1",
+ special_tokens=[
+ ("<|begin_of_text|>", tokenizer.convert_tokens_to_ids("<|begin_of_text|>")),
+ ],
+ ),
+ ]
+ )
+
+
+O200K_PATTERN = r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+""" # noqa: E501
+
+
+def write_tokenizer(args):
+ tokenizer_path = os.path.join(args.input_dir, "tokenizer.model")
+ chat_template = "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %} \n {%- if messages[0]['content'] is string %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- else %}\n {#- FIXME: The processor requires an array, always. #}\n {%- set system_message = messages[0]['content'][0]['text']|trim %}\n {%- endif %}\n {%- set messages = messages[1:] %}\n {%- set user_supplied_system_message = true %}\n{%- else %}\n {%- set system_message = \"\" %}\n {%- set user_supplied_system_message = false %}\n{%- endif %}\n\n{#- System message if the user supplied one #}\n{%- if user_supplied_system_message %}\n {{- \"<|header_start|>system<|header_end|>\n\n\" }}\n {%- if tools is not none %}\n {{- \"Environment: ipython\n\" }}\n {%- endif %}\n {%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\n\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\n\n\" }}\n {%- endfor %}\n {%- endif %}\n {{- system_message }}\n {{- \"<|eot|>\" }}\n{%- endif %}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|header_start|>user<|header_end|>\n\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\n\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\n\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\n\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }}\n {%- if message['content'] is string %}\n {{- message['content'] }}\n {%- else %}\n {%- for content in message['content'] %}\n {%- if content['type'] == 'image' %}\n {{- '<|image|>' }}\n {%- elif content['type'] == 'text' %}\n {{- content['text'] }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- \"<|eot|>\" }}\n {%- elif 'tool_calls' in message and message.tool_calls|length > 0 %}\n {{- '<|header_start|>assistant<|header_end|>\n\n' -}}\n {{- '<|python_start|>' }}\n {%- if message['content'] is string %}\n {{- message['content'] }}\n {%- else %}\n {%- for content in message['content'] %}\n {%- if content['type'] == 'image' %}\n {{- '<|image|>' }}\n {%- elif content['type'] == 'text' %}\n {{- content['text'] }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '<|python_end|>' }}\n {%- for tool_call in message.tool_calls %}\n {{- '{\"name\": \"' + tool_call.function.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.function.arguments | tojson }}\n {{- \"}\" }}\n {%- endfor %}\n {{- \"<|eot|>\" }}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|header_start|>ipython<|header_end|>\n\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|header_start|>assistant<|header_end|>\n\n' }}\n{%- endif %}\n"
+
+ special_tokens = BASIC_SPECIAL_TOKENS + LLAMA4_SPECIAL_TOKENS
+ converter = Llama4Converter(
+ vocab_file=tokenizer_path,
+ pattern=O200K_PATTERN,
+ special_tokens=special_tokens,
+ chat_template=chat_template if args.instruct else None,
+ bos_token="<|begin_of_text|>",
+ eos_token="<|end_of_text|>" if not args.instruct else "<|eot|>",
+ pad_token="<|finetune_right_pad_id|>",
+ model_max_length=max_context_length(args.input_dir, args.instruct),
+ )
+ tokenizer = converter.converted_tokenizer
+
+ image_processor = Llama4ImageProcessorFast()
+ processor = Llama4Processor(
+ image_processor=image_processor,
+ tokenizer=tokenizer,
+ chat_template=tokenizer.chat_template,
+ )
+ processor.save_pretrained(args.output_dir)
+ del processor
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--input_dir",
+ type=str,
+ default="/fsx/arthur/Llama-4-17B-Omni-Instruct-Original",
+ help="Location of the local folder copied from the Hub.",
+ )
+ parser.add_argument(
+ "--output_dir",
+ default="llama4_hf_vision",
+ type=str,
+ help="Location to write HF model and tokenizer",
+ )
+ parser.add_argument(
+ "--safe_serialization", default=True, type=bool, help="Whether or not to save using `safetensors`."
+ )
+ parser.add_argument(
+ "--special_tokens",
+ default=None,
+ type=List[str],
+ help="The list of special tokens that should be added to the model.",
+ )
+ parser.add_argument(
+ "--num_shards",
+ default=8,
+ type=int,
+ help="The number of individual shards used for the model. Does not have to be the same as the number of consolidated_xx.pth",
+ )
+ parser.add_argument(
+ "--instruct",
+ action="store_true",
+ help="Whether the model is an instruct model",
+ )
+ parser.add_argument(
+ "--convert_checkpoints",
+ action="store_true",
+ help="Whether to convert the original weights (or skip if previously converted)",
+ )
+
+ args = parser.parse_args()
+ write_tokenizer(args)
+
+ write_model(
+ model_path=args.output_dir,
+ input_base_path=args.input_dir,
+ safe_serialization=args.safe_serialization,
+ num_shards=args.num_shards,
+ instruct=args.instruct,
+ convert_checkpoints=args.convert_checkpoints,
+ )
diff --git a/src/transformers/models/llama4/image_processing_llama4_fast.py b/src/transformers/models/llama4/image_processing_llama4_fast.py
new file mode 100644
index 0000000000..6935ba798f
--- /dev/null
+++ b/src/transformers/models/llama4/image_processing_llama4_fast.py
@@ -0,0 +1,480 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fast Image processor class for Got-OCR-2."""
+
+import math
+from collections import defaultdict
+from functools import lru_cache
+from typing import List, Optional, Set, Tuple, Union
+
+from ...image_processing_utils import BatchFeature
+from ...image_processing_utils_fast import (
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
+ BaseImageProcessorFast,
+ DefaultFastImageProcessorKwargs,
+ group_images_by_shape,
+ reorder_images,
+)
+from ...image_utils import (
+ ImageInput,
+ PILImageResampling,
+ SizeDict,
+)
+from ...processing_utils import Unpack
+from ...utils import (
+ TensorType,
+ add_start_docstrings,
+ is_torch_available,
+ is_torchvision_available,
+ is_torchvision_v2_available,
+)
+
+
+if is_torch_available():
+ import torch
+
+if is_torchvision_available():
+ if is_torchvision_v2_available():
+ from torchvision.transforms.v2 import functional as F
+ else:
+ from torchvision.transforms import functional as F
+
+
+def get_factors(dividend: int) -> Set[int]:
+ """
+ Calculate all factors of a given number, i.e. a dividor that leaves
+ no remainder. For example, if dividend=12, it will return {1, 2, 3, 4, 6, 12}.
+
+ Args:
+ dividend (int): The number to find factors for.
+
+ Returns:
+ set: A set containing all factors of the number.
+ """
+ factors_set = set()
+
+ for i in range(1, int(dividend**0.5) + 1):
+ if dividend % i == 0:
+ factors_set.add(i)
+ factors_set.add(dividend // i)
+ return factors_set
+
+
+def get_max_res_without_distortion(
+ image_size: Tuple[int, int],
+ target_size: Tuple[int, int],
+) -> Tuple[int, int]:
+ """
+ Determines the maximum resolution to which an image can be resized to without distorting its
+ aspect ratio, based on the target resolution.
+
+ Args:
+ image_size (Tuple[int, int]): The original resolution of the image (height, width).
+ target_resolution (Tuple[int, int]): The desired resolution to fit the image into (height, width).
+ Returns:
+ Tuple[int, int]: The optimal dimensions (height, width) to which the image should be resized.
+ Example:
+ >>> _get_max_res_without_distortion([200, 300], target_size = [450, 200])
+ (134, 200)
+ >>> _get_max_res_without_distortion([800, 600], target_size = [450, 1300])
+ (450, 338)
+ """
+
+ original_height, original_width = image_size
+ target_height, target_width = target_size
+
+ scale_w = target_width / original_width
+ scale_h = target_height / original_height
+
+ if scale_w < scale_h:
+ new_width = target_width
+ new_height = min(math.floor(original_height * scale_w), target_height)
+ else:
+ new_height = target_height
+ new_width = min(math.floor(original_width * scale_h), target_width)
+
+ return new_height, new_width
+
+
+class Llama4ImageProcessorKwargs(DefaultFastImageProcessorKwargs):
+ max_patches: Optional[int]
+ resize_to_max_canvas: Optional[bool]
+
+
+def split_to_tiles(images: torch.Tensor, num_tiles_height: int, num_tiles_width: int) -> torch.Tensor:
+ # Split image into number of required tiles (width x height)
+ batch_size, num_channels, height, width = images.size()
+ images = images.view(
+ batch_size,
+ num_channels,
+ num_tiles_height,
+ height // num_tiles_height,
+ num_tiles_width,
+ width // num_tiles_width,
+ )
+ # Permute dimensions to reorder the axes
+ image = images.permute(0, 2, 4, 1, 3, 5).contiguous()
+ # Reshape into the desired output shape (batch_size * 4, num_channels, width/2, height/2)
+ image = image.view(
+ batch_size,
+ num_tiles_width * num_tiles_height,
+ num_channels,
+ height // num_tiles_height,
+ width // num_tiles_width,
+ )
+ return image
+
+
+@lru_cache(maxsize=1)
+def find_supported_resolutions(max_num_chunks: int, patch_size: SizeDict) -> torch.Tensor:
+ """
+ Computes all of the allowed resolutions for a fixed number of chunks
+ and patch_size. Useful for when dividing an image into chunks.
+
+ Args:
+ max_num_chunks (int): Maximum number of chunks for processing.
+ patch_size (int): Size of the side of the patch.
+
+ Returns:
+ torch.Tensor: List of possible resolutions as tuples (height, width).
+
+ Example:
+ >>> max_num_chunks = 5
+ >>> patch_size = 224
+ >>> find_supported_resolutions(max_num_chunks, patch_size)
+ tensor([(224, 896), (448, 448), (224, 224), (896, 224), (224, 672),
+ (672, 224), (224, 448), (448, 224)])
+
+ Given max_num_chunks=4, patch_size=224, it will create a dictionary:
+ {
+ 0.25: [(1, 4)],
+ 1.0: [(2, 2), (1, 1)],
+ 4.0: [(4, 1)],
+ 0.33: [(1, 3)],
+ 3.0: [(3, 1)],
+ 0.5: [(1, 2)],
+ 2.0: [(2, 1)]
+ }
+
+ and return the resolutions multiplied by the patch_size:
+ [(1*224, 4*224), (2*224, 2*224), ..., (2*224, 1*224)]
+ """
+ height, width = patch_size.height, patch_size.width
+ if height != width:
+ raise ValueError("`size` must be square.")
+
+ patch_size = height
+
+ asp_dict = defaultdict(list)
+ for chunk_size in range(max_num_chunks, 0, -1):
+ _factors = sorted(get_factors(chunk_size))
+ _asp_ratios = [(factor, chunk_size // factor) for factor in _factors]
+ for height, width in _asp_ratios:
+ ratio_float = height / width
+ asp_dict[ratio_float].append((height, width))
+
+ # get the resolutions multiplied by the patch_size
+ possible_resolutions = []
+ for key, value in asp_dict.items():
+ for height, depth in value:
+ possible_resolutions.append((height * patch_size, depth * patch_size))
+
+ return possible_resolutions
+
+
+def pad_to_best_fit(
+ images: "torch.Tensor",
+ target_size: Tuple[int, int],
+ background_color: Union[int, Tuple[int, int, int]] = 0,
+) -> "torch.Tensor":
+ """
+ Pads an image to fit the target size.
+
+ Args:
+ images (`np.ndarray`):
+ The images to pad.
+ background_color (`int` or `Tuple[int, int, int]`, *optional*, defaults to 0):
+ The color to use for the padding. Can be an integer for single channel or a
+ tuple of integers representing for multi-channel images. If passed as integer
+ in mutli-channel mode, it will default to `0` in subsequent channels.
+ Returns:
+ `torch.Tensor`: The padded images.
+ """
+
+ num_channels = images.shape[1] if len(images.shape) == 4 else images.shape[0]
+ if isinstance(background_color, int):
+ background_color = [background_color] + [0] * (num_channels - 1)
+ elif len(background_color) != num_channels:
+ raise ValueError(
+ f"background_color must have no more than {num_channels} elements to match the number of channels"
+ )
+
+ height, width = images.shape[-2:]
+ target_height, target_width = target_size
+ paste_x_right = target_width - width
+ paste_y_right = target_height - height
+ padded_images = F.pad(images, padding=[0, 0, paste_x_right, paste_y_right], fill=background_color)
+
+ return padded_images
+
+
+def get_best_fit(
+ image_size: Tuple[int, int],
+ possible_resolutions: torch.Tensor,
+ resize_to_max_canvas: bool = False,
+) -> Tuple[int, int]:
+ """
+ Determines the best canvas possible from a list of possible resolutions to, without distortion,
+ resize an image to.
+
+ For each possible resolution, calculates the scaling factors for
+ width and height, and selects the smallest one, which is the limiting side.
+ E.g. to match the canvas you can upscale height by 2x, and width by 1.5x,
+ therefore, the maximum upscaling you can do is min(2, 1.5) = 1.5.
+
+ If upscaling is possible (any of the scaling factors is greater than 1),
+ then picks the smallest upscaling factor > 1, unless resize_to_max_canvas is True.
+
+ If upscaling is not possible, then picks the largest scaling factor <= 1, i.e.
+ reduce downscaling as much as possible.
+
+ If there are multiple resolutions with the same max scale, we pick the one with the lowest area,
+ to minimize padding. E.g., the same image can be upscaled to 224x224 and 224x448, but the latter
+ has more padding.
+
+ Args:
+ image_size (Tuple[int, int]): A tuple containing the height and width of the image.
+ possible_resolutions (torch.Tensor): A tensor of shape (N, 2) where each
+ row represents a possible resolution (height, width).
+ resize_to_max_canvas (bool): If True, will return the largest upscaling resolution.
+
+ Returns:
+ List[int]: The best resolution [height, width] for the given image.
+
+ Example:
+ >>> image_size = (200, 300)
+ >>> possible_resolutions = torch.tensor([[224, 672],
+ ... [672, 224],
+ ... [224, 448],
+ ... [448, 224],
+ ... [224, 224]])
+ >>> get_best_fit(image_size, possible_resolutions)
+ [224, 448]
+
+ We have:
+ scale_w = tensor([2.2400, 0.7467, 1.4933, 0.7467, 0.7467])
+ scale_h = tensor([1.1200, 3.3600, 1.1200, 2.2400, 1.1200])
+ scales = tensor([1.1200, 0.7467, 1.1200, 0.7467, 0.7467])
+ Only one of the scales > 1:
+ upscaling_possible = tensor([1.1200, 1.1200])
+ smallest_rescale = tensor(1.1200)
+ So we pick the resolution with the smallest smallest area:
+ areas = tensor([150528, 100352]) # [672, 224], [224, 448]
+ optimal_canvas = tensor([224, 448])
+ """
+
+ original_height, original_width = image_size
+
+ # get all possible resolutions heights/widths
+ target_heights, target_widths = (
+ possible_resolutions[:, 0],
+ possible_resolutions[:, 1],
+ )
+
+ # get scaling factors to resize the image without distortion
+ scale_w = target_widths / original_width
+ scale_h = target_heights / original_height
+
+ # get the min scale between width and height (limiting side -> no distortion)
+ scales = torch.where(scale_h > scale_w, scale_w, scale_h)
+
+ # filter only scales that allow upscaling
+ upscaling_options = scales[scales >= 1]
+ if len(upscaling_options) > 0:
+ if resize_to_max_canvas:
+ selected_scale = torch.max(upscaling_options)
+ else:
+ selected_scale = torch.min(upscaling_options)
+ else:
+ # no upscaling possible,
+ # get the minimum downscaling (max scale for scales<1)
+ downscaling_options = scales[scales < 1]
+ selected_scale = torch.max(downscaling_options)
+
+ # get all resolutions that support this scaling factor,
+ # e.g. you can upscale to 224x224, 224x448, 224x672 without distortion
+ chosen_canvas = possible_resolutions[scales == selected_scale]
+
+ # if there are multiple resolutions,
+ # get the one with minimum area to reduce padding
+ if len(chosen_canvas) > 1:
+ areas = chosen_canvas[:, 0] * chosen_canvas[:, 1]
+ optimal_idx = torch.argmin(areas)
+ optimal_canvas = chosen_canvas[optimal_idx]
+ else:
+ optimal_canvas = chosen_canvas[0]
+
+ return tuple(optimal_canvas.tolist())
+
+
+@add_start_docstrings(
+ "Constructs a fast Llama4 image processor.",
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
+ """
+ max_patches (`int`, *optional*, defaults to 16):
+ The maximum number of patches to be extracted from the image.
+ Can be overridden by the `max_patches` parameter in the `preprocess` method.
+ resize_to_max_canvas (`bool`, *optional*, defaults to False):
+ Whether to resize the image to the maximum canvas size.
+ If True, picks the canvas the allows the largest resizing without distortion.
+ If False, downsample as little as possible, including no resizing at all,
+ but never upsample, unless the image is smaller than the patch size.
+ """,
+)
+class Llama4ImageProcessorFast(BaseImageProcessorFast):
+ resample = PILImageResampling.BILINEAR
+ image_mean = [0.5, 0.5, 0.5]
+ image_std = [0.5, 0.5, 0.5]
+ size = {"height": 336, "width": 336}
+ do_resize = True
+ do_rescale = True
+ do_normalize = True
+ do_convert_rgb = True
+ max_patches = 16
+ resize_to_max_canvas = False
+ valid_kwargs = Llama4ImageProcessorKwargs
+
+ def __init__(self, **kwargs: Unpack[Llama4ImageProcessorKwargs]):
+ super().__init__(**kwargs)
+
+ def rescale_and_normalize(
+ self,
+ images: "torch.Tensor",
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ image_mean: Union[float, List[float]],
+ image_std: Union[float, List[float]],
+ ) -> "torch.Tensor":
+ """
+ Rescale and normalize images.
+ Override to rescale and normalize the images in torch.bfloat16 as in the original implementation
+ """
+ if do_rescale and do_normalize:
+ images = images.to(dtype=torch.bfloat16) * rescale_factor
+ images = self.normalize(images, image_mean, image_std)
+ elif do_rescale:
+ images = images * rescale_factor
+ elif do_normalize:
+ images = self.normalize(images, image_mean, image_std)
+
+ return images
+
+ @add_start_docstrings(
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
+ """
+ max_patches (`int`, *optional*, defaults to 16):
+ The maximum number of patches to be extracted from the image.
+ Can be overridden by the `max_patches` parameter in the `preprocess` method.
+ resize_to_max_canvas (`bool`, *optional*, defaults to False):
+ Whether to resize the image to the maximum canvas size.
+ If True, picks the canvas the allows the largest resizing without distortion.
+ If False, downsample as little as possible, including no resizing at all,
+ but never upsample, unless the image is smaller than the patch size.
+ """,
+ )
+ def preprocess(self, images: ImageInput, **kwargs: Unpack[Llama4ImageProcessorKwargs]) -> BatchFeature:
+ return super().preprocess(images, **kwargs)
+
+ def _preprocess(
+ self,
+ images: List["torch.Tensor"],
+ size: SizeDict,
+ max_patches: int,
+ resize_to_max_canvas: bool,
+ interpolation: Optional["F.InterpolationMode"],
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ image_mean: Optional[Union[float, List[float]]],
+ image_std: Optional[Union[float, List[float]]],
+ return_tensors: Optional[Union[str, TensorType]],
+ **kwargs,
+ ) -> BatchFeature:
+ possible_resolutions = find_supported_resolutions(max_num_chunks=max_patches, patch_size=size)
+ possible_resolutions = torch.tensor(possible_resolutions)
+ # process images by batch, grouped by shape
+ grouped_images, grouped_images_index = group_images_by_shape(images)
+ grouped_processed_images = {}
+ grouped_aspect_ratios = {}
+ for shape, stacked_images in grouped_images.items():
+ image_size = stacked_images.shape[-2:]
+ target_size = get_best_fit(image_size, possible_resolutions, resize_to_max_canvas=resize_to_max_canvas)
+ # If target_size requires upscaling, we might want to limit the upscaling to max_upscaling_size
+ max_upscaling_size = None if resize_to_max_canvas else size.height
+ if max_upscaling_size is not None:
+ new_target_height = min(max(image_size[0], max_upscaling_size), target_size[0])
+ new_target_width = min(max(image_size[1], max_upscaling_size), target_size[1])
+ target_size_without_distortion = (new_target_height, new_target_width)
+
+ # resize to target_size while preserving aspect ratio
+ new_size_without_distortion = get_max_res_without_distortion(image_size, target_size_without_distortion)
+ new_size_without_distortion = SizeDict(
+ height=max(new_size_without_distortion[0], 1), width=max(new_size_without_distortion[1], 1)
+ )
+ processed_images = self.resize(
+ stacked_images,
+ new_size_without_distortion,
+ interpolation=interpolation,
+ )
+
+ # pad to target_size to be able to split into tiles
+ processed_images = pad_to_best_fit(processed_images, target_size)
+ processed_images = self.rescale_and_normalize(
+ processed_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
+ )
+
+ ratio_h, ratio_w = (
+ target_size[0] // size.height,
+ target_size[1] // size.height,
+ )
+ # split into tiles
+ processed_images = split_to_tiles(processed_images, ratio_h, ratio_w)
+ grouped_processed_images[shape] = processed_images
+ grouped_aspect_ratios[shape] = torch.tensor([[ratio_h, ratio_w]] * stacked_images.shape[0])
+
+ # add a global tile to the processed tile if there are more than one tile
+ if ratio_h * ratio_w > 1:
+ global_tiles = self.resize(
+ stacked_images,
+ size,
+ interpolation=interpolation,
+ )
+ global_tiles = self.rescale_and_normalize(
+ global_tiles, do_rescale, rescale_factor, do_normalize, image_mean, image_std
+ )
+ grouped_processed_images[shape] = torch.cat([processed_images, global_tiles.unsqueeze(1)], dim=1)
+ processed_images = reorder_images(grouped_processed_images, grouped_images_index)
+ aspect_ratios_list = reorder_images(grouped_aspect_ratios, grouped_images_index)
+
+ processed_images = torch.cat(processed_images, dim=0) if return_tensors else processed_images
+ aspect_ratios = torch.stack(aspect_ratios_list, dim=0) if return_tensors else aspect_ratios_list
+ return BatchFeature(
+ data={"pixel_values": processed_images, "aspect_ratios": aspect_ratios}, tensor_type=return_tensors
+ )
+
+
+__all__ = ["Llama4ImageProcessorFast"]
diff --git a/src/transformers/models/llama4/modeling_llama4.py b/src/transformers/models/llama4/modeling_llama4.py
new file mode 100644
index 0000000000..f511674dc0
--- /dev/null
+++ b/src/transformers/models/llama4/modeling_llama4.py
@@ -0,0 +1,1903 @@
+# coding=utf-8
+# Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+from dataclasses import dataclass
+from typing import Callable, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint
+
+from transformers.models.llama4.configuration_llama4 import Llama4VisionConfig
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...modeling_attn_mask_utils import AttentionMaskConverter
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_outputs import (
+ BaseModelOutput,
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ ModelOutput,
+)
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import (
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_torch_flex_attn_available,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_llama4 import Llama4Config, Llama4TextConfig
+
+
+if is_torch_flex_attn_available():
+ from torch.nn.attention.flex_attention import BlockMask
+
+ from ...integrations.flex_attention import make_flex_block_causal_mask
+
+logger = logging.get_logger(__name__)
+_CHECKPOINT_FOR_DOC = "meta-ai/Llama-4-17B"
+_CONFIG_FOR_DOC = "Llama4Config"
+
+
+class Llama4TextExperts(nn.Module):
+ def __init__(self, config: Llama4Config):
+ super().__init__()
+ self.num_experts = config.num_local_experts
+ self.intermediate_size = config.intermediate_size
+ self.hidden_size = config.hidden_size
+ self.expert_dim = self.intermediate_size
+ self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
+ self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size)))
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ """
+ This should really not be run on a single machine, as we are reaching compute bound:
+ - the inputs are expected to be "sorted" per expert already.
+ - the weights are viewed with another dim, to match num_expert, 1, shape * num_tokens, shape
+
+ Args:
+ hidden_states (torch.Tensor): (batch_size * token_num, hidden_size)
+ selected_experts (torch.Tensor): (batch_size * token_num, top_k)
+ routing_weights (torch.Tensor): (batch_size * token_num, top_k)
+ Returns:
+ torch.Tensor
+ """
+ hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)
+ gate_up = torch.bmm(hidden_states, self.gate_up_proj)
+ gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors
+ next_states = torch.bmm((up * self.act_fn(gate)), self.down_proj)
+ next_states = next_states.view(-1, self.hidden_size)
+ return next_states
+
+
+# Phi3MLP
+class Llama4TextMLP(nn.Module):
+ def __init__(self, config, intermediate_size=None):
+ super().__init__()
+
+ if intermediate_size is None:
+ intermediate_size = config.intermediate_size
+
+ self.config = config
+ self.gate_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False)
+ self.up_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False)
+ self.down_proj = nn.Linear(intermediate_size, config.hidden_size, bias=False)
+ self.activation_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.activation_fn(self.gate_proj(x)) * self.up_proj(x)
+ return self.down_proj(down_proj)
+
+
+class Llama4TextL2Norm(torch.nn.Module):
+ def __init__(self, dim: int = None, eps: float = 1e-6):
+ super().__init__()
+ self.eps = eps
+
+ def _norm(self, x):
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+ def forward(self, x):
+ return self._norm(x.float()).type_as(x)
+
+ def extra_repr(self):
+ return f"eps={self.eps}"
+
+
+class Llama4TextRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-5):
+ """
+ Llama4RMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+
+ def _norm(self, x):
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+ def forward(self, x):
+ output = self._norm(x.float()).type_as(x)
+ return output * self.weight
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.eps}"
+
+
+class Llama4TextMoe(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.top_k = config.num_experts_per_tok
+ self.hidden_dim = config.hidden_size
+ self.num_experts = config.num_local_experts
+ self.experts = Llama4TextExperts(config)
+ self.router = nn.Linear(config.hidden_size, config.num_local_experts, bias=False)
+ self.shared_expert = Llama4TextMLP(config)
+
+ def forward(self, hidden_states):
+ batch, seq_len, hidden_dim = hidden_states.shape
+ hidden_states = hidden_states.view(-1, self.hidden_dim)
+ router_logits = self.router(hidden_states).transpose(0, 1)
+ tokens_per_expert = batch * seq_len
+
+ router_top_value, router_indices = torch.topk(router_logits.transpose(0, 1), self.top_k, dim=1)
+ router_scores = (
+ torch.full_like(router_logits.transpose(0, 1), float("-inf"))
+ .scatter_(1, router_indices, router_top_value)
+ .transpose(0, 1)
+ )
+ # We do this to make sure we have -inf for non topK tokens before going through the !
+ # Here we are just creating a tensor to index each and every single one of the hidden states. Let s maybe register a buffer for this!
+ router_indices = (
+ torch.arange(tokens_per_expert, device=hidden_states.device).view(1, -1).expand(router_scores.size(0), -1)
+ )
+ router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype)
+
+ router_indices = router_indices.reshape(-1, 1).expand(-1, hidden_dim)
+ routed_in = torch.gather(
+ input=hidden_states,
+ dim=0,
+ index=router_indices,
+ ).to(hidden_states.device)
+ # we gather inputs corresponding to each expert based on the router indices
+ routed_in = routed_in * router_scores.reshape(-1, 1)
+ routed_out = self.experts(routed_in)
+ out = self.shared_expert(hidden_states)
+ # now that we finished expert computation -> we scatter add because we gathered previously
+ # we have to do this because we used all experts on all tokens. This is faster than the for loop, tho you are compute bound
+ # this scales a lot better if you do EP!
+ out.scatter_add_(dim=0, index=router_indices, src=routed_out.view(-1, hidden_dim))
+ return out, router_scores
+
+
+class Llama4TextRotaryEmbedding(nn.Module):
+ def __init__(self, config: Llama4TextConfig, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ self.rope_type = "llama3" if config.rope_scaling is not None else "default"
+
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ def _dynamic_frequency_update(self, position_ids, device):
+ """
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
+ 1 - growing beyond the cached sequence length (allow scaling)
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
+ """
+ seq_len = torch.max(position_ids) + 1
+ if seq_len > self.max_seq_len_cached: # growth
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
+ self.max_seq_len_cached = seq_len
+
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
+ # This .to() is needed if the model has been moved to a device after being initialized (because
+ # the buffer is automatically moved, but not the original copy)
+ self.original_inv_freq = self.original_inv_freq.to(device)
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
+ self.max_seq_len_cached = self.original_max_seq_len
+
+ @torch.no_grad()
+ def forward(self, x, position_ids):
+ if "dynamic" in self.rope_type:
+ self._dynamic_frequency_update(position_ids, device=x.device)
+ # Core RoPE block
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+ position_ids_expanded = position_ids[:, None, :].float()
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
+ device_type = x.device.type
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False):
+ freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2)
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # Convert to complex representation
+
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
+ freqs_cis = freqs_cis * self.attention_scaling
+ return freqs_cis
+
+
+def apply_rotary_emb(
+ xq: torch.Tensor,
+ xk: torch.Tensor,
+ freqs_cis: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
+ xq_out = torch.view_as_real(xq_ * freqs_cis[:, :, None, :]).flatten(3)
+ xk_out = torch.view_as_real(xk_ * freqs_cis[:, :, None, :]).flatten(3)
+ return xq_out.type_as(xq), xk_out.type_as(xk)
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) / math.sqrt(module.head_dim)
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class Llama4TextAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: Llama4TextConfig, layer_idx):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_attention_heads = config.num_attention_heads
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.num_key_value_heads = config.num_key_value_heads
+ self.scaling = self.head_dim**-0.5
+ self.attn_scale = config.attn_scale
+ self.floor_scale = config.floor_scale
+ self.attn_temperature_tuning = config.attn_temperature_tuning
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = True
+ self.use_rope = int((layer_idx + 1) % 4 != 0) # rope unused for dense layers
+ self.q_proj = nn.Linear(
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.k_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.v_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.o_proj = nn.Linear(
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
+ )
+ if self.config.use_qk_norm and self.use_rope:
+ self.qk_norm = Llama4TextL2Norm()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_value: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape)
+ key_states = self.k_proj(hidden_states).view(*input_shape, -1, self.head_dim)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ if self.use_rope: # the 16E model skips rope for long context on certain layers
+ query_states, key_states = apply_rotary_emb(
+ query_states, key_states, position_embeddings.to(query_states.device)
+ )
+
+ if hasattr(self, "qk_norm"): # the 128E model does not use qk_norm
+ query_states = self.qk_norm(query_states)
+ key_states = self.qk_norm(key_states)
+
+ # Use temperature tuning from https://arxiv.org/abs/2501.19399) to NoROPE layers
+ if self.attn_temperature_tuning and not self.use_rope:
+ attn_scales = (
+ torch.log(torch.floor((cache_position.float() + 1.0) / self.floor_scale) + 1.0) * self.attn_scale + 1.0
+ )
+ attn_scales = attn_scales.view((*input_shape, 1, 1))
+ query_states = (query_states * attn_scales).to(query_states.dtype)
+
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
+ logger.warning_once(
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ else:
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class Llama4TextDecoderLayer(nn.Module):
+ def __init__(self, config, layer_idx):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.self_attn = Llama4TextAttention(config, layer_idx)
+ self.use_chunked_attention = int((layer_idx + 1) % 4 != 0) # <=> use rope
+ self.is_moe_layer = layer_idx in config.moe_layers
+ if self.is_moe_layer: # the 128E model interleaves dense / sparse
+ self.feed_forward = Llama4TextMoe(config)
+ else:
+ self.feed_forward = Llama4TextMLP(config, intermediate_size=config.intermediate_size_mlp)
+
+ self.input_layernorm = Llama4TextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = Llama4TextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ self.layer_idx = layer_idx
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ chunk_causal_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ output_router_logits: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # use local attention mask for ROPE layers
+ if self.use_chunked_attention and chunk_causal_mask is not None:
+ attention_mask = chunk_causal_mask
+
+ # Self Attention
+ attention_states, self_attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ hidden_states = residual + attention_states
+
+ # Fully Connected
+ residual = hidden_states
+
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.feed_forward(hidden_states)
+ if self.is_moe_layer:
+ hidden_states, router_logits = hidden_states
+ else:
+ router_logits = None
+ hidden_states = residual + hidden_states.view(residual.shape)
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if output_router_logits:
+ outputs += (router_logits,)
+
+ return outputs
+
+
+LLAMA4_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`Llama4Config`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ "The bare Llama4 Model outputting raw hidden-states without any specific head on top.",
+ LLAMA4_START_DOCSTRING,
+)
+class Llama4PreTrainedModel(PreTrainedModel):
+ config_class = Llama4Config
+ supports_gradient_checkpointing = True
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+ _supports_cache_class = True
+ _supports_quantized_cache = True
+ _supports_static_cache = True
+ _supports_attention_backend = True
+
+ def _init_weights(self, module):
+ std = (
+ self.config.initializer_range
+ if hasattr(self.config, "initializer_range")
+ else self.config.text_config.initializer_range
+ )
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+LLAMA4_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
+
+ Two formats are allowed:
+ - a [`~cache_utils.Cache`] instance, see our
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
+ cache format.
+
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
+ legacy cache format will be returned.
+
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
+ of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
+ the complete sequence length.
+"""
+
+
+@add_start_docstrings(
+ "The bare Llama4 Model outputting raw hidden-states without any specific head on top.",
+ LLAMA4_START_DOCSTRING,
+)
+class Llama4TextModel(Llama4PreTrainedModel):
+ _no_split_modules = ["Llama4TextDecoderLayer"]
+ base_model_prefix = "model"
+ config_class = Llama4TextConfig
+
+ def __init__(self, config: Llama4TextConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [Llama4TextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = Llama4TextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = Llama4TextRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ @add_start_docstrings_to_model_forward(LLAMA4_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids.to(self.embed_tokens.weight.device))
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache()
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask, chunk_causal_mask = self._update_causal_mask(
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+ )
+
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ freq_cis = self.rotary_emb(hidden_states, position_ids)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ causal_mask,
+ chunk_causal_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ cache_position,
+ freq_cis,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ chunk_causal_mask=chunk_causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=freq_cis,
+ **flash_attn_kwargs,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ output = BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values if use_cache else None,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+ return output if return_dict else output.to_tuple()
+
+ def _update_causal_mask(
+ self,
+ attention_mask: torch.Tensor,
+ input_tensor: torch.Tensor,
+ cache_position: torch.Tensor,
+ past_key_values: Cache,
+ output_attentions: bool = False,
+ chunked_attention_mask=None,
+ ):
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and (attention_mask == 0.0).any():
+ return attention_mask, attention_mask # flash does not support chunked attn TODO support flash
+ return None, None
+
+ if self.config._attn_implementation not in ["sdpa", "flex_attention", "eager"]:
+ return None, None
+
+ sequence_length = input_tensor.shape[1]
+ cache_position = cache_position.to(self.device)
+ attention_chunk_size = self.config.attention_chunk_size
+
+ first_cache_position = cache_position[0]
+ last_cache_position = cache_position[-1]
+
+ # to avoid graph break, we introduce this hack
+ cond1 = first_cache_position >= attention_chunk_size
+ cond2 = (first_cache_position < attention_chunk_size) & (
+ first_cache_position + sequence_length > attention_chunk_size
+ )
+
+ key_length = torch.where(
+ cond1,
+ attention_chunk_size + sequence_length - 1,
+ torch.where(cond2, first_cache_position + sequence_length, attention_chunk_size),
+ )
+
+ if past_key_values is not None and past_key_values.is_compileable:
+ target_length = past_key_values.get_max_cache_shape
+ else:
+ target_length = attention_mask.shape[-1] if attention_mask is not None else sequence_length
+
+ if self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ offsets = (first_cache_position, max(last_cache_position - key_length, 0))
+ chunked_attention_mask = make_flex_block_causal_mask(
+ attention_mask, self.config.attention_chunk_size, sequence_length, key_length, offsets=offsets
+ )
+ attention_mask = make_flex_block_causal_mask(
+ attention_mask,
+ query_length=sequence_length,
+ key_length=past_key_values.get_max_cache_shape(),
+ offsets=None if sequence_length != 1 else (first_cache_position, 0),
+ )
+ return attention_mask, chunked_attention_mask
+ if isinstance(attention_mask, BlockMask):
+ return attention_mask, chunked_attention_mask
+
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
+ dtype, device = input_tensor.dtype, input_tensor.device
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask,
+ sequence_length=sequence_length,
+ target_length=target_length,
+ dtype=dtype,
+ device=device,
+ cache_position=cache_position,
+ batch_size=input_tensor.shape[0],
+ )
+ if target_length > self.config.attention_chunk_size:
+ chunked_attention_mask = self.create_chunked_attention_mask(
+ self.config.attention_chunk_size,
+ start=first_cache_position,
+ end=first_cache_position + key_length,
+ device=device,
+ )
+ chunked_attention_mask = chunked_attention_mask & attention_mask
+ if sequence_length == 1:
+ chunked_attention_mask = chunked_attention_mask[-1:]
+ if self.config._attn_implementation == "eager":
+ chunked_attention_mask = (
+ chunked_attention_mask[None, None, :, :]
+ .to(dtype)
+ .masked_fill(chunked_attention_mask, torch.finfo(dtype).min)
+ )
+
+ if (
+ self.config._attn_implementation == "sdpa"
+ and attention_mask is not None
+ and attention_mask.device.type in ["cuda", "xpu"]
+ and attention_mask.ndim == 4
+ and not output_attentions # Only unmask for 4d masks
+ ):
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
+ # chunked_attention_mask = AttentionMaskConverter._unmask_unattended(chunked_attention_mask, min_dtype)
+
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
+ if self.config._attn_implementation == "sdpa" and chunked_attention_mask is not None:
+ chunked_attention_mask = chunked_attention_mask.bool()
+ causal_mask = causal_mask.bool()
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
+ attention_mask,
+ inputs_embeds=input_tensor,
+ past_key_values_length=first_cache_position,
+ is_training=self.training,
+ ):
+ causal_mask = None
+ return causal_mask, chunked_attention_mask
+
+ def create_chunked_attention_mask(
+ self, attention_chunk_size: int, start: int, end: int, device: torch.device
+ ) -> torch.Tensor:
+ """
+ Generate the following:
+
+ 'What' : 0 ■ ⬚ ⬚ ⬚ ⬚ ⬚ |
+ '▁is' : 1 ■ ■ ⬚ ⬚ ⬚ ⬚ |
+ '▁ch' : 2 ■ ■ ■ ⬚ ⬚ ⬚ |
+ 'unked' : 3 ⬚ ⬚ ⬚ ■ ⬚ ⬚ |
+ '▁attention': 4 ⬚ ⬚ ⬚ ■ ■ ⬚ |
+ '?' : 5 ⬚ ⬚ ⬚ ■ ■ ■ |
+
+ If the chunk size is 3.
+ This can just be appplied over the already created attention mask
+ """
+ block_pos = torch.abs(
+ (torch.arange(start, end).unsqueeze(0) // attention_chunk_size)
+ - (torch.arange(start, end).unsqueeze(1) // attention_chunk_size)
+ )
+ token_pos = torch.arange(start, end).unsqueeze(0) - torch.arange(start, end).unsqueeze(1)
+ mask = (block_pos == 0) & (token_pos <= 0)
+ return mask.to(device)
+
+ @staticmethod
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ device (`torch.device`):
+ The device to plcae the 4D attention mask on.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.to(device).reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(device)
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+
+class Llama4ForCausalLM(Llama4PreTrainedModel, GenerationMixin):
+ base_model_prefix = "language_model"
+ _tied_weights_keys = ["lm_head.weight"]
+ _tp_plan = {"lm_head": "colwise_rep"}
+ config_class = Llama4TextConfig
+
+ def __init__(self, config: Llama4TextConfig):
+ super().__init__(config)
+ self.model = Llama4TextModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ @add_start_docstrings_to_model_forward(LLAMA4_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, Llama4ForCausalLM
+
+ >>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf")
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf")
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@dataclass
+class Llama4CausalLMOutputWithPast(ModelOutput):
+ """
+ Base class for Llava causal language model (or autoregressive) outputs.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ past_key_values: Optional[List[torch.FloatTensor]] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+ image_hidden_states: Optional[torch.FloatTensor] = None
+
+
+class Llama4VisionMLP2(torch.nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.fc1 = nn.Linear(self.intermediate_size, config.projector_input_dim, bias=False)
+ self.fc2 = nn.Linear(config.projector_output_dim, config.projector_output_dim, bias=False)
+ self.activation_fn = nn.GELU() # ACT2FN[config.hidden_act]
+ self.dropout = config.projector_dropout
+
+ def forward(self, hidden_states):
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
+ return self.activation_fn(self.fc2(hidden_states))
+
+
+class Llama4MultiModalProjector(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.linear_1 = nn.Linear(
+ config.vision_config.vision_output_dim,
+ config.text_config.hidden_size,
+ bias=False,
+ )
+
+ def forward(self, image_features):
+ hidden_states = self.linear_1(image_features)
+ return hidden_states
+
+
+def pixel_shuffle(input_tensor, shuffle_ratio):
+ # input_tensor: [batch_size, num_patches, channels]
+ batch_size, num_patches, channels = input_tensor.shape
+ patch_size = int(math.sqrt(num_patches))
+
+ input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1)
+ batch_size, height, width, channels = input_tensor.size()
+
+ reshaped_tensor = input_tensor.view(batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio))
+ reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
+
+ reshaped_tensor = reshaped_tensor.view(
+ batch_size, int(height * shuffle_ratio), int(width * shuffle_ratio), int(channels / (shuffle_ratio**2))
+ )
+ reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous()
+
+ output_tensor = reshaped_tensor.view(batch_size, -1, reshaped_tensor.shape[-1])
+ return output_tensor
+
+
+class Llama4VisionPixelShuffleMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.pixel_shuffle_ratio = config.pixel_shuffle_ratio
+ self.inner_dim = int(config.projector_input_dim // (self.pixel_shuffle_ratio**2))
+ self.output_dim = config.projector_output_dim
+ self.mlp = Llama4VisionMLP2(config)
+
+ def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:
+ encoded_patches = pixel_shuffle(encoded_patches, self.pixel_shuffle_ratio)
+ return self.mlp(encoded_patches)
+
+
+LLAVA_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`LlavaConfig`] or [`LlavaVisionConfig`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+# TODO there is a different RoPE for vision encoder, defined as below
+def reshape_for_broadcast(freqs_ci: torch.Tensor, query: torch.Tensor):
+ ndim = query.ndim
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(query.shape)]
+ return freqs_ci.view(*shape)
+
+
+def vision_apply_rotary_emb(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ freqs_ci: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ query_ = torch.view_as_complex(query.float().reshape(*query.shape[:-1], -1, 2))
+ key_ = torch.view_as_complex(key.float().reshape(*key.shape[:-1], -1, 2))
+ freqs_ci = reshape_for_broadcast(freqs_ci=freqs_ci, query=query_) # freqs_ci[:,:,None,:]
+ freqs_ci = freqs_ci.to(query_.device)
+ query_out = torch.view_as_real(query_ * freqs_ci).flatten(3)
+ key_out = torch.view_as_real(key_ * freqs_ci).flatten(3)
+ return query_out.type_as(query), key_out.type_as(key) # but this drops to 8e-3
+
+
+class Llama4VisionAttention(nn.Module):
+ def __init__(self, config: Llama4VisionConfig):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = config.hidden_size // config.num_attention_heads
+ self.num_key_value_groups = 1
+ self.attention_dropout = config.attention_dropout
+
+ self.q_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=True)
+ self.k_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=True)
+ self.v_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=True)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.embed_dim, bias=True)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ freqs_ci: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Cache] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape)
+ key_states = self.k_proj(hidden_states).view(hidden_shape)
+ value_states = self.v_proj(hidden_states).view(hidden_shape)
+
+ query_states, key_states = vision_apply_rotary_emb(query_states, key_states, freqs_ci=freqs_ci)
+
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ attention_interface: Callable = eager_attention_forward
+ # flex disable because breaks on TP 8, embed is 88 not power of 2
+ if self.config._attn_implementation not in ["eager", "flex_attention"]:
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
+ logger.warning_once(
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ else:
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ None,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=None,
+ is_causal=False, # HAS TO BE ENFORCED
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class Llama4VisionMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.activation_fn = nn.GELU() # ACT2FN[config.hidden_act]
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=True)
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=True)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class Llama4VisionEncoderLayer(nn.Module):
+ def __init__(self, config: Llama4VisionConfig):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = Llama4VisionAttention(config)
+ self.mlp = Llama4VisionMLP(config)
+
+ self.input_layernorm = nn.LayerNorm(config.hidden_size)
+ self.post_attention_layernorm = nn.LayerNorm(config.hidden_size)
+
+ def forward(
+ self,
+ hidden_state: torch.Tensor,
+ freqs_ci: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = None,
+ ):
+ # Self Attention
+ residual = hidden_state
+
+ hidden_state = self.input_layernorm(hidden_state)
+
+ hidden_state, attn_weights = self.self_attn(
+ hidden_state,
+ freqs_ci=freqs_ci,
+ attention_mask=attention_mask,
+ )
+ hidden_state = residual + hidden_state
+
+ # Feed forward
+ residual = hidden_state
+ hidden_state = self.post_attention_layernorm(hidden_state)
+ hidden_state = self.mlp(hidden_state)
+ hidden_state = residual + hidden_state
+
+ outputs = (hidden_state,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+class Llama4VisionEncoder(nn.Module):
+ """
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
+ [`Llama4VisionEncoderLayer`].
+
+ Args:
+ config: Llama4VisionConfig
+ """
+
+ def __init__(self, config: Llama4VisionConfig):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList([Llama4VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+ self.config = config
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ freqs_ci: torch.Tensor, # TODO move this to an attribute instead of keeping it around
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutput]:
+ r"""
+ Args:
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ for encoder_layer in self.layers:
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ encoder_layer.__call__,
+ hidden_states,
+ attention_mask,
+ output_attentions,
+ )
+ else:
+ layer_outputs = encoder_layer(
+ hidden_state=hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ freqs_ci=freqs_ci,
+ )
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ hidden_states = layer_outputs[0]
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+ )
+
+
+class Llama4UnfoldConvolution(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ kernel_size = config.patch_size
+ if isinstance(kernel_size, int):
+ kernel_size = (kernel_size, kernel_size)
+ self.unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=config.patch_size)
+ self.linear = nn.Linear(
+ config.num_channels * kernel_size[0] * kernel_size[1],
+ config.hidden_size,
+ bias=False,
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.unfold(hidden_states)
+ hidden_states = hidden_states.permute(0, 2, 1)
+ hidden_states = self.linear(hidden_states)
+ return hidden_states
+
+
+class Llama4VisionRotaryEmbedding(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ idx = config.image_size // config.patch_size
+ img_idx = torch.arange(idx**2, dtype=torch.int32).reshape(idx**2, 1)
+ img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
+ img_idx[-1, -1] = -2 # ID_CLS_TOKEN
+ frequencies_x = img_idx % idx # get the coordinates of the 2d matrix along x
+ frequencies_y = img_idx // idx # get the coordinates of the 2d matrix along y
+ freq_dim = config.hidden_size // config.num_attention_heads // 2
+ rope_freq = 1.0 / (config.rope_theta ** (torch.arange(0, freq_dim, 2)[: (freq_dim // 2)].float() / freq_dim))
+ freqs_x = ((frequencies_x + 1)[..., None] * rope_freq[None, None, :]).repeat_interleave(2, dim=-1)
+ freqs_y = ((frequencies_y + 1)[..., None] * rope_freq[None, None, :]).repeat_interleave(2, dim=-1)
+ freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2]
+ freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0)
+ freq_cis = torch.view_as_complex(torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1))
+ self.freqs_ci = freq_cis # idx**2, idx**2, idx * 2
+
+ def forward(self, hidden_states):
+ return self.freqs_ci.to(hidden_states.device)
+
+
+class Llama4VisionModel(Llama4PreTrainedModel):
+ base_model_prefix = "vision_model"
+ _no_split_modules = ["Llama4VisionAttention"]
+ config_class = Llama4VisionConfig
+
+ def __init__(self, config: Llama4VisionConfig):
+ super().__init__(config)
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+ self.hidden_size = config.hidden_size
+ self.num_channels = config.num_channels
+
+ self.num_patches = (self.image_size // self.patch_size) ** 2 + 1
+ self.scale = config.hidden_size**-0.5
+
+ self.patch_embedding = Llama4UnfoldConvolution(config)
+
+ self.class_embedding = nn.Parameter(self.scale * torch.randn(self.hidden_size))
+ self.positional_embedding_vlm = nn.Parameter(self.scale * torch.randn(self.num_patches, self.hidden_size))
+ self.rotary_embedding = Llama4VisionRotaryEmbedding(config)
+
+ # layer norms
+ self.layernorm_pre = nn.LayerNorm(self.hidden_size)
+ self.layernorm_post = nn.LayerNorm(self.hidden_size)
+
+ # encoders
+ self.model = Llama4VisionEncoder(config)
+ self.vision_adapter = Llama4VisionPixelShuffleMLP(config)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ """
+ This function is used to fetch the first embedding layer to activate grads on inputs.
+ """
+ return self.patch_embedding
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]:
+ r"""
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, MllamaVisionModel
+
+ >>> checkpoint = "meta-llama/Llama-3.2-11B-Vision"
+ >>> model = MllamaVisionModel.from_pretrained(checkpoint)
+ >>> processor = AutoProcessor.from_pretrained(checkpoint)
+
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+ >>> inputs = processor(images=image, return_tensors="pt")
+
+ >>> output = model(**inputs)
+
+ >>> print(output.last_hidden_state.shape)
+ torch.Size([1, 1, 4, 1025, 7680])
+ ```
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # num_concurrent_media and num_chunks are both currently 1
+ batch_size_times_num_tiles, num_channels, height, width = pixel_values.shape
+ num_concurrent_media = 1
+ num_chunks = 1
+ hidden_state = self.patch_embedding(pixel_values)
+ _, num_patches, hidden_dim = hidden_state.shape
+
+ # Add cls token
+ hidden_state = hidden_state.reshape(
+ batch_size_times_num_tiles * num_concurrent_media * num_chunks, num_patches, hidden_dim
+ )
+ class_embedding = self.class_embedding.expand(hidden_state.shape[0], 1, hidden_state.shape[-1])
+ hidden_state = torch.cat([hidden_state, class_embedding], dim=1)
+ num_patches += 1
+
+ # Position embeddings
+ hidden_state = hidden_state.reshape(
+ batch_size_times_num_tiles * num_concurrent_media, num_chunks, num_patches, hidden_dim
+ )
+ positional_embedding = self.positional_embedding_vlm.to(dtype=hidden_state.dtype, device=hidden_state.device)
+ hidden_state = hidden_state + positional_embedding
+
+ hidden_state = self.layernorm_pre(hidden_state)
+
+ hidden_state = hidden_state.view(batch_size_times_num_tiles, -1, hidden_dim)
+ freqs_ci = self.rotary_embedding(pixel_values)
+
+ output = self.model(
+ hidden_state,
+ attention_mask=None,
+ output_hidden_states=output_hidden_states,
+ output_attentions=output_attentions,
+ freqs_ci=freqs_ci,
+ )
+
+ hidden_state = output.last_hidden_state
+
+ hidden_state = self.layernorm_post(hidden_state)
+
+ hidden_state = hidden_state[:, :-1, :]
+
+ # now, we use Llama4VisionPixelShuffle + mlp to project embeddings
+ hidden_state = self.vision_adapter(hidden_state)
+
+ hidden_states = output.hidden_states if output_hidden_states else None
+
+ if output_attentions:
+ attentions = output[2]
+ else:
+ attentions = None
+
+ if not return_dict:
+ return tuple(v for v in [hidden_state, hidden_states, attentions] if v is not None)
+
+ return BaseModelOutput(
+ last_hidden_state=hidden_state,
+ hidden_states=hidden_states,
+ attentions=attentions,
+ )
+
+
+class Llama4ForConditionalGeneration(Llama4PreTrainedModel, GenerationMixin):
+ _tp_plan = {}
+ base_model_prefix = ""
+ config_class = Llama4Config
+ _supports_flex_attn = True
+
+ def __init__(self, config: Llama4Config):
+ super().__init__(config)
+ self.vision_model = Llama4VisionModel(config.vision_config)
+
+ self.multi_modal_projector = Llama4MultiModalProjector(config)
+ self.language_model = Llama4ForCausalLM(config.text_config)
+ self.vocab_size = config.text_config.vocab_size
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
+
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
+
+ def get_output_embeddings(self):
+ return self.language_model.get_output_embeddings()
+
+ def set_output_embeddings(self, new_embeddings):
+ self.language_model.set_output_embeddings(new_embeddings)
+
+ def set_decoder(self, decoder):
+ self.language_model.set_decoder(decoder)
+
+ def get_decoder(self):
+ return self.language_model.get_decoder()
+
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ vision_feature_layer: Union[int, List[int]],
+ vision_feature_select_strategy: str,
+ **kwargs,
+ ):
+ """
+ Obtains image last hidden states from the vision tower and apply al projection.
+
+ Args:
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
+ The tensors corresponding to the input images.
+ vision_feature_layer (`Union[int, List[int]]`):
+ The index of the layer to select the vision feature. If multiple indices are provided,
+ the vision feature of the corresponding indices will be concatenated to form the
+ vision features.
+ vision_feature_select_strategy (`str`):
+ The feature selection strategy used to select the vision feature from the vision backbone.
+ Can be one of `"default"` or `"full"`
+ Returns:
+ image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
+ """
+ if vision_feature_select_strategy not in ["default", "full"]:
+ raise ValueError(f"Unexpected select feature strategy: {self.vision_feature_select_strategy}")
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
+ image_outputs = self.vision_model(pixel_values, output_hidden_states=False, **kwargs)
+ hidden_state = image_outputs.last_hidden_state
+ return hidden_state
+
+ @replace_return_docstrings(output_type=Llama4CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ vision_feature_layer: Optional[Union[int, List[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ image_sizes: torch.Tensor = None,
+ **lm_kwargs,
+ ) -> Union[Tuple, Llama4CausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
+
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, LlavaForConditionalGeneration
+
+ >>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
+ >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
+
+ >>> prompt = "USER: \nWhat's the content of the image? ASSISTANT:"
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(**inputs, max_new_tokens=15)
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed"
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ vision_feature_layer = (
+ vision_feature_layer
+ if vision_feature_layer is not None
+ else self.config.vision_config.vision_feature_layer
+ )
+ vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_config.vision_feature_select_strategy
+ )
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if pixel_values is not None and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ if pixel_values is not None:
+ image_features = self.get_image_features(
+ pixel_values=pixel_values,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ image_sizes=image_sizes,
+ )
+ original_inputs_embeds_shape = inputs_embeds.shape
+
+ vision_flat = image_features.view(-1, image_features.size(-1))
+ projected_vision_flat = self.multi_modal_projector(vision_flat)
+
+ special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
+ final_mask = special_image_mask.to(inputs_embeds.device)
+ inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1))
+
+ final_mask_1d = final_mask[..., 0].reshape(-1)
+ num_tokens_to_fill = final_mask_1d.sum()
+
+ if num_tokens_to_fill != projected_vision_flat.size(0):
+ raise ValueError(
+ f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, "
+ f"but multi_modal_projector returned {projected_vision_flat.size(0)}"
+ )
+
+ expanded_mask = final_mask_1d.unsqueeze(-1).expand(-1, inputs_embeds.size(-1))
+ inputs_embeds.masked_scatter_(expanded_mask, projected_vision_flat)
+
+ inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape)
+
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ logits_to_keep=logits_to_keep,
+ **lm_kwargs,
+ )
+
+ logits = outputs[0]
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ if attention_mask is not None:
+ # we use the input attention mask to shift the logits and labels, because it is 2D.
+ # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
+ shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
+ shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
+ shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
+ else:
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = nn.CrossEntropyLoss()
+ loss = loss_fct(
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
+ )
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return Llama4CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ inputs_embeds=None,
+ pixel_values=None,
+ attention_mask=None,
+ cache_position=None,
+ logits_to_keep=None,
+ **kwargs,
+ ):
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
+
+ model_inputs = self.language_model.prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ logits_to_keep=logits_to_keep,
+ **kwargs,
+ )
+
+ if cache_position[0] == 0:
+ # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
+ # Otherwise we need pixel values to be passed to model
+ model_inputs["pixel_values"] = pixel_values
+
+ return model_inputs
+
+ @staticmethod
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ device (`torch.device`):
+ The device to place the 4D attention mask on.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+
+__all__ = [
+ "Llama4PreTrainedModel",
+ "Llama4TextModel",
+ "Llama4VisionModel",
+ "Llama4ForCausalLM",
+ "Llama4ForConditionalGeneration",
+]
diff --git a/src/transformers/models/llama4/processing_llama4.py b/src/transformers/models/llama4/processing_llama4.py
new file mode 100644
index 0000000000..0ca4a44c5e
--- /dev/null
+++ b/src/transformers/models/llama4/processing_llama4.py
@@ -0,0 +1,275 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import List, Optional, Union
+
+from transformers.processing_utils import (
+ ImagesKwargs,
+ ProcessingKwargs,
+ ProcessorMixin,
+ Unpack,
+)
+from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
+
+from ...image_processing_utils import BatchFeature
+from ...image_utils import (
+ ImageInput,
+ make_flat_list_of_images,
+)
+
+
+class Llama4ImagesKwargs(ImagesKwargs, total=False):
+ max_patches: Optional[int]
+ resize_to_max_canvas: Optional[bool]
+
+
+class Llama4ProcessorKwargs(ProcessingKwargs, total=False):
+ images_kwargs: Llama4ImagesKwargs
+ _defaults = {
+ "text_kwargs": {
+ "padding_side": "left",
+ },
+ }
+
+
+chat_template = "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %} \n {%- if messages[0]['content'] is string %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- else %}\n {#- FIXME: The processor requires an array, always. #}\n {%- set system_message = messages[0]['content'][0]['text']|trim %}\n {%- endif %}\n {%- set messages = messages[1:] %}\n {%- set user_supplied_system_message = true %}\n{%- else %}\n {%- set system_message = \"\" %}\n {%- set user_supplied_system_message = false %}\n{%- endif %}\n\n{#- System message if the user supplied one #}\n{%- if user_supplied_system_message %}\n {{- \"<|header_start|>system<|header_end|>\n\n\" }}\n {%- if tools is not none %}\n {{- \"Environment: ipython\n\" }}\n {%- endif %}\n {%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\n\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\n\n\" }}\n {%- endfor %}\n {%- endif %}\n {{- system_message }}\n {{- \"<|eot|>\" }}\n{%- endif %}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|header_start|>user<|header_end|>\n\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\n\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\n\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\n\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }}\n {%- if message['content'] is string %}\n {{- message['content'] }}\n {%- else %}\n {%- for content in message['content'] %}\n {%- if content['type'] == 'image' %}\n {{- '<|image|>' }}\n {%- elif content['type'] == 'text' %}\n {{- content['text'] }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- \"<|eot|>\" }}\n {%- elif 'tool_calls' in message and message.tool_calls|length > 0 %}\n {{- '<|header_start|>assistant<|header_end|>\n\n' -}}\n {{- '<|python_start|>' }}\n {%- if message['content'] is string %}\n {{- message['content'] }}\n {%- else %}\n {%- for content in message['content'] %}\n {%- if content['type'] == 'image' %}\n {{- '<|image|>' }}\n {%- elif content['type'] == 'text' %}\n {{- content['text'] }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '<|python_end|>' }}\n {%- for tool_call in message.tool_calls %}\n {{- '{\"name\": \"' + tool_call.function.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.function.arguments | tojson }}\n {{- \"}\" }}\n {%- endfor %}\n {{- \"<|eot|>\" }}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|header_start|>ipython<|header_end|>\n\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|header_start|>assistant<|header_end|>\n\n' }}\n{%- endif %}\n"
+
+
+class Llama4Processor(ProcessorMixin):
+ r"""
+ Constructs a Llama4 processor which wraps a [`AutoImageProcessor`] and
+ [`PretrainedTokenizerFast`] tokenizer into a single processor that inherits both the image processor and
+ tokenizer functionalities. See the [`~Llama4Processor.__call__`] and [`~Llama4Processor.decode`] for more information.
+ Args:
+ image_processor ([`AutoImageProcessor`], *optional*):
+ The image processor is a required input.
+ tokenizer ([`PreTrainedTokenizer`, `PreTrainedTokenizerFast`], *optional*):
+ The tokenizer is a required input.
+ patch_size (`int`, *optional*, defaults to 28):
+ The size of image patches for tokenization.
+ img_size (`int`, *optional*, defaults to 364):
+ The size of the image to be tokenized. This should correspond to the size given to the image processor.
+ image_token (`str`, *optional*, defaults to `"<|image|>"`):
+ The token to be used to represent an image in the text.
+ downsample_factor (`int`, *optional*, defaults to 1):
+ The factor by which to scale the patch size.
+ start_of_img_token (`str`, *optional*, defaults to `"<|START_OF_IMG|>"`):
+ The token to be used to represent the start of an image in the text.
+ end_of_img_token (`str`, *optional*, defaults to `"<|END_OF_IMG|>"`):
+ The token to be used to represent the end of an image in the text.
+ img_patch_token (`str`, *optional*, defaults to `"<|IMG_PATCH|>"`):
+ The token to be used to represent an image patch in the text.
+ img_line_break_token (`str`, *optional*, defaults to `"<|IMG_LINE_BREAK|>"`):
+ The token to be used to represent a line break in the text.
+ tile_token (`str`, *optional*, defaults to `"TILE"`):
+ The token to be used to represent an image patch in the text.
+ tile_global_token (`str`, *optional*, defaults to `"TILE_GLOBAL"`):
+ The token to be used to represent the cover image in the text.
+ chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
+ in a chat into a tokenizable string.
+ """
+
+ attributes = ["image_processor", "tokenizer"]
+ valid_kwargs = [
+ "chat_template",
+ "image_token",
+ "patch_size",
+ "img_size",
+ "downsample_factor",
+ "start_of_img_token",
+ "end_of_img_token",
+ "img_patch_token",
+ "img_line_break_token",
+ "tile_token",
+ "tile_global_token",
+ ]
+ image_processor_class = "AutoImageProcessor"
+ tokenizer_class = "AutoTokenizer"
+
+ def __init__(
+ self,
+ image_processor=None,
+ tokenizer=None,
+ patch_size: int = 14,
+ pixel_shuffle_ratio: float = 0.5,
+ fake_image_token="<|image|>",
+ image_token="<|image|>",
+ start_of_image_token="<|image_start|>",
+ end_of_image_token="<|image_end|>",
+ patch_token="<|patch|>",
+ tile_x_separator_token="<|tile_x_separator|>",
+ tile_y_separator_token="<|tile_y_separator|>",
+ chat_template=chat_template,
+ **kwargs,
+ ):
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
+
+ self.downsample_ratio = int(round(1.0 / (pixel_shuffle_ratio**2)))
+ self.patch_size = patch_size
+
+ self.fake_image_token = fake_image_token
+ self.image_token = image_token
+ self.start_of_img_token = start_of_image_token
+ self.end_of_img_token = end_of_image_token
+ self.img_patch_token = patch_token
+ self.tile_token = tile_x_separator_token
+ self.tile_global_token = tile_y_separator_token
+
+ def _prompt_split_image(self, aspect_ratio, num_patches_per_chunk):
+ """
+ Create a structured string representation of image tokens
+
+ Args:
+ num_patches: Number of patches in the image
+
+ Returns:
+ String with appropriate image tokens
+ """
+ img_string = "<|image_start|>"
+ ratio_h, ratio_w = aspect_ratio
+ if ratio_h * ratio_w > 1:
+ for yy in range(ratio_h):
+ for xx in range(ratio_w):
+ img_string += "<|patch|>" * num_patches_per_chunk
+ if xx < ratio_w - 1:
+ img_string += "<|tile_x_separator|>"
+
+ img_string += "<|tile_y_separator|>"
+ img_string += "<|image|>"
+ img_string += "<|patch|>" * num_patches_per_chunk
+ img_string += "<|image_end|>"
+
+ return img_string
+
+ def __call__(
+ self,
+ images: Optional[ImageInput] = None,
+ text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
+ audio=None,
+ videos=None,
+ **kwargs: Unpack[Llama4ProcessorKwargs],
+ ) -> BatchFeature:
+ """
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
+ and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] to encode the text.
+ To prepare the vision inputs, this method forwards the `images` and `kwargs` arguments to
+ Llama4ImageProcessor's [`~Llama4ImageProcessor.__call__`] if `images` is not `None`.
+
+ Args:
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. Both channels-first and channels-last formats are supported.
+ text (`str`, `List[str]`, `List[List[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors of a particular framework. Acceptable values are:
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
+ `None`).
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
+ """
+ if text is None:
+ raise ValueError("You have to specify text.")
+
+ output_kwargs = self._merge_kwargs(
+ Llama4ProcessorKwargs,
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
+ **kwargs,
+ )
+
+ if not isinstance(text, (list, tuple)):
+ text = [text]
+
+ # Process images
+ image_inputs = {}
+ if images is not None:
+ images = make_flat_list_of_images(images)
+ image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
+ image_height, image_width = image_inputs["pixel_values"][0].shape[-2:]
+ num_patches_per_chunk = int(
+ (image_height // self.patch_size) * (image_width // self.patch_size) // self.downsample_ratio
+ )
+ aspect_ratios = image_inputs.pop("aspect_ratios")
+
+ total_placeholders = sum(prompt.count(self.fake_image_token) for prompt in text)
+ if total_placeholders != len(images):
+ raise ValueError(
+ f"Found {total_placeholders} placeholders across the batch, "
+ f"but have {len(images)} flattened images."
+ )
+
+ image_index = 0
+ processed_text = []
+ for prompt in text:
+ placeholder_count = prompt.count(self.fake_image_token)
+ if placeholder_count == 0:
+ # do nothing if there is no image
+ processed_text.append(prompt)
+ continue
+ prompt_splits = prompt.split(self.fake_image_token)
+ new_prompt = []
+ for local_image_index, split_part in enumerate(prompt_splits):
+ new_prompt.append(split_part)
+ if local_image_index < placeholder_count:
+ tokens_for_this_image = self._prompt_split_image(
+ aspect_ratios[image_index], num_patches_per_chunk
+ )
+ image_index += 1
+ new_prompt.append(tokens_for_this_image)
+ processed_text.append("".join(new_prompt))
+
+ if image_index != len(images):
+ raise ValueError("Number of image placeholders in the prompt does not match the number of images.")
+
+ text = processed_text
+
+ text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
+
+ return BatchFeature(data={**text_inputs, **image_inputs})
+
+ def batch_decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
+ refer to the docstring of this method for more information.
+ """
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ def decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
+ the docstring of this method for more information.
+ """
+ return self.tokenizer.decode(*args, **kwargs)
+
+ @property
+ def model_input_names(self):
+ tokenizer_input_names = self.tokenizer.model_input_names
+ image_processor_input_names = self.image_processor.model_input_names
+ return list(tokenizer_input_names) + list(image_processor_input_names)
+
+
+__all__ = ["Llama4Processor"]
diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py
index 714d69baf4..1df663a26e 100644
--- a/src/transformers/pipelines/base.py
+++ b/src/transformers/pipelines/base.py
@@ -981,6 +981,8 @@ class Pipeline(_ScikitCompat, PushToHubMixin):
else:
self.device = device if device is not None else -1
+ if torch.distributed.is_initialized():
+ self.device = self.model.device
logger.warning(f"Device set to use {self.device}")
self.binary_output = binary_output
diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py
index b486526388..b1c40e7ff2 100644
--- a/src/transformers/processing_utils.py
+++ b/src/transformers/processing_utils.py
@@ -1178,10 +1178,6 @@ class ProcessorMixin(PushToHubMixin):
unused_kwargs = {}
unused_keys = set(kwargs_from_config) - set(valid_kwargs)
if unused_keys:
- unused_key_str = ", ".join(unused_keys)
- logger.warning(
- f"Some kwargs in processor config are unused and will not have any effect: {unused_key_str}. "
- )
unused_kwargs = {k: processor_config[k] for k in unused_keys}
return unused_kwargs
diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py
index cf8a120958..ff3856c9d4 100644
--- a/src/transformers/pytorch_utils.py
+++ b/src/transformers/pytorch_utils.py
@@ -43,8 +43,7 @@ is_torch_greater_or_equal_than_1_12 = is_torch_greater_or_equal("1.12", accept_d
_torch_distributed_available = torch.distributed.is_available()
if is_torch_greater_or_equal("2.5") and _torch_distributed_available:
- from torch.distributed.tensor import Replicate
- from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
+ pass
def softmax_backward_data(parent, grad_output, output, dim, self):
@@ -335,29 +334,6 @@ def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int)
return torch.isin(elements, test_elements)
-# TODO need to add the __repr__ that shows that it is a colwise parallel
-# See https://github.com/pytorch/pytorch/issues/145726
-def translate_to_torch_parallel_style(style: str):
- """
- In model configurations, we use a neutral type (string) to specify parallel
- styles, here we translate them into torch.distributed tensor-parallel
- types.
- """
- if not isinstance(style, str):
- raise ValueError(f"Unsupported parallel style type {type(style)}, expected str")
-
- if style == "colwise":
- return ColwiseParallel()
- elif style == "rowwise":
- return RowwiseParallel()
- elif style == "colwise_rep":
- return ColwiseParallel(output_layouts=Replicate())
- elif style == "rowwise_rep":
- return RowwiseParallel(input_layouts=Replicate())
- else:
- raise ValueError(f"Unsupported parallel style value: {style}")
-
-
def compile_compatible_method_lru_cache(*lru_args, **lru_kwargs):
"""
LRU cache decorator from standard functools library, but with a workaround to disable
@@ -382,88 +358,3 @@ def compile_compatible_method_lru_cache(*lru_args, **lru_kwargs):
return wrapper
return decorator
-
-
-def distribute_module(
- module: nn.Module,
- device_mesh=None,
- partition_fn=None,
- input_fn=None,
- output_fn=None,
-) -> nn.Module:
- """
- This function expose three functions to control the parameters/inputs/outputs of the module:
-
- 1. To perform sharding on the module before runtime execution by specifying the
- ``partition_fn`` (i.e. allow user to convert Module parameters to :class:`DTensor`
- parameters according to the `partition_fn` specified).
- 2. To control the inputs or outputs of the module during runtime execution by
- specifying the ``input_fn`` and ``output_fn``. (i.e. convert the input to
- :class:`DTensor`, convert the output back to ``torch.Tensor``)
-
- Args:
- module (:class:`nn.Module`): user module to be partitioned.
- device_mesh (:class:`DeviceMesh`): the device mesh to place the module.
- partition_fn (Callable): the function to partition parameters (i.e. shard certain
- parameters across the ``device_mesh``). If ``partition_fn`` is not specified,
- by default we replicate all module parameters of ``module`` across the mesh.
- input_fn (Callable): specify the input distribution, i.e. could control how the
- input of the module is sharded. ``input_fn`` will be installed as a module
- ``forward_pre_hook`` (pre forward hook).
- output_fn (Callable): specify the output distribution, i.e. could control how the
- output is sharded, or convert it back to torch.Tensor. ``output_fn`` will be
- installed as a module ``forward_hook`` (post forward hook).
-
- Returns:
- A module that contains parameters/buffers that are all ``DTensor`` s.
-
- .. note::
- When initialize the DeviceMesh with the ``xla`` device_type, ``distribute_module``
- return nn.Module with PyTorch/XLA SPMD annotated parameters. See
- `this issue `__
- for more details. The XLA integration is experimental and subject to change.
-
- """
-
- torch._C._log_api_usage_once("torch.dtensor.distribute_module")
-
- device_mesh = device_mesh
-
- # register input_fn as module forward pre hook
- if input_fn is not None:
- # check the input_fn signature
- num_args = len(inspect.signature(input_fn).parameters)
- if num_args == 2:
- # input_fn only takes in inputs and device mesh
- logger.warning(
- "Deprecating input_fn that takes two arguments (inputs, device_mesh), "
- "please use input_fn that takes in (module, inputs, device_mesh) instead!",
- FutureWarning,
- stacklevel=2,
- )
- module.register_forward_pre_hook(lambda _, inputs: input_fn(inputs, device_mesh)) # type: ignore[call-arg]
- elif num_args == 3:
- # input_fn takes in module, inputs, device mesh
- module.register_forward_pre_hook(lambda mod, inputs: input_fn(mod, inputs, device_mesh))
- else:
- raise ValueError(f"input_fn should take in 3 arguments, but got {num_args} arguments!")
- # register output_fn as module forward hook
- if output_fn is not None:
- num_args = len(inspect.signature(output_fn).parameters)
- if num_args == 2:
- # output_fn only takes in outputs and device mesh
- logger.warning(
- "Deprecating output_fn that takes two arguments (inputs, device_mesh), "
- "please use output_fn that takes in (module, inputs, device_mesh) instead!",
- FutureWarning,
- stacklevel=2,
- )
- module.register_forward_hook(
- lambda mod, inputs, outputs: output_fn(outputs, device_mesh) # type: ignore[call-arg]
- )
- elif num_args == 3:
- module.register_forward_hook(lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh))
- else:
- raise ValueError(f"output_fn should take in 3 arguments, but got {num_args} arguments!")
-
- return module
diff --git a/src/transformers/quantizers/base.py b/src/transformers/quantizers/base.py
old mode 100755
new mode 100644
index 46c44b79f2..a780dca754
--- a/src/transformers/quantizers/base.py
+++ b/src/transformers/quantizers/base.py
@@ -15,7 +15,8 @@ from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from ..utils import is_torch_available
-from ..utils.quantization_config import QuantizationConfigMixin
+from ..utils.quantization_config import QuantizationConfigMixin, QuantizationMethod
+from .quantizers_utils import get_module_from_name
if TYPE_CHECKING:
@@ -23,6 +24,9 @@ if TYPE_CHECKING:
if is_torch_available():
import torch
+ from torch.nn import ModuleList
+else:
+ ModuleList = str
class HfQuantizer(ABC):
@@ -198,6 +202,10 @@ class HfQuantizer(ABC):
"""
return
+ def update_tp_plan(self, config):
+ "updates the tp plan for the scales"
+ return config
+
def preprocess_model(self, model: "PreTrainedModel", **kwargs):
"""
Setting model attributes and/or converting model before weights loading. At this point
@@ -212,6 +220,7 @@ class HfQuantizer(ABC):
"""
model.is_quantized = True
model.quantization_method = self.quantization_config.quant_method
+ self._convert_model_for_quantization(model)
return self._process_model_before_weight_loading(model, **kwargs)
def postprocess_model(self, model: "PreTrainedModel", **kwargs):
@@ -288,3 +297,44 @@ class HfQuantizer(ABC):
@property
@abstractmethod
def is_trainable(self): ...
+
+ def _convert_model_for_quantization(self, model):
+ from accelerate import init_empty_weights
+
+ for name, module in model.named_modules():
+ module_class_name = module.__class__.__name__
+ if (
+ module_class_name in MODULES_TO_PATCH_FOR_QUANTIZATION.keys()
+ and self.quantization_config.quant_method == QuantizationMethod.COMPRESSED_TENSORS
+ ):
+ with init_empty_weights():
+ parent_module, name = get_module_from_name(model, name)
+ parent_module._modules[name] = MODULES_TO_PATCH_FOR_QUANTIZATION[module_class_name](
+ model.config.get_text_config()
+ )
+
+
+class SequentialLlama4TextExperts(ModuleList):
+ """
+ A module that implements a compressed version of a list of expert modules.
+ This is specifically designed to work with Llama4TextExperts in MoE layers.
+ """
+
+ def __init__(self, config):
+ from transformers.models.llama4.modeling_llama4 import Llama4TextMLP
+
+ super().__init__([Llama4TextMLP(config) for _ in range(config.num_local_experts)])
+ self.num_experts = config.num_local_experts
+
+ def forward(
+ self,
+ hidden_states: "torch.Tensor",
+ ) -> "torch.Tensor":
+ hidden_states = hidden_states.reshape(self.num_experts, -1, hidden_states.shape[-1])
+ routed_out = torch.zeros_like(hidden_states)
+ for expert_idx in range(self.num_experts):
+ routed_out[expert_idx] = self[expert_idx](hidden_states[expert_idx])
+ return routed_out
+
+
+MODULES_TO_PATCH_FOR_QUANTIZATION = {"Llama4TextExperts": SequentialLlama4TextExperts}
diff --git a/src/transformers/quantizers/quantizer_compressed_tensors.py b/src/transformers/quantizers/quantizer_compressed_tensors.py
index ee1d0df380..4e45abf953 100644
--- a/src/transformers/quantizers/quantizer_compressed_tensors.py
+++ b/src/transformers/quantizers/quantizer_compressed_tensors.py
@@ -146,6 +146,19 @@ class CompressedTensorsHfQuantizer(HfQuantizer):
self.compressor.quantization_config.quantization_status = QuantizationStatus.FROZEN
self.compressor.decompress(model_path=cache_path, model=model)
+ def update_tp_plan(self, config):
+ additional_plan = {
+ "layers.*.feed_forward.experts.*.gate_proj.weight": "local_colwise",
+ "layers.*.feed_forward.experts.*.gate_proj.weight_scale": "local_colwise",
+ "layers.*.feed_forward.experts.*.up_proj.weight": "local_colwise",
+ "layers.*.feed_forward.experts.*.up_proj.weight_scale": "local_colwise",
+ "layers.*.feed_forward.experts.*.down_proj.weight": "local_rowwise",
+ }
+ if config.get_text_config() is not None and config.get_text_config().base_model_tp_plan is not None:
+ config.get_text_config().base_model_tp_plan.update(additional_plan)
+
+ return config
+
@property
def is_trainable(self):
return True
diff --git a/src/transformers/quantizers/quantizer_fbgemm_fp8.py b/src/transformers/quantizers/quantizer_fbgemm_fp8.py
index 69c55d01bd..f0fa8f063b 100644
--- a/src/transformers/quantizers/quantizer_fbgemm_fp8.py
+++ b/src/transformers/quantizers/quantizer_fbgemm_fp8.py
@@ -116,7 +116,7 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
state_dict: Dict[str, Any],
**kwargs,
):
- from ..integrations import FbgemmFp8Linear
+ from ..integrations import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts
module, tensor_name = get_module_from_name(model, param_name)
@@ -129,6 +129,13 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
if tensor_name == "weight_scale":
raise ValueError("Expect unquantized weights but got a quantized weight_scale")
return True
+ if isinstance(module, FbgemmFp8Llama4TextExperts):
+ if self.pre_quantized or tensor_name == "bias":
+ return False
+ else:
+ if tensor_name == "gate_up_proj_scale" or tensor_name == "down_proj_scale":
+ raise ValueError("Expect unquantized weights but got a quantized weight_scale")
+ return True
return False
def create_quantized_param(
@@ -143,12 +150,52 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
"""
Quantizes weights into weight and weight_scale
"""
- new_value, weight_scale = torch.ops.fbgemm.quantize_fp8_per_row(param_value)
+
+ from ..integrations import FbgemmFp8Llama4TextExperts
module, tensor_name = get_module_from_name(model, param_name)
- module._buffers[tensor_name] = new_value.to(target_device)
- # to have the right output shape -> (out_features, 1)
- module._buffers["weight_scale"] = weight_scale.view(weight_scale.shape[0], 1).to(target_device)
+ if isinstance(module, FbgemmFp8Llama4TextExperts):
+ if tensor_name == "gate_up_proj":
+ # Process each expert separately
+ # Transpose the second and third dimension
+ transposed_param = param_value.transpose(1, 2)
+
+ # Reshape to 2D for quantization
+ original_shape = transposed_param.shape
+ flattened_param = transposed_param.reshape(-1, original_shape[-1])
+
+ # Quantize using per row instead of per column
+ new_value_flat, weight_scale_flat = torch.ops.fbgemm.quantize_fp8_per_row(flattened_param)
+
+ # Reshape back to original dimensions
+ new_value = new_value_flat.reshape(original_shape)
+ new_value = new_value.transpose(1, 2)
+ weight_scale = weight_scale_flat.reshape(original_shape[0], 1, original_shape[1])
+ elif tensor_name == "down_proj":
+ # Process each expert separately
+ # Transpose the weights for proper quantization
+ transposed_param = param_value.transpose(1, 2)
+
+ # Reshape to 2D for quantization
+ original_shape = transposed_param.shape
+ flattened_param = transposed_param.reshape(-1, original_shape[-1])
+
+ # Quantize using per column
+ new_value_flat, weight_scale_flat = torch.ops.fbgemm.quantize_fp8_per_row(flattened_param)
+
+ # Reshape back to original dimensions
+ new_value = new_value_flat.reshape(original_shape)
+ new_value = new_value.transpose(1, 2)
+ weight_scale = weight_scale_flat.reshape(original_shape[0], original_shape[1], 1)
+
+ module._parameters[f"{tensor_name}_scale"] = torch.nn.Parameter(weight_scale.to(target_device))
+ else:
+ new_value, weight_scale = torch.ops.fbgemm.quantize_fp8_per_row(param_value)
+ module._parameters[f"{tensor_name}_scale"] = torch.nn.Parameter(
+ weight_scale.view(weight_scale.shape[0], 1).to(target_device)
+ )
+
+ module._parameters[tensor_name] = torch.nn.Parameter(new_value.to(target_device))
if unexpected_keys is not None and param_name in unexpected_keys:
unexpected_keys.remove(param_name)
@@ -165,25 +212,29 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
):
from ..integrations import replace_with_fbgemm_fp8_linear
+ tp_plan = model._tp_plan
self.modules_to_not_convert = self.get_modules_to_not_convert(
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
)
+ config = model.config
model = replace_with_fbgemm_fp8_linear(
model,
modules_to_not_convert=self.modules_to_not_convert,
quantization_config=self.quantization_config,
pre_quantized=self.pre_quantized,
+ config=config,
+ tp_plan=tp_plan,
)
model.config.quantization_config = self.quantization_config
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
- from ..integrations import FbgemmFp8Linear
+ from ..integrations import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts
not_missing_keys = []
for name, module in model.named_modules():
- if isinstance(module, FbgemmFp8Linear):
+ if isinstance(module, FbgemmFp8Linear) or isinstance(module, FbgemmFp8Llama4TextExperts):
for missing in missing_keys:
if (
(name in missing or name in f"{prefix}.{missing}")
diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py
index cb1169fe02..8e047569e7 100644
--- a/src/transformers/tokenization_utils_base.py
+++ b/src/transformers/tokenization_utils_base.py
@@ -3950,7 +3950,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
verbose (`bool`): Whether or not to print more information and warnings.
"""
- if max_length is None and len(ids) > self.model_max_length and verbose:
+ if max_length is None and len(ids) > self.model_max_length and verbose and self.model_max_length != 0:
if not self.deprecation_warnings.get("sequence-length-is-longer-than-the-specified-maximum", False):
logger.warning(
"Token indices sequence length is longer than the specified maximum sequence length "
diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py
index be964d8490..b03455c89e 100644
--- a/src/transformers/utils/dummy_pt_objects.py
+++ b/src/transformers/utils/dummy_pt_objects.py
@@ -5823,6 +5823,41 @@ class LlamaPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
+class Llama4ForCausalLM(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class Llama4ForConditionalGeneration(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class Llama4PreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class Llama4TextModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class Llama4VisionModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
class LlavaForConditionalGeneration(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/src/transformers/utils/dummy_torchvision_objects.py b/src/transformers/utils/dummy_torchvision_objects.py
index 50314fc55e..c59c9b4bdd 100644
--- a/src/transformers/utils/dummy_torchvision_objects.py
+++ b/src/transformers/utils/dummy_torchvision_objects.py
@@ -72,6 +72,13 @@ class GotOcr2ImageProcessorFast(metaclass=DummyObject):
requires_backends(self, ["torchvision"])
+class Llama4ImageProcessorFast(metaclass=DummyObject):
+ _backends = ["torchvision"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torchvision"])
+
+
class LlavaImageProcessorFast(metaclass=DummyObject):
_backends = ["torchvision"]
diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py
index ffd15d64ac..4e7293d50a 100644
--- a/src/transformers/utils/dummy_vision_objects.py
+++ b/src/transformers/utils/dummy_vision_objects.py
@@ -408,6 +408,13 @@ class LevitImageProcessor(metaclass=DummyObject):
requires_backends(self, ["vision"])
+class Llama4ImageProcessor(metaclass=DummyObject):
+ _backends = ["vision"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["vision"])
+
+
class LlavaImageProcessor(metaclass=DummyObject):
_backends = ["vision"]
diff --git a/tests/models/llama4/__init__.py b/tests/models/llama4/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/models/llama4/test_image_processing_llama4.py b/tests/models/llama4/test_image_processing_llama4.py
new file mode 100644
index 0000000000..bf84b3550d
--- /dev/null
+++ b/tests/models/llama4/test_image_processing_llama4.py
@@ -0,0 +1,128 @@
+# coding=utf-8
+# Copyright 2022 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import unittest
+
+from transformers.testing_utils import require_torch, require_vision
+from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
+
+from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
+
+
+if is_torch_available():
+ pass
+
+if is_vision_available() and is_torchvision_available():
+ from transformers import Llama4ImageProcessorFast
+
+
+class Llama4ImageProcessingTester(unittest.TestCase):
+ def __init__(
+ self,
+ parent,
+ batch_size=7,
+ num_channels=3,
+ image_size=18,
+ min_resolution=30,
+ max_resolution=400,
+ max_patches=1,
+ do_resize=True,
+ size=None,
+ do_normalize=True,
+ do_pad=False,
+ image_mean=[0.5, 0.5, 0.5],
+ image_std=[0.5, 0.5, 0.5],
+ do_convert_rgb=True,
+ ):
+ super().__init__()
+ size = size if size is not None else {"height": 20, "width": 20}
+ self.parent = parent
+ self.batch_size = batch_size
+ self.num_channels = num_channels
+ self.image_size = image_size
+ self.min_resolution = min_resolution
+ self.max_resolution = max_resolution
+ self.max_patches = max_patches
+ self.do_resize = do_resize
+ self.size = size
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean
+ self.image_std = image_std
+ self.do_pad = do_pad
+ self.do_convert_rgb = do_convert_rgb
+
+ def prepare_image_processor_dict(self):
+ return {
+ "max_patches": self.max_patches,
+ "do_resize": self.do_resize,
+ "size": self.size,
+ "do_normalize": self.do_normalize,
+ "image_mean": self.image_mean,
+ "image_std": self.image_std,
+ "do_convert_rgb": self.do_convert_rgb,
+ "do_pad": self.do_pad,
+ }
+
+ def expected_output_image_shape(self, images):
+ return self.num_channels, self.size["height"], self.size["width"]
+
+ def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
+ return prepare_image_inputs(
+ batch_size=self.batch_size,
+ num_channels=self.num_channels,
+ min_resolution=self.min_resolution,
+ max_resolution=self.max_resolution,
+ equal_resolution=equal_resolution,
+ numpify=numpify,
+ torchify=torchify,
+ )
+
+
+@require_torch
+@require_vision
+class Llama4ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
+ test_slow_image_processor = False
+ fast_image_processing_class = Llama4ImageProcessorFast if is_torchvision_available() else None
+
+ def setUp(self):
+ super().setUp()
+ self.image_processor_tester = Llama4ImageProcessingTester(self)
+
+ @property
+ def image_processor_dict(self):
+ return self.image_processor_tester.prepare_image_processor_dict()
+
+ def test_image_processor_properties(self):
+ for image_processing_class in self.image_processor_list:
+ image_processor = image_processing_class(**self.image_processor_dict)
+ self.assertTrue(hasattr(image_processor, "do_resize"))
+ self.assertTrue(hasattr(image_processor, "size"))
+ self.assertTrue(hasattr(image_processor, "do_normalize"))
+ self.assertTrue(hasattr(image_processor, "image_mean"))
+ self.assertTrue(hasattr(image_processor, "image_std"))
+ self.assertTrue(hasattr(image_processor, "do_convert_rgb"))
+
+ def test_split_tiles(self):
+ for image_processing_class in self.image_processor_list:
+ image_processor = image_processing_class(**self.image_processor_dict)
+ image = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)[0]
+ processed_images = image_processor(
+ image,
+ max_patches=16,
+ )
+ self.assertEqual(len(processed_images.pixel_values), 1)
+ self.assertEqual(processed_images.pixel_values[0].shape[0], 17)
+ self.assertEqual(processed_images.pixel_values[0].shape[-2:], (20, 20))
diff --git a/tests/models/llama4/test_modeling_llama4.py b/tests/models/llama4/test_modeling_llama4.py
new file mode 100644
index 0000000000..65672993a0
--- /dev/null
+++ b/tests/models/llama4/test_modeling_llama4.py
@@ -0,0 +1,121 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Testing suite for the PyTorch Llama4 model."""
+
+import unittest
+
+from transformers import is_torch_available
+from transformers.testing_utils import (
+ require_read_token,
+ require_torch_large_gpu,
+ slow,
+ torch_device,
+)
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import (
+ Llama4ForConditionalGeneration,
+ Llama4Processor,
+ )
+
+
+@slow
+@require_torch_large_gpu
+@require_read_token
+class Llama4IntegrationTest(unittest.TestCase):
+ model_id = "ll-re/Llama-4-17B-Omni-Instruct"
+ # This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
+ # Depending on the hardware we get different logits / generations
+ cuda_compute_capability_major_version = None
+
+ @classmethod
+ def setUpClass(cls):
+ if is_torch_available() and torch.cuda.is_available():
+ # 8 is for A100 / A10 and 7 for T4
+ cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
+ cls.model = Llama4ForConditionalGeneration.from_pretrained(
+ "ll-re/Llama-4-17B-Omni-Instruct", device_map="auto", torch_dtype=torch.float32
+ )
+
+ def setUp(self):
+ self.processor = Llama4Processor.from_pretrained("ll-re/Llama-4-17B-Omni-Instruct", padding_side="left")
+
+ url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png"
+ self.messages = [
+ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
+ {
+ "role": "user",
+ "content": [
+ {"type": "image", "url": url},
+ {"type": "text", "text": "What is shown in this image?"},
+ ],
+ },
+ ]
+
+ def test_model_17b_16e_fp16(self):
+ EXPECTED_TEXT = [
+ "The capital of France is Paris, which is located in the north-central part of the country. Paris is known for its iconic landmarks such as the",
+ "Roses are red, violets are blue, and this poem is about you. Roses are red, violets are blue, and I love",
+ ]
+
+ messages = [
+ {"role": "user", "content": "Who are you?"},
+ ]
+ inputs = self.processor.apply_chat_template(
+ messages, add_generation_prompt=True, return_tensors="pt", return_dict=True
+ ).to(torch_device)
+
+ output = self.model.generate(**inputs, max_new_tokens=100)
+ output_text = self.processor.batch_decode(output, skip_special_tokens=True)
+
+ print(output_text)
+ self.assertEqual(output_text, EXPECTED_TEXT)
+
+ def test_model_17b_16e_batch(self):
+ messages_2 = [
+ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image",
+ "url": "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
+ },
+ {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
+ {"type": "text", "text": "Are these images identical?"},
+ ],
+ },
+ ]
+
+ inputs = self.processor.apply_chat_template(
+ [self.messages, messages_2],
+ tokenize=True,
+ return_dict=True,
+ return_tensors="pt",
+ padding=True,
+ add_generation_prompt=True,
+ ).to(torch_device)
+
+ output = self.model.generate(**inputs, max_new_tokens=30, do_sample=False)
+ output_text = self.processor.batch_decode(output, skip_special_tokens=True)
+
+ EXPECTED_TEXTS = [
+ 'user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear turquoise water and a blue sky in the background. It looks like',
+ "user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, these images are not identical. \n\nHere's a breakdown of the differences:\n\n* **Image 1:** Shows a cow"
+ ] # fmt: skip
+ self.assertEqual(output_text, EXPECTED_TEXTS)
diff --git a/tests/models/llama4/test_processor_llama4.py b/tests/models/llama4/test_processor_llama4.py
new file mode 100644
index 0000000000..4ec01fa497
--- /dev/null
+++ b/tests/models/llama4/test_processor_llama4.py
@@ -0,0 +1,65 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import shutil
+import tempfile
+import unittest
+from typing import Optional
+
+from transformers import AutoProcessor, Llama4Processor, PreTrainedTokenizerFast
+from transformers.testing_utils import require_vision
+from transformers.utils import is_vision_available
+
+from ...test_processing_common import ProcessorTesterMixin
+
+
+if is_vision_available():
+ from transformers import Llama4ImageProcessorFast
+
+
+@require_vision
+class Llama4ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
+ processor_class = Llama4Processor
+
+ def setUp(self):
+ self.tmpdirname = tempfile.mkdtemp()
+
+ image_processor = Llama4ImageProcessorFast(max_patches=1, size={"height": 20, "width": 20})
+ tokenizer = PreTrainedTokenizerFast.from_pretrained("unsloth/Llama-3.2-11B-Vision-Instruct-unsloth-bnb-4bit")
+ processor_kwargs = self.prepare_processor_dict()
+ processor = Llama4Processor(image_processor, tokenizer, **processor_kwargs)
+ processor.save_pretrained(self.tmpdirname)
+
+ def get_tokenizer(self, **kwargs):
+ return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
+
+ def get_image_processor(self, **kwargs):
+ return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
+
+ def tearDown(self):
+ shutil.rmtree(self.tmpdirname)
+
+ # Override as Llama4ProcessorProcessor needs image tokens in prompts
+ def prepare_text_inputs(self, batch_size: Optional[int] = None):
+ if batch_size is None:
+ return "lower newer "
+
+ if batch_size < 1:
+ raise ValueError("batch_size must be greater than 0")
+
+ if batch_size == 1:
+ return ["lower newer "]
+ return ["lower newer ", " upper older longer string"] + [" lower newer"] * (
+ batch_size - 2
+ )
diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py
index 7ce1bdd0f8..76fc2cc428 100644
--- a/utils/check_config_attributes.py
+++ b/utils/check_config_attributes.py
@@ -236,6 +236,16 @@ SPECIAL_CASES_TO_ALLOW = {
"text_config",
"vision_config",
],
+ "Llama4Config": ["boi_token_index", "eoi_token_index"],
+ "Llama4TextConfig": [
+ "interleave_moe_layer_step",
+ "no_rope_layer_interval",
+ "no_rope_layers",
+ "output_router_logits",
+ "router_aux_loss_coef",
+ "router_jitter_noise",
+ ],
+ "Llama4VisionConfig": ["multi_modal_projector_bias", "norm_eps"],
}
@@ -358,6 +368,8 @@ def check_attribute_being_used(config_class, attributes, default_value, source_s
"rope_theta",
"partial_rotary_factor",
"pretraining_tp",
+ "boi_token_index",
+ "eoi_token_index",
]
attributes_used_in_generation = ["encoder_no_repeat_ngram_size"]
diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py
index bdcec87c2b..b9f6645638 100644
--- a/utils/check_docstrings.py
+++ b/utils/check_docstrings.py
@@ -67,6 +67,7 @@ _re_parse_description = re.compile(r"\*optional\*, defaults to (.*)$")
# docstrings instead. If formatting should be ignored for the docstring, you can put a comment # no-format on the
# line before the docstring.
OBJECTS_TO_IGNORE = [
+ "Llama4Processor",
# Deprecated
"InputExample",
"InputFeatures",
diff --git a/utils/check_dummies.py b/utils/check_dummies.py
index 73d7ebbfd1..48a6b2fa71 100644
--- a/utils/check_dummies.py
+++ b/utils/check_dummies.py
@@ -223,13 +223,20 @@ def check_dummies(overwrite: bool = False):
f.write(dummy_files[backend])
else:
# Temporary fix to help people identify which objects introduced are not correctly protected.
+ found = False
for _actual, _dummy in zip(
actual_dummies["torch"].split("class"), dummy_files["torch"].split("class")
):
if _actual != _dummy:
actual_broken = _actual
dummy_broken = _dummy
+ found = True
break
+
+ if not found:
+ print("A transient error was found with the dummies, please investigate.")
+ continue
+
raise ValueError(
"The main __init__ has objects that are not present in "
f"transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py.\n"
diff --git a/utils/check_repo.py b/utils/check_repo.py
index b4119bb7b6..42beda83e6 100644
--- a/utils/check_repo.py
+++ b/utils/check_repo.py
@@ -144,6 +144,8 @@ IGNORE_NON_TESTED = (
"Qwen2_5_VLModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5_VLForConditionalGeneration.
"MllamaTextModel", # Building part of bigger (tested) model. # TODO: add tests
"MllamaVisionModel", # Building part of bigger (tested) model. # TODO: add tests
+ "Llama4TextModel", # Building part of bigger (tested) model. # TODO: add tests
+ "Llama4VisionModel", # Building part of bigger (tested) model. # TODO: add tests
"Emu3VQVAE", # Building part of bigger (tested) model
"Emu3TextModel", # Building part of bigger (tested) model
]
@@ -170,6 +172,7 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [
"models/decision_transformer/test_modeling_decision_transformer.py",
"models/bark/test_modeling_bark.py",
"models/shieldgemma2/test_modeling_shieldgemma2.py",
+ "models/llama4/test_modeling_llama4.py",
]
# Update this list for models that are not in any of the auto MODEL_XXX_MAPPING. Being in this list is an exception and