diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index ed306fac94..6c4b7498b3 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -507,6 +507,8 @@ title: Llama2 - local: model_doc/llama3 title: Llama3 + - local: model_doc/llama4 + title: Llama4 - local: model_doc/longformer title: Longformer - local: model_doc/longt5 diff --git a/docs/source/en/model_doc/llama4.md b/docs/source/en/model_doc/llama4.md new file mode 100644 index 0000000000..8e2cd3a278 --- /dev/null +++ b/docs/source/en/model_doc/llama4.md @@ -0,0 +1,442 @@ + + +# Llama4 + + +
+
+ PyTorch + FlashAttention +
+
+ +Llama 4, developed by Meta, introduces a new auto-regressive Mixture-of-Experts (MoE) architecture. +This generation includes two models: +- The highly capable Llama 4 Maverick with 17B active parameters out of ~400B total, with 128 experts. +- The efficient Llama 4 Scout also has 17B active parameters out of ~109B total, using just 16 experts. +- +Both models leverage early fusion for native multimodality, enabling them to process text and image inputs. +Maverick and Scout are both trained on up to 40 trillion tokens on data encompassing 200 languages +(with specific fine-tuning support for 12 languages including Arabic, Spanish, German, and Hindi). + +For deployment, Llama 4 Scout is designed for accessibility, fitting on a single server-grade GPU via +on-the-fly 4-bit or 8-bitint4 quantization, while Maverick is available in BF16 and FP8 formats. +These models are released under the custom Llama 4 Community License Agreement, available on the model repositories. + +You can find all the original Llama checkpoints under the [meta-llama](https://huggingface.co/meta-llama) organization. + +> [!TIP] +> The Llama 4 family of models comes in two flavors: 109B, and 402B parameters. Both of these flavors are extremely +> large and won't fit on your run-of-the-mill device. See below for some examples to reduce the memory usage of the +> model. +> +> For the download to be faster and more resilient, we recommend installing the `hf_xet` dependency as followed: +> `pip install transformers[hf_xet]` + +The examples below demonstrates how to generate with [`Pipeline`] or the [`AutoModel`]. We additionally add an example +showcasing how to toggle the right attributes to enable very long-context generations, as some flavors of Llama 4 +have context lengths going up to 10 million tokens. + + + + + +```py +from transformers import pipeline +import torch + +model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" + +messages = [ + {"role": "user", "content": "what is the recipe of mayonnaise?"}, +] + +pipe = pipeline( + "text-generation", + model=model_id, + device_map="auto", + torch_dtype=torch.bfloat16 +) + +output = pipe(messages, do_sample=False, max_new_tokens=200) +print(output[0]["generated_text"][-1]["content"]) +``` + + + + +```py +from transformers import AutoTokenizer, Llama4ForConditionalGeneration +import torch + +model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" + +tokenizer = AutoTokenizer.from_pretrained(model_id) + +messages = [ + {"role": "user", "content": "Who are you?"}, +] +inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True) + +model = Llama4ForConditionalGeneration.from_pretrained( + model_id, + device_map="auto", + torch_dtype=torch.bfloat16 +) + +outputs = model.generate(**inputs.to(model.device), max_new_tokens=100) +outputs = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:]) +print(outputs[0]) +``` + + + + +```py +from transformers import AutoProcessor, Llama4ForConditionalGeneration +import torch + +model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" + +processor = AutoProcessor.from_pretrained(model_id) +model = Llama4ForConditionalGeneration.from_pretrained( + model_id, + device_map="auto", + torch_dtype=torch.bfloat16, +) + +img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg" +messages = [ + { + "role": "user", + "content": [ + {"type": "image", "url": img_url}, + {"type": "text", "text": "Describe this image in two sentences."}, + ] + }, +] + +inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", +).to(model.device) + +outputs = model.generate( + **inputs, + max_new_tokens=256, +) + +response = processor.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])[0] +print(response) +``` + + + + +```py +from transformers import AutoProcessor, Llama4ForConditionalGeneration +import torch + +model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" + +processor = AutoProcessor.from_pretrained(model_id) +model = Llama4ForConditionalGeneration.from_pretrained( + model_id, + device_map="auto", + torch_dtype=torch.bfloat16, +) + +url1 = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg" +url2 = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png" +messages = [ + { + "role": "user", + "content": [ + {"type": "image", "url": url1}, + {"type": "image", "url": url2}, + {"type": "text", "text": "Can you describe how these two images are similar, and how they differ?"}, + ] + }, +] + +inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", +).to(model.device) + +outputs = model.generate( + **inputs, + max_new_tokens=256, +) + +response = processor.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])[0] +print(response) +``` + + + + +Beware: the example below uses both `device_map="auto"` and flex-attention. +Please use `torchrun` to run this example in tensor-parallel mode. + +We will work to enable running with `device_map="auto"` and flex-attention without +tensor-parallel in the future. + +```py +from transformers import Llama4ForConditionalGeneration, AutoTokenizer +import torch +import time + +file = "very_long_context_prompt.txt" +model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" + +with open(file, "r") as f: + very_long_text = "\n".join(f.readlines()) + +tokenizer = AutoTokenizer.from_pretrained(model_id) +model = Llama4ForConditionalGeneration.from_pretrained( + model_id, + device_map="auto", + attn_implementation="flex_attention", + torch_dtype=torch.bfloat16 +) + +messages = [ + {"role": "user", "content": f"Look at the following texts: [{very_long_text}]\n\n\n\nWhat are the books, and who wrote them? Make me a nice list."}, +] +input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt") + +torch.cuda.synchronize() +start = time.time() +out = model.generate( + input_ids.to(model.device), + prefill_chunk_size=2048*8, + max_new_tokens=300, + cache_implementation="hybrid", +) +print(time.time()-start) +print(tokenizer.batch_decode(out[:, input_ids.shape[-1]:])) +print(f"{torch.cuda.max_memory_allocated(model.device) / 1024**3:.2f} GiB") +``` + + + + +## Efficiency; how to get the best out of llama 4 + +### The Attention methods + +Updating the default attention function can significantly improve compute performance as well as memory usage. Refer to the [Attention Interface](../attention_interface) overview for an in-depth explanation of our interface. + +As of release, the Llama 4 model supports the following attention methods: `eager`, `flex_attention`, `sdpa`. We recommend using `flex_attention` for best results. +Switching attention mechanism is done at the model initialization step: + + + + + +Setting Flex Attention ensures the best results with the very long context the model can handle. + +> [!TIP] Beware: the example below uses both `device_map="auto"` and flex-attention. +> Please use `torchrun` to run this example in tensor-parallel mode. +> +> We will work to enable running with `device_map="auto"` and flex-attention without +> tensor-parallel in the future. + +```py +from transformers import Llama4ForConditionalGeneration +import torch + +model = Llama4ForConditionalGeneration.from_pretrained( + model_id, + attn_implementation="flex_attention", + device_map="auto", + torch_dtype=torch.bfloat16, +) +``` + + +The `sdpa` attention method is generally more compute-efficient than the `eager` method. + +```py +from transformers import Llama4ForConditionalGeneration +import torch + +model = Llama4ForConditionalGeneration.from_pretrained( + model_id, + attn_implementation="sdpa", + device_map="auto", + torch_dtype=torch.bfloat16, +) +``` + + +The `eager` attention method is set by default, so no need for anything different when loading the model: + +```py +from transformers import Llama4ForConditionalGeneration +import torch + +model = Llama4ForConditionalGeneration.from_pretrained( + model_id, + device_map="auto", + torch_dtype=torch.bfloat16, +) +``` + + + + +### Quantization + +Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for available quantization backends. +At time of release, both FBGEMM and LLM-Compressor are supported; more quantization methods will be supported in the days that follow the release. + +See below for examples using both: + + + +Here is an example loading an BF16 model in FP8 using the FBGEMM approach: + + + + +```python +from transformers import AutoTokenizer, Llama4ForConditionalGeneration, FbgemmFp8Config +import torch + +model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" + +tokenizer = AutoTokenizer.from_pretrained(model_id) + +messages = [ + {"role": "user", "content": "Who are you?"}, +] +inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True) + +model = Llama4ForConditionalGeneration.from_pretrained( + model_id, + device_map="auto", + torch_dtype=torch.bfloat16, + quantization_config=FbgemmFp8Config() +) + +outputs = model.generate(**inputs.to(model.device), max_new_tokens=100) +outputs = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:]) +print(outputs[0]) +``` + + + + +To use the LLM-Compressor technique, we recommend leveraging the pre-quantized FP8 checkpoint available with the release: + +```python +from transformers import AutoTokenizer, Llama4ForConditionalGeneration +import torch + +model_id = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8" + +tokenizer = AutoTokenizer.from_pretrained(model_id) + +messages = [ + {"role": "user", "content": "Who are you?"}, +] +inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True) + +model = Llama4ForConditionalGeneration.from_pretrained( + model_id, + tp_plan="auto", + torch_dtype=torch.bfloat16, +) + +outputs = model.generate(**inputs.to(model.device), max_new_tokens=100) +outputs = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:]) +print(outputs[0]) +``` + + + +### Offloading + +Enabling CPU-offloading means that components of the model might be moved to CPU instead of GPU in case the GPU-memory available isn't sufficient to load the entire model. +At inference, different components will be loaded/unloaded from/to the GPU on the fly. This ensures that the model can be loaded on smaller machines as long as the CPU-memory is sufficient. +However, this also slows down inference as it adds communication overhead. + +In order to enable CPU-offloading, you simply need to specify the `device_map` to `auto` at model load: + +```py +from transformers import Llama4ForConditionalGeneration +import torch + +model = Llama4ForConditionalGeneration.from_pretrained( + model_id, + device_map="auto", + torch_dtype=torch.bfloat16, +) +``` + +## Llama4Config + +[[autodoc]] Llama4Config + +## Llama4TextConfig + +[[autodoc]] Llama4TextConfig + +## Llama4VisionConfig + +[[autodoc]] Llama4VisionConfig + +## Llama4Processor + +[[autodoc]] Llama4Processor + +## Llama4ImageProcessorFast + +[[autodoc]] Llama4ImageProcessorFast + +## Llama4ForConditionalGeneration + +[[autodoc]] Llama4ForConditionalGeneration +- forward + +## Llama4ForCausalLM + +[[autodoc]] Llama4ForCausalLM +- forward + +## Llama4TextModel + +[[autodoc]] Llama4TextModel +- forward + +## Llama4ForCausalLM + +[[autodoc]] Llama4ForCausalLM +- forward + +## Llama4VisionModel + +[[autodoc]] Llama4VisionModel +- forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 96f993e7e5..3b525f7f15 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -562,6 +562,12 @@ _import_structure = { "models.levit": ["LevitConfig"], "models.lilt": ["LiltConfig"], "models.llama": ["LlamaConfig"], + "models.llama4": [ + "Llama4Config", + "Llama4Processor", + "Llama4TextConfig", + "Llama4VisionConfig", + ], "models.llava": [ "LlavaConfig", "LlavaProcessor", @@ -1354,6 +1360,7 @@ else: _import_structure["models.detr"].append("DetrImageProcessorFast") _import_structure["models.gemma3"].append("Gemma3ImageProcessorFast") _import_structure["models.got_ocr2"].append("GotOcr2ImageProcessorFast") + _import_structure["models.llama4"].append("Llama4ImageProcessorFast") _import_structure["models.llava"].append("LlavaImageProcessorFast") _import_structure["models.llava_next"].append("LlavaNextImageProcessorFast") _import_structure["models.llava_onevision"].append("LlavaOnevisionImageProcessorFast") @@ -2510,6 +2517,15 @@ else: "GlmPreTrainedModel", ] ) + _import_structure["models.llama4"].extend( + [ + "Llama4ForCausalLM", + "Llama4ForConditionalGeneration", + "Llama4TextModel", + "Llama4VisionModel", + "Llama4PreTrainedModel", + ] + ) _import_structure["models.glpn"].extend( [ "GLPNForDepthEstimation", @@ -5807,6 +5823,12 @@ if TYPE_CHECKING: from .models.levit import LevitConfig from .models.lilt import LiltConfig from .models.llama import LlamaConfig + from .models.llama4 import ( + Llama4Config, + Llama4Processor, + Llama4TextConfig, + Llama4VisionConfig, + ) from .models.llava import ( LlavaConfig, LlavaProcessor, @@ -6646,6 +6668,7 @@ if TYPE_CHECKING: from .models.detr import DetrImageProcessorFast from .models.gemma3 import Gemma3ImageProcessorFast from .models.got_ocr2 import GotOcr2ImageProcessorFast + from .models.llama4 import Llama4ImageProcessorFast from .models.llava import LlavaImageProcessorFast from .models.llava_next import LlavaNextImageProcessorFast from .models.llava_onevision import LlavaOnevisionImageProcessorFast @@ -7827,6 +7850,13 @@ if TYPE_CHECKING: LlamaModel, LlamaPreTrainedModel, ) + from .models.llama4 import ( + Llama4ForCausalLM, + Llama4ForConditionalGeneration, + Llama4PreTrainedModel, + Llama4TextModel, + Llama4VisionModel, + ) from .models.llava import ( LlavaForConditionalGeneration, LlavaPreTrainedModel, diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 3f2e02a703..cd2a711960 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1811,6 +1811,200 @@ class HybridCache(Cache): self.value_cache[layer_idx].zero_() +class HybridChunkedCache(Cache): + """ + Hybrid Cache class to be used with `torch.compile` for Gemma2 models that alternate between a local sliding window attention + and global attention in every other layer. Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention + and ["StaticCache"] for global attention. For more information, see the documentation of each subcomponeent cache class. + + Parameters: + config (`PretrainedConfig): + The configuration file defining the shape-related attributes required to initialize the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a + smaller batch size is used. + max_cache_len (`int`, *optional*): + The maximum sequence length with which the model will be used. + device (`torch.device` or `str`, *optional*): + The device on which the cache should be initialized. If you're using more than 1 computation device, you + should pass the `layer_device_map` argument instead. + dtype (torch.dtype, *optional*, defaults to `torch.bfloat16`): + The default `dtype` to use when initializing the layer. + layer_device_map (`Optional[Dict[int, Union[str, torch.device, int]]]]`, *optional*): + Mapping between the layers and its device. This is required when you are manually initializing the cache + and the model is split between different gpus. You can know which layers mapped to which device by + checking the associated device_map: `model.hf_device_map`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache + + >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b") + + >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate + >>> max_generated_length = inputs.input_ids.shape[1] + 10 + >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + HybridCache() + ``` + """ + + # TODO (joao): dive deeper into gemma2 and paligemma -- there are reports of speed loss with compilation. Revert + # ALL changes from the PR that commented the line below when reactivating it. + # is_compileable = True + + def __init__( + self, + config: PretrainedConfig, + max_batch_size: int, + max_cache_len: Optional[int] = None, + device: Union[torch.device, str, None] = None, + dtype: torch.dtype = torch.bfloat16, + layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, + ) -> None: + super().__init__() + if not hasattr(config, "sliding_window") or config.sliding_window is None: + self.sliding_window = getattr(config.get_text_config(), "attention_chunk_size", 8192) + else: + self.sliding_window = config.sliding_window + self.max_cache_len = max_cache_len + self.max_batch_size = max_batch_size + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self._dtype = dtype + + if hasattr(config.get_text_config(), "no_rope_layers"): + self.is_sliding = config.no_rope_layers + else: + layer_switch = getattr(config, "sliding_window_pattern", 2) + self.is_sliding = [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)] + + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + self.cumulative_length = [0 for _ in range(config.num_hidden_layers)] + + def initialise_cache_layer(self, layer_idx, key_states): + if len(self.key_cache) > layer_idx: + return + + num_key_value_heads = key_states.shape[1] + device = key_states.device + global_cache_shape = (self.max_batch_size, num_key_value_heads, self.max_cache_len, self.head_dim) + sliding_cache_shape = ( + self.max_batch_size, + num_key_value_heads, + self.sliding_window, + self.head_dim, + ) + # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph + # breaks when updating the cache. + cache_shape = sliding_cache_shape if self.is_sliding[layer_idx] else global_cache_shape + new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device) + new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device) + torch._dynamo.mark_static_address(new_layer_key_cache) + torch._dynamo.mark_static_address(new_layer_value_cache) + self.key_cache.append(new_layer_key_cache) + self.value_cache.append(new_layer_value_cache) + + def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): + cumulative_length = self.cumulative_length[layer_idx] + is_full = cumulative_length >= max_cache_len + if is_full: + full_key_states = torch.cat((k_out[:, :, 1:, :], key_states), dim=-2) + full_value_states = torch.cat((v_out[:, :, 1:, :], value_states), dim=-2) + elif not is_full and cumulative_length + key_states.shape[2] > max_cache_len: + full_key_states = torch.cat((k_out[:, :, :cumulative_length, :], key_states), dim=-2) + full_value_states = torch.cat((v_out[:, :, :cumulative_length, :], value_states), dim=-2) + else: + self.key_cache[layer_idx].index_copy_(2, cache_position, key_states) + self.value_cache[layer_idx].index_copy_(2, cache_position, value_states) + self.cumulative_length[layer_idx] += key_states.shape[-2] + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + self.key_cache[layer_idx].copy_(full_key_states[:, :, -max_cache_len:, :]) + self.value_cache[layer_idx].copy_(full_value_states[:, :, -max_cache_len:, :]) + self.cumulative_length[layer_idx] += key_states.shape[-2] + # we should return the whole states instead of k_out, v_out to take the whole prompt + # into consideration when building kv cache instead of just throwing away tokens outside of the window + return full_key_states, full_value_states + + def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states + + self.key_cache[layer_idx] = k_out + self.value_cache[layer_idx] = v_out + return k_out, v_out + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if cache_kwargs is None: + cache_kwargs = {} + cache_position = cache_kwargs.get("cache_position") + self.initialise_cache_layer(layer_idx, key_states) + + # These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used + # when the cache is initialized in the forward pass (e.g. Gemma2) + if self.key_cache[layer_idx].device != key_states.device: + self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device) + if self.value_cache[layer_idx].device != value_states.device: + self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device) + + k_out = self.key_cache[layer_idx] + v_out = self.value_cache[layer_idx] + key_states = key_states.to(k_out.dtype) + value_states = value_states.to(v_out.dtype) + + if self.is_sliding[layer_idx]: + update_fn = self._sliding_update + else: + update_fn = self._static_update + + return update_fn( + cache_position, + layer_idx, + key_states, + value_states, + k_out, + v_out, + k_out.shape[2], + ) + + def get_max_cache_shape(self) -> Optional[int]: + return self.max_cache_len + + def get_seq_length(self, layer_idx: Optional[int] = 0): + # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's + # limit the check to the first batch member and head dimension. + # TODO: deprecate this function in favor of `cache_position` + if layer_idx != 0: + raise ValueError( + "`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. " + "Using the `layer_idx` argument is not supported." + ) + if len(self.key_cache) == 0: + return 0 + return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() + + def reset(self): + """Resets the cache values while preserving the objects""" + for layer_idx in range(len(self.key_cache)): + # In-place ops prevent breaking the static address + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() + self.cumulative_length = [0 for _ in range(len(self.cumulative_length))] + + class MambaCache: """ Cache for mamba model which does not have attention mechanism and key value states. diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 808cfbcf93..d5991cae8f 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -801,18 +801,19 @@ class PretrainedConfig(PushToHubMixin): def to_diff_dict(self) -> dict[str, Any]: """ - Removes all attributes from config which correspond to the default config attributes for better readability and - serializes to a Python dictionary. + Removes all attributes from the configuration that correspond to the default config attributes for + better readability, while always retaining the `config` attribute from the class. Serializes to a + Python dictionary. Returns: - `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance, + Dict[str, Any]: Dictionary of all the attributes that make up this configuration instance. """ config_dict = self.to_dict() - # get the default config dict + # Get the default config dict (from a fresh PreTrainedConfig instance) default_config_dict = PretrainedConfig().to_dict() - # get class specific config dict + # Get class-specific config dict if not part of a composition class_config_dict = self.__class__().to_dict() if not self.is_composition else {} serializable_config_dict = {} @@ -847,8 +848,7 @@ class PretrainedConfig(PushToHubMixin): if not isinstance(self.quantization_config, dict) else self.quantization_config ) - - # pop the `_pre_quantization_dtype` as torch.dtypes are not serializable. + # Pop the `_pre_quantization_dtype` as torch.dtypes are not serializable. _ = serializable_config_dict.pop("_pre_quantization_dtype", None) self.dict_torch_dtype_to_str(serializable_config_dict) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index a6b0a72162..c743480d78 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -52,6 +52,7 @@ if is_torch_available(): from ..cache_utils import ( HQQQuantizedCache, HybridCache, + HybridChunkedCache, MambaCache, OffloadedStaticCache, QuantizedCacheConfig, @@ -69,6 +70,7 @@ if is_torch_available(): "offloaded_static": OffloadedStaticCache, "sliding_window": SlidingWindowCache, "hybrid": HybridCache, + "hybrid_chunked": HybridChunkedCache, "mamba": MambaCache, } QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache} @@ -416,6 +418,7 @@ class GenerationConfig(PushToHubMixin): if isinstance(self.cache_config, dict): self.cache_config = cache_config_class.from_dict(self.cache_config) self.return_legacy_cache = kwargs.pop("return_legacy_cache", None) + self.prefill_chunk_size = kwargs.pop("prefill_chunk_size", None) # Parameters for manipulation of the model output logits self.temperature = kwargs.pop("temperature", 1.0) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 232cceeedf..bc00e29ba5 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1830,6 +1830,9 @@ class GenerationMixin: Returns the resulting cache object. """ + if cache_implementation == "hybrid" and "llama4" in getattr(self.config, "model_type", ""): + cache_implementation = "hybrid_chunked" + cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation] requires_cross_attention_cache = ( self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None @@ -3405,7 +3408,12 @@ class GenerationMixin: os.environ["TOKENIZERS_PARALLELISM"] = "0" model_forward = self.get_compiled_call(generation_config.compile_config) - is_prefill = True + if generation_config.prefill_chunk_size is not None: + model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs) + is_prefill = False + else: + is_prefill = True + while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) @@ -4855,6 +4863,45 @@ class GenerationMixin: else: return input_ids + def _prefill_chunking(self, input_ids: torch.LongTensor, generation_config: GenerationConfig, **model_kwargs): + # Even if we are not compiling the forward, flex is always compiled when used. With chunk prefill, we may + # end up needing just a bit more graphs than the default (which is 8). Doing this avoids very cryptic warnings + torch._dynamo.config.cache_size_limit = 64 + + chunk_size = generation_config.prefill_chunk_size + # Only chunk up the token just before last, so that decoding is completely performed outside this function + # (here we simply prefill the cache) + input_chunks = torch.split(input_ids[:, :-1], chunk_size, dim=-1) + + if "past_key_values" not in model_kwargs: + raise ValueError("Cannot use prefill chunkink without a cache") + + model_forward = self.get_compiled_call(generation_config.compile_config) + attention_mask = model_kwargs.pop("attention_mask", None) + + past_length = 0 + for input_chunk in input_chunks: + current_length = past_length + input_chunk.shape[-1] + # Prepare inputs + if attention_mask is not None: + model_kwargs["attention_mask"] = attention_mask[:, :current_length] + model_kwargs["cache_position"] = torch.arange( + past_length, current_length, dtype=torch.long, device=input_chunk.device + ) + model_kwargs["position_ids"] = model_kwargs["cache_position"].unsqueeze(0) + model_inputs = self.prepare_inputs_for_generation(input_chunk, **model_kwargs) + + outputs = model_forward(**model_inputs, return_dict=True) + + model_kwargs["past_key_values"] = outputs.past_key_values + past_length = current_length + + model_kwargs["attention_mask"] = attention_mask + model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1 + _ = model_kwargs.pop("position_ids", None) + + return model_kwargs + def _speculative_sampling( candidate_input_ids, diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index da8c9cd4c6..8d03c5cf79 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -53,7 +53,7 @@ _import_structure = { "unset_hf_deepspeed_config", ], "eetq": ["replace_with_eetq_linear"], - "fbgemm_fp8": ["FbgemmFp8Linear", "replace_with_fbgemm_fp8_linear"], + "fbgemm_fp8": ["FbgemmFp8Linear", "FbgemmFp8Llama4TextExperts", "replace_with_fbgemm_fp8_linear"], "finegrained_fp8": ["FP8Linear", "replace_with_fp8_linear"], "fsdp": ["is_fsdp_managed_module"], "ggml": [ @@ -192,7 +192,7 @@ if TYPE_CHECKING: unset_hf_deepspeed_config, ) from .eetq import replace_with_eetq_linear - from .fbgemm_fp8 import FbgemmFp8Linear, replace_with_fbgemm_fp8_linear + from .fbgemm_fp8 import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts, replace_with_fbgemm_fp8_linear from .finegrained_fp8 import FP8Linear, replace_with_fp8_linear from .fsdp import is_fsdp_managed_module from .ggml import ( diff --git a/src/transformers/integrations/compressed_tensors.py b/src/transformers/integrations/compressed_tensors.py new file mode 100644 index 0000000000..752227914d --- /dev/null +++ b/src/transformers/integrations/compressed_tensors.py @@ -0,0 +1,54 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from transformers.utils import is_torch_available + + +if is_torch_available(): + import torch + import torch.nn as nn + +from transformers.models.llama4.modeling_llama4 import Llama4TextMLP + + +def skip(*args, **kwargs): + pass + + +class CompressedExpertsLinear(nn.Module): + """ + A module that implements a compressed version of a list of expert modules. + This is specifically designed to work with Llama4TextExperts in MoE layers. + """ + + def __init__(self, config): + # Skip random weight initialization for experts. Otherwise, + # the init of this module would take over minutes. For a model + # with tens of layers of experts, it would easily take over 20 minutes. + nn.init.kaiming_uniform_ = skip + nn.init.uniform_ = skip + nn.init.normal_ = skip + super().__init__() + self.num_experts = config.num_local_experts + self.expert_modules = nn.ModuleList([Llama4TextMLP(config) for _ in range(self.num_experts)]) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + hidden_states = hidden_states.reshape(self.num_experts, -1, hidden_states.shape[-1]) + expert_routed_out_list = [] + for expert_idx in range(self.num_experts): + expert_routed_out_list.append(self.expert_modules[expert_idx](hidden_states[expert_idx])) + routed_out = torch.cat(expert_routed_out_list, dim=0) + return routed_out diff --git a/src/transformers/integrations/fbgemm_fp8.py b/src/transformers/integrations/fbgemm_fp8.py index 71c2b570cc..5cca37f515 100644 --- a/src/transformers/integrations/fbgemm_fp8.py +++ b/src/transformers/integrations/fbgemm_fp8.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from ..activations import ACT2FN from ..utils import is_accelerate_available, is_fbgemm_gpu_available, is_torch_available, logging @@ -28,36 +29,36 @@ if is_fbgemm_gpu_available(): logger = logging.get_logger(__name__) -class FbgemmFp8Linear(torch.nn.Module): +class FbgemmFp8Linear(torch.nn.Linear): def __init__(self, in_features, out_features, bias, weight_dtype=torch.float32): - super().__init__() + super().__init__(in_features, out_features, bias) self.in_features = in_features self.out_features = out_features - self.register_buffer("weight", torch.zeros((out_features, in_features), dtype=torch.float8_e4m3fn)) - self.register_buffer("weight_scale", torch.zeros((out_features, 1), dtype=weight_dtype)) + self.weight = torch.nn.Parameter(torch.zeros((out_features, in_features), dtype=torch.float8_e4m3fn)) + self.weight_scale = torch.nn.Parameter(torch.zeros((out_features, 1), dtype=weight_dtype)) self.register_buffer("input_scale_ub", torch.zeros([1], dtype=torch.float), persistent=False) if bias: - self.register_buffer("bias", torch.zeros((self.out_features), dtype=weight_dtype)) + self.bias = torch.nn.Parameter(torch.zeros((self.out_features), dtype=weight_dtype)) else: self.bias = None def forward(self, x): - num_tokens = None # quantize_fp8_per_row will squash the leading dimensions, so save the desired shape here output_shape = (*x.shape[:-1], -1) # x_quantized and x_scale are not necessarily on the same device as x, this is an issue. # https://github.com/pytorch/FBGEMM/blob/e08af8539c391437f447173863df0f3f6f6f1855/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L1237C3-L1237C45 x_quantized, x_scale = torch.ops.fbgemm.quantize_fp8_per_row( - x.view(-1, x.shape[-1]), num_tokens, self.input_scale_ub + x.view(-1, x.shape[-1]), scale_ub=self.input_scale_ub ) # moving x_quantized, x_scale here creates glibberish output ... However, if we move the output, it works # x_quantized, x_scale = x_quantized.to(x.device), x_scale.to(x.device) # The computation still happens on the device where self.weight is even if x_quantized is not on the same device as self.weight + weight_scale_float32 = self.weight_scale.to(torch.float32) output = torch.ops.fbgemm.f8f8bf16_rowwise( - x_quantized, self.weight, x_scale, self.weight_scale, use_fast_accum=True + x_quantized, self.weight, x_scale, weight_scale_float32, use_fast_accum=True ) output = output + self.bias if self.bias is not None else output # Hacky for now, we have the output to the device of x @@ -67,6 +68,92 @@ class FbgemmFp8Linear(torch.nn.Module): return output +class FbgemmFp8Llama4TextExperts(nn.Module): + def __init__(self, config, dtype=torch.float32): + super().__init__() + self.num_experts = config.num_local_experts + self.intermediate_size = config.intermediate_size + self.hidden_size = config.hidden_size + self.expert_dim = self.intermediate_size + self.act_fn = ACT2FN[config.hidden_act] + # Register FP8 buffers for gate_up_proj + self.gate_up_proj = torch.nn.Parameter( + torch.zeros((self.num_experts, self.hidden_size, 2 * self.expert_dim), dtype=torch.float8_e4m3fn) + ) + self.gate_up_proj_scale = torch.nn.Parameter( + torch.zeros((self.num_experts, 1, self.expert_dim * 2), dtype=torch.float32) + ) + # Register FP8 buffers for down_proj + self.down_proj = torch.nn.Parameter( + torch.zeros((self.num_experts, self.expert_dim, self.hidden_size), dtype=torch.float8_e4m3fn) + ) + self.down_proj_scale = torch.nn.Parameter( + torch.zeros((self.num_experts, self.hidden_size, 1), dtype=torch.float32) + ) + # Register input scale upper bound + self.register_buffer("input_scale_ub", torch.zeros([1], dtype=torch.float), persistent=False) + + def forward(self, hidden_states): + """ + Args: + hidden_states (torch.Tensor): (batch_size * token_num, hidden_size) + Returns: + torch.Tensor: (batch_size * token_num, hidden_size) + """ + # Reshape hidden states for expert computation + hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) + num_tokens = None + + # Pre-allocate tensor for all expert outputs with same shape as hidden_states + next_states = torch.empty_like(hidden_states) + + for i in range(self.num_experts): + # Extract expert's hidden states + expert_hidden = hidden_states[i] + expert_hidden_reshaped = expert_hidden.reshape(-1, self.hidden_size) + # Quantize for this expert + expert_quantized, expert_scale = torch.ops.fbgemm.quantize_fp8_per_row( + expert_hidden_reshaped, num_tokens, self.input_scale_ub + ) + sharded_expert_dim = self.gate_up_proj.shape[-1] // 2 + gate_up_proj_scale_float32 = self.gate_up_proj_scale.to(torch.float32) + + gate = torch.ops.fbgemm.f8f8bf16_rowwise( + expert_quantized, + self.gate_up_proj[i].transpose(0, 1)[:sharded_expert_dim].contiguous(), + expert_scale, + gate_up_proj_scale_float32[i][0][:sharded_expert_dim].view(-1, 1).contiguous(), + use_fast_accum=True, + ) + + up = torch.ops.fbgemm.f8f8bf16_rowwise( + expert_quantized, + self.gate_up_proj[i].transpose(0, 1)[sharded_expert_dim:].contiguous(), + expert_scale, + gate_up_proj_scale_float32[i][0][sharded_expert_dim:].view(-1, 1).contiguous(), + use_fast_accum=True, + ) + + activated = up * self.act_fn(gate) + + activated_quantized, activated_scale = torch.ops.fbgemm.quantize_fp8_per_row( + activated, num_tokens, self.input_scale_ub + ) + + down_proj_scale_float32 = self.down_proj_scale.to(torch.float32) + expert_output = torch.ops.fbgemm.f8f8bf16_rowwise( + activated_quantized, + self.down_proj[i].transpose(0, 1).contiguous(), + activated_scale, + down_proj_scale_float32[i].view(-1, 1).contiguous(), + use_fast_accum=True, + ) + + next_states[i] = expert_output + next_states = next_states.to(hidden_states.device) + return next_states.view(-1, self.hidden_size) + + def _replace_with_fbgemm_fp8_linear( model, modules_to_not_convert=None, @@ -74,12 +161,17 @@ def _replace_with_fbgemm_fp8_linear( quantization_config=None, has_been_replaced=False, pre_quantized=False, + config=None, + tp_plan=None, ): """ Private method that wraps the recursion for module replacement. Returns the converted model and a boolean that indicates if the conversion has been successfull or not. """ + + import re + if current_key_name is None: current_key_name = [] @@ -105,9 +197,27 @@ def _replace_with_fbgemm_fp8_linear( # Force requires grad to False to avoid unexpected errors model._modules[name].requires_grad_(False) # set non persistant buffer outside of init_empty_weights + model._modules[name].input_scale_ub = torch.tensor( + [quantization_config.activation_scale_ub], + dtype=torch.float, + ) + if module.__class__.__name__ == "Llama4TextExperts" and name not in modules_to_not_convert: + current_key_name_str = ".".join(current_key_name) + if not any( + (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert + ): + with init_empty_weights(include_buffers=True): + tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".gate_up_proj_scale")] = tp_plan[ + re.sub(r"\d+", "*", current_key_name_str + ".gate_up_proj") + ] + tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".down_proj_scale")] = None + model._modules[name] = FbgemmFp8Llama4TextExperts( + config.text_config, + ) model._modules[name].input_scale_ub = torch.tensor( [quantization_config.activation_scale_ub], dtype=torch.float ) + if len(list(module.children())) > 0: _, has_been_replaced = _replace_with_fbgemm_fp8_linear( module, @@ -116,6 +226,8 @@ def _replace_with_fbgemm_fp8_linear( quantization_config, has_been_replaced=has_been_replaced, pre_quantized=pre_quantized, + config=config, + tp_plan=tp_plan, ) # Remove the last key for recursion current_key_name.pop(-1) @@ -123,7 +235,13 @@ def _replace_with_fbgemm_fp8_linear( def replace_with_fbgemm_fp8_linear( - model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, pre_quantized=False + model, + modules_to_not_convert=None, + current_key_name=None, + quantization_config=None, + pre_quantized=False, + config=None, + tp_plan=None, ): """ A helper function to replace all `torch.nn.Linear` modules by `FbgemmFp8Linear` modules. @@ -151,9 +269,14 @@ def replace_with_fbgemm_fp8_linear( modules_to_not_convert.extend(quantization_config.modules_to_not_convert) modules_to_not_convert = list(set(modules_to_not_convert)) model, has_been_replaced = _replace_with_fbgemm_fp8_linear( - model, modules_to_not_convert, current_key_name, quantization_config, pre_quantized=pre_quantized + model, + modules_to_not_convert, + current_key_name, + quantization_config, + pre_quantized=pre_quantized, + config=config, + tp_plan=tp_plan, ) - if not has_been_replaced: logger.warning( "You are loading your model using FP8 quantization but no linear modules were found in your model." diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index b0a054998c..1aa146e4a4 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -34,10 +34,7 @@ from ..utils import is_torch_flex_attn_available if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import ( - BlockMask, - flex_attention, - ) + from torch.nn.attention.flex_attention import BlockMask, flex_attention from torch.nn.attention.flex_attention import ( create_block_mask as create_block_causal_mask_flex, ) @@ -64,14 +61,23 @@ class WrappedFlexAttention: Initialize or update the singleton instance. """ if self._is_flex_compiled is False: - self._compiled_flex_attention = torch.compile(flex_attention, dynamic=False) + self._compiled_flex_attention = torch.compile(flex_attention, backend="inductor") self._is_flex_compiled = True def __call__(self): return self._compiled_flex_attention -def make_flex_block_causal_mask(attention_mask_2d: torch.Tensor) -> "BlockMask": +Offset = Union[torch.Tensor, int] + + +def make_flex_block_causal_mask( + attention_mask_2d: torch.Tensor, + attention_chunk_size: Optional[int] = None, + query_length=None, + key_length=None, + offsets: Optional[Tuple[Offset, Offset]] = None, +) -> "BlockMask": """ Create a block causal document mask for a batch of sequences, both packed and unpacked. Create Block causal logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`. @@ -94,10 +100,13 @@ def make_flex_block_causal_mask(attention_mask_2d: torch.Tensor) -> "BlockMask": Returns: BlockMask """ + attention_mask_2d = torch.nn.functional.pad(attention_mask_2d, value=0, pad=(0, key_length)) device = attention_mask_2d.device + document_ids = attention_mask_2d.clone() - document_ids = attention_mask_2d - batch_size, total_seq_len = document_ids.shape + if attention_chunk_size is not None: + # we create an arange, then we just // by chunk size to get [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3] + document_ids = (document_ids.fill_(1).cumsum(-1) - 1) // (attention_chunk_size) # Instead of passing a tensor mask, flex attention requires a mask_mod function # that determines which elements of QK^T should be included in the attention @@ -112,18 +121,30 @@ def make_flex_block_causal_mask(attention_mask_2d: torch.Tensor) -> "BlockMask": See :func:`~torchtune.modules.attention_utils.create_block_causal_mask` for an illustration. """ - causal_mask = q_idx >= kv_idx + causal_mask = q_idx >= kv_idx # not valid when decoding document_mask = document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx] - padding_mask = document_ids[batch_idx, q_idx] > 0 - return causal_mask & document_mask & padding_mask + padding_mask = attention_mask_2d[batch_idx, q_idx] > 0 + final_mask = causal_mask & padding_mask & document_mask + return final_mask + if offsets is not None: + q_offset = offsets[0] + kv_offset = offsets[1] + + def mask_mod(batch_idx, head_idx, q_idx, kv_idx): + offset_q = q_idx + q_offset + offset_kv = kv_idx + kv_offset + return causal_mask_mod(batch_idx, head_idx, offset_q, offset_kv) + else: + mask_mod = causal_mask_mod return create_block_causal_mask_flex( - mask_mod=causal_mask_mod, - B=batch_size, + mask_mod=mask_mod, + B=1, H=None, # attention head - Q_LEN=total_seq_len, - KV_LEN=total_seq_len, + Q_LEN=query_length, + KV_LEN=key_length, device=device, + _compile=True, ) @@ -144,6 +165,18 @@ def compile_friendly_flex_attention( ) +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + def flex_attention_forward( module: torch.nn.Module, query: torch.Tensor, @@ -174,14 +207,25 @@ def flex_attention_forward( score = score + head_mask[batch_idx][head_idx][0][0] return score + enable_gqa = True + num_local_query_heads = query.shape[1] + + # When running TP this helps: + if not ((num_local_query_heads & (num_local_query_heads - 1)) == 0): + key = repeat_kv(key, query.shape[1] // key.shape[1]) + value = repeat_kv(value, query.shape[1] // value.shape[1]) + enable_gqa = False + + kernel_options = kwargs.get("kernel_options", None) attn_output, attention_weights = compile_friendly_flex_attention( query, key, value, score_mod=score_mod, block_mask=block_mask, - enable_gqa=True, + enable_gqa=enable_gqa, scale=scaling, + kernel_options=kernel_options, # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless. # For simplification, we thus always return it as no additional computations are introduced. return_lse=True, diff --git a/src/transformers/integrations/sdpa_attention.py b/src/transformers/integrations/sdpa_attention.py index 890832a26e..9c924c048a 100644 --- a/src/transformers/integrations/sdpa_attention.py +++ b/src/transformers/integrations/sdpa_attention.py @@ -31,7 +31,7 @@ def sdpa_attention_forward( value = repeat_kv(value, module.num_key_value_groups) causal_mask = attention_mask - if attention_mask is not None: + if attention_mask is not None and causal_mask.ndim == 4: causal_mask = causal_mask[:, :, :, : key.shape[-2]] # SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index f8a04c037b..34b21444fe 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -61,6 +61,21 @@ def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> Li return [single_size] * blocks +str_to_torch_dtype = { + "BOOL": torch.bool, + "U8": torch.uint8, + "I8": torch.int8, + "I16": torch.int16, + "F16": torch.float16, + "BF16": torch.bfloat16, + "I32": torch.int32, + "F32": torch.float32, + "F64": torch.float64, + "I64": torch.int64, + "F8_E4M3": torch.float8_e4m3fn, +} + + def get_packed_weights(param, empty_param, device_mesh, rank, dim): """ When weights are packed (gate_up_proj), we need to make sure each shard gets its correct share. @@ -106,6 +121,12 @@ def get_packed_weights(param, empty_param, device_mesh, rank, dim): tensors_slices += range(block_offset + start, block_offset + stop) block_offset += block_size + slice_dtype = slice_.get_dtype() + # Handle F8_E4M3 dtype by converting to float16 before slicing + # Without upcasting, the slicing causes : RuntimeError: "index_cpu" not implemented for 'Float8_e4m3fn' + if slice_dtype == "F8_E4M3": + slice_ = slice_[...].to(torch.float16) + if dim == 0: tensor = slice_[tensors_slices, ...] elif dim == 1 or dim == -2: @@ -114,7 +135,7 @@ def get_packed_weights(param, empty_param, device_mesh, rank, dim): tensor = slice_[..., tensors_slices] else: raise ValueError(f"Unsupported dim {dim}, only dim 0, 1 or 2 are supported") - return tensor + return tensor.to(str_to_torch_dtype[slice_dtype]) def get_tensor_shard(param, empty_param, device_mesh, rank, dim): @@ -199,11 +220,12 @@ class GatherParallel(TensorParallelLayer): @staticmethod def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): if isinstance(inputs[0], DTensor): - inputs[0] = inputs[0].to_local() + inputs = inputs[0].to_local() return inputs @staticmethod def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): + # this op cannot be asynch, otherwise it completely breaks the outputs of models torch.distributed.all_reduce(outputs[0], op=torch.distributed.ReduceOp.SUM, async_op=False) return outputs @@ -266,7 +288,7 @@ class ColwiseParallel(TensorParallelLayer): # transform the input layouts to the desired layouts of ColwiseParallel if input_layouts != desired_input_layouts: - input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True) + input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=False) return input_tensor def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): @@ -291,7 +313,7 @@ class ColwiseParallel(TensorParallelLayer): def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): # outputs is a shard on last dimension DTensor, i.e. Shard(-1) if outputs.placements != output_layouts: - outputs = outputs.redistribute(placements=output_layouts, async_op=True) + outputs = outputs.redistribute(placements=output_layouts, async_op=False) # back to local tensor return outputs.to_local() if use_local_output else outputs @@ -343,16 +365,6 @@ class RowwiseParallel(TensorParallelLayer): self.use_local_output = use_local_output self.use_dtensor = use_dtensor - @staticmethod - def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): - input_tensor = inputs[0] - if not isinstance(input_tensor, DTensor): - input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False) - - if input_layouts != desired_input_layouts: - input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True) - return input_tensor - def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): # Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1) # means Rowwise as nn.Linear is input * weight^T + bias, where @@ -371,6 +383,20 @@ class RowwiseParallel(TensorParallelLayer): parameter = DTensor.from_local(parameter, device_mesh, shard, run_check=False) return nn.Parameter(parameter) + @staticmethod + def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): + if hasattr(mod, "bias") and mod.bias is not None: + mod._bias = mod.bias + mod.bias = None + + input_tensor = inputs[0] + if not isinstance(input_tensor, DTensor): + input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False) + + if input_layouts != desired_input_layouts: + input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True) + return input_tensor + @staticmethod def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): # Rowwise sharding produces partial output, depending on output layouts: @@ -378,6 +404,8 @@ class RowwiseParallel(TensorParallelLayer): # 2. to shard -> reduce_scatter if outputs.placements != output_layouts: outputs = outputs.redistribute(placements=output_layouts, async_op=True) + if hasattr(mod, "_bias"): + outputs += mod._bias # back to local tensor if use_local_output is True return outputs.to_local() if use_local_output else outputs @@ -418,6 +446,90 @@ class PackedRowwiseParallel(RowwiseParallel): return nn.Parameter(parameter) +class SequenceParallel(TensorParallelLayer): + """ + SequenceParallel replicates a compatible ``nn.Module`` parameters and runs the sharded computation with + input sharded on the sequence dimension. This currently supports ``nn.LayerNorm``, ``nn.Dropout``, and the + `RMSNorm python implementation `__ + + This style implements the operation that is described in the paper + `Reducing Activation Recomputation in Large Transformer Models `__ + + If the input passed in to this ``nn.Module`` is a :class:`torch.Tensor`, it assumes that the input is already sharded + on the sequence dimension and converts the input to a :class:`DTensor` sharded on the sequence dimension. If the input + passed in to this ``nn.Module`` is already a :class:`DTensor` but is not sharded on the sequence dimension, it would + redistribute the input to be sharded on the sequence dimension. + + The output of the ``nn.Module`` will be sharded on the sequence dimension. + + Keyword Args: + sequence_dim (int, optional): + The sequence dimension of the input tensor for the ``nn.Module``, this is used to annotate the input tensor to + become a DTensor that is sharded on the sequence dimension, default: 1. + use_local_output (bool, optional): + Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: False. + Returns: + A :class:`ParallelStyle` object that represents Sequence Parallel of the ``nn.Module``. + + Example:: + >>> # xdoctest: +SKIP(failing) + >>> from torch.distributed.tensor.parallel import parallelize_module, SequenceParallel + >>> from torch.distributed.device_mesh import init_device_mesh + >>> ... + >>> m = Model(...) # m is a nn.Module that contains a "norm" nn.LayerNorm submodule + >>> tp_mesh = init_device_mesh("cuda", (8,)) + >>> + >>> # By default, the input of the "norm" will be converted to DTensor that shards on the sequence dim + >>> # and the output of "norm" will return a sharded on sequence dimension :class:`DTensor`. + >>> + >>> sharded_mod = parallelize_module(m, tp_mesh, {"norm": SequenceParallel()}), + >>> ... + + .. note:: SequenceParallel style assumes ones initialization if there are weights in the nn.Module (i.e. + ``nn.LayerNorm`` or ``RMSNorm``, and they by default have ones initialization). If you have custom + inits for the weights on those modules, you need to broadcast the weights before/after parallelizing + to ensure that they are replicated. + """ + + def __init__(self, *, sequence_dim: int = 1, use_local_output: bool = False, use_dtensor=False): + super().__init__() + self.input_layouts = (Replicate(),) + self.desired_input_layouts = (Shard(1),) + self.output_layouts = (Replicate(),) + self.use_local_output = use_local_output + self.use_dtensor = True + self.sequence_sharding = (Shard(sequence_dim),) + self.use_local_output = use_local_output + + @staticmethod + def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): + input_tensor = inputs[0] + if not isinstance(input_tensor, DTensor): + input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False) + if input_layouts != desired_input_layouts: + input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True) + return input_tensor + + @staticmethod + def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): + outputs = outputs.redistribute( + placements=(Replicate(),), async_op=True + ) # maybe we have to replicate ? because next layer is not sharded + return outputs.to_local() # if use_local_output else outputs + + def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): + # colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only) + # means Colwise as Linear is input * weight^T + bias, where + # weight would become Shard(1) + parameter = param[:] + parameter = parameter.to(param_casting_dtype) + if to_contiguous: + parameter = parameter.contiguous() + if self.use_dtensor: + parameter = DTensor.from_local(parameter, device_mesh, [Replicate()], run_check=False) + return nn.Parameter(parameter) + + SUPPORTED_TP_STYLES = { "colwise", "rowwise", @@ -428,6 +540,7 @@ SUPPORTED_TP_STYLES = { "local", "gather", "local_packed_rowwise", + "sequence_parallel", } @@ -459,6 +572,8 @@ def translate_to_torch_parallel_style(style: str): return GatherParallel() elif style == "local_packed_rowwise": return PackedRowwiseParallel(use_dtensor=False) + elif style == "sequence_parallel": + return SequenceParallel() else: raise ValueError(f"Unsupported parallel style value: {style}") @@ -518,6 +633,7 @@ def shard_and_distribute_module( tp_plan = model._tp_plan module_to_tp = model.get_submodule(param_name) current_module_plan = None + rank = int(rank) generic_param_name = re.sub(r"\d+", "*", parameter_name) if generic_param_name in tp_plan: current_module_plan = tp_plan[generic_param_name] @@ -531,12 +647,18 @@ def shard_and_distribute_module( module_to_tp._is_hooked = True if current_module_plan is not None: - tp_layer = translate_to_torch_parallel_style(current_module_plan) - param = tp_layer.partition_tensor( - param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh - ) + try: + tp_layer = translate_to_torch_parallel_style(current_module_plan) + param = tp_layer.partition_tensor( + param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh + ) + except NotImplementedError as e: + print( + f"Trying to prepare {parameter_name}, but it's not supported. Corresponding module: {module_to_tp} Fix it's TP plan, current layer: {tp_layer} : {e}" + ) else: # TODO log no plan modules in set + # print("No plan for", parameter_name,end ="\n") param = param[...].to(param_casting_dtype) if is_contiguous: param = param.contiguous() diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 2e0a389d74..6a3286cbc9 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -484,6 +484,7 @@ str_to_torch_dtype = { "F32": torch.float32, "F64": torch.float64, "I64": torch.int64, + "F8_E4M3": torch.float8_e4m3fn, } if is_torch_greater_or_equal("2.1.0"): @@ -1914,16 +1915,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ) # If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config - if self.base_model is self: - self._pp_plan = ( - self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else None - ) - self._tp_plan = self.config.base_model_tp_plan.copy() if self.config.base_model_tp_plan is not None else {} - else: - self._tp_plan = self._tp_plan or {} - for name, module in self.named_children(): - if plan := getattr(module, "_tp_plan", None): - self._tp_plan.update({f"{name}.{k}": v for k, v in plan.items()}) + self._pp_plan = self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else None + self._tp_plan = self.config.base_model_tp_plan.copy() if self.config.base_model_tp_plan is not None else {} + for name, module in self.named_children(): + if plan := getattr(module, "_tp_plan", None): + self._tp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()}) if self._tp_plan is not None and is_torch_greater_or_equal("2.3"): for _, v in self._tp_plan.items(): @@ -4054,6 +4050,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix import sys sys.stdout = open(os.devnull, "w") + sys.stderr = open(os.devnull, "w") # This is the easiest way to dispatch to the current process device device_map = tp_device # Assuming sharding the model onto the world @@ -4238,6 +4235,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ) torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype) device_map = hf_quantizer.update_device_map(device_map) + config = hf_quantizer.update_tp_plan(config) # In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry` if hasattr(hf_quantizer.quantization_config.quant_method, "value"): @@ -4370,9 +4368,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix if hf_quantizer is not None: hf_quantizer.preprocess_model( - model=model, device_map=device_map, keep_in_fp32_modules=model._keep_in_fp32_modules + model=model, device_map=device_map, keep_in_fp32_modules=model._keep_in_fp32_modules, config=config ) - # We store the original dtype for quantized models as we cannot easily retrieve it # once the weights have been quantized # Note that once you have loaded a quantized model, you can't change its dtype so this will @@ -4901,7 +4898,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix name, casting_dtype, to_contiguous, - tp_device.index, + os.environ["RANK"], device_mesh, ) @@ -5174,6 +5171,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix want to use compiled version to avoid recomputing the graph with new shapes) and iterative decoding (where we want the speed-ups of compiled version with static shapes).""" # Only reset it if not present or different from previous config + if "llama4" in self.config.model_type: # TODO try to enable for FULL COMPILE HYBRID CACHE SUPPORT + return self.__call__ default_config = getattr(self.generation_config, "compile_config", CompileConfig()) if ( not hasattr(self, "_compiled_call") diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index b2cde9f4bc..08cec64b41 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -148,6 +148,7 @@ from . import ( levit, lilt, llama, + llama4, llava, llava_next, llava_next_video, diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py index c06017fd76..d21a16beb1 100644 --- a/src/transformers/models/auto/auto_factory.py +++ b/src/transformers/models/auto/auto_factory.py @@ -544,10 +544,6 @@ class _BaseAutoModelClass: if kwargs_orig.get("quantization_config", None) is not None: kwargs["quantization_config"] = kwargs_orig["quantization_config"] - # AutoClass-specific config manipulation - config = copy.deepcopy(config) - config = cls._prepare_config_for_auto_class(config) - has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map has_local_code = type(config) in cls._model_mapping.keys() trust_remote_code = resolve_trust_remote_code( @@ -570,6 +566,8 @@ class _BaseAutoModelClass: ) elif type(config) in cls._model_mapping.keys(): model_class = _get_model_class(config, cls._model_mapping) + if model_class.config_class == config.sub_configs.get("text_config", None): + config = config.get_text_config() return model_class.from_pretrained( pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs ) diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 9937b55a8b..759b7ad3d9 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -170,6 +170,8 @@ CONFIG_MAPPING_NAMES = OrderedDict( ("levit", "LevitConfig"), ("lilt", "LiltConfig"), ("llama", "LlamaConfig"), + ("llama4", "Llama4Config"), + ("llama4_text", "Llama4TextConfig"), ("llava", "LlavaConfig"), ("llava_next", "LlavaNextConfig"), ("llava_next_video", "LlavaNextVideoConfig"), @@ -519,6 +521,8 @@ MODEL_NAMES_MAPPING = OrderedDict( ("llama", "LLaMA"), ("llama2", "Llama2"), ("llama3", "Llama3"), + ("llama4", "Llama4"), + ("llama4_text", "Llama4ForCausalLM"), ("llava", "LLaVa"), ("llava_next", "LLaVA-NeXT"), ("llava_next_video", "LLaVa-NeXT-Video"), @@ -776,6 +780,7 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict( ("rt_detr_resnet", "rt_detr"), ("granitevision", "llava_next"), ("sam_vision_model", "sam"), + ("llama4_text", "llama4"), ] ) diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 7cfda72a93..2f9d42fcdb 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -104,6 +104,7 @@ else: ("layoutlmv2", ("LayoutLMv2ImageProcessor",)), ("layoutlmv3", ("LayoutLMv3ImageProcessor",)), ("levit", ("LevitImageProcessor",)), + ("llama4", ("Llama4ImageProcessor", "Llama4ImageProcessorFast")), ("llava", ("LlavaImageProcessor", "LlavaImageProcessorFast")), ("llava_next", ("LlavaNextImageProcessor", "LlavaNextImageProcessorFast")), ("llava_next_video", ("LlavaNextVideoImageProcessor",)), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 5df90ff5a8..d33d0f20d5 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -17,7 +17,6 @@ import warnings from collections import OrderedDict -from ...configuration_utils import PretrainedConfig from ...utils import logging from .auto_factory import ( _BaseAutoBackboneClass, @@ -161,6 +160,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ("levit", "LevitModel"), ("lilt", "LiltModel"), ("llama", "LlamaModel"), + ("llama4", "Llama4ForConditionalGeneration"), ("longformer", "LongformerModel"), ("longt5", "LongT5Model"), ("luke", "LukeModel"), @@ -547,6 +547,8 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ("jamba", "JambaForCausalLM"), ("jetmoe", "JetMoeForCausalLM"), ("llama", "LlamaForCausalLM"), + ("llama4", "Llama4ForCausalLM"), + ("llama4_text", "Llama4ForCausalLM"), ("mamba", "MambaForCausalLM"), ("mamba2", "Mamba2ForCausalLM"), ("marian", "MarianForCausalLM"), @@ -634,6 +636,7 @@ MODEL_FOR_IMAGE_MAPPING_NAMES = OrderedDict( ("ijepa", "IJepaModel"), ("imagegpt", "ImageGPTModel"), ("levit", "LevitModel"), + ("llama4", "Llama4VisionModel"), ("mllama", "MllamaVisionModel"), ("mobilenet_v1", "MobileNetV1Model"), ("mobilenet_v2", "MobileNetV2Model"), @@ -849,6 +852,7 @@ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict( ("idefics3", "Idefics3ForConditionalGeneration"), ("instructblip", "InstructBlipForConditionalGeneration"), ("kosmos-2", "Kosmos2ForConditionalGeneration"), + ("llama4", "Llama4ForConditionalGeneration"), ("llava", "LlavaForConditionalGeneration"), ("llava_next", "LlavaNextForConditionalGeneration"), ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), @@ -1492,6 +1496,7 @@ MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict( ("emu3", "Emu3TextModel"), ("flaubert", "FlaubertModel"), ("ibert", "IBertModel"), + ("llama4", "Llama4TextModel"), ("longformer", "LongformerModel"), ("mllama", "MllamaTextModel"), ("mobilebert", "MobileBertModel"), @@ -1678,30 +1683,6 @@ _AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="languag class AutoModelForCausalLM(_BaseAutoModelClass): _model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING - @classmethod - def _prepare_config_for_auto_class(cls, config: PretrainedConfig) -> PretrainedConfig: - """ - Additional autoclass-specific config post-loading manipulation. In this specific autoclass, if the config has - a nested text decoder section, uses that section instead. - - Under the hood, multimodal models mapped by AutoModelForCausalLM assume the text decoder receives its own - config, rather than the config for the whole model. This is used e.g. to load the text-only part of a VLM. - """ - possible_text_config_names = ("decoder", "generator", "text_config") - text_config_names = [] - for text_config_name in possible_text_config_names: - if hasattr(config, text_config_name): - text_config_names += [text_config_name] - - text_config = config.get_text_config(decoder=True) - if text_config_names and type(text_config) in cls._model_mapping.keys(): - warnings.warn( - "Loading a multimodal model with `AutoModelForCausalLM` is deprecated and will be removed in v5. " - "`AutoModelForCausalLM` will be used to load only the text-to-text generation module.", - FutureWarning, - ) - return config - AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling") diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 48081b9df8..6e655edff1 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -77,6 +77,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict( ("kosmos-2", "Kosmos2Processor"), ("layoutlmv2", "LayoutLMv2Processor"), ("layoutlmv3", "LayoutLMv3Processor"), + ("llama4", "Llama4Processor"), ("llava", "LlavaProcessor"), ("llava_next", "LlavaNextProcessor"), ("llava_next_video", "LlavaNextVideoProcessor"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 42ff109cca..eb54d95ab6 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -292,6 +292,20 @@ else: "LlamaTokenizerFast" if is_tokenizers_available() else None, ), ), + ( + "llama4", + ( + "LlamaTokenizer" if is_sentencepiece_available() else None, + "LlamaTokenizerFast" if is_tokenizers_available() else None, + ), + ), + ( + "llama4_text", + ( + "LlamaTokenizer" if is_sentencepiece_available() else None, + "LlamaTokenizerFast" if is_tokenizers_available() else None, + ), + ), ("llava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), ("llava_next", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), ("llava_next_video", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), diff --git a/src/transformers/models/llama4/__init__.py b/src/transformers/models/llama4/__init__.py new file mode 100644 index 0000000000..59fe1686ce --- /dev/null +++ b/src/transformers/models/llama4/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_llama4 import * + from .image_processing_llama4_fast import * + from .modeling_llama4 import * + from .processing_llama4 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/llama4/configuration_llama4.py b/src/transformers/models/llama4/configuration_llama4.py new file mode 100644 index 0000000000..1c4c00f48f --- /dev/null +++ b/src/transformers/models/llama4/configuration_llama4.py @@ -0,0 +1,432 @@ +# coding=utf-8 +# Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class Llama4VisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Llama4VisionModel`]. It is used to instantiate a + Llama4 vision model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Llama4 109B. + + e.g. [meta-llama/Llama-4-Scout-17B-16E](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + num_hidden_layers (`int`, *optional*, defaults to 34): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input image. + intermediate_size (`int`, *optional*, defaults to 5632): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + vision_output_dim (`int`, *optional*, defaults to 7680): + Dimensionality of the vision model output. Includes output of transformer + encoder with intermediate layers and global transformer encoder. + image_size (`int`, *optional*, defaults to 448): + The size (resolution) of each image *tile*. + patch_size (`int`, *optional*, defaults to 14): + The size (resolution) of each patch. + norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + vision_feature_layer (``, *optional*, defaults to -1): TODO + vision_feature_select_strategy (`int`, *optional*, defaults to `"default"`): TODO + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + pixel_shuffle_ratio (`int`, *optional*, defaults to 0.5): TODO + projector_input_dim (`int`, *optional*, defaults to 4096): TODO + projector_output_dim (`int`, *optional*, defaults to 4096): TODO + multi_modal_projector_bias (`int`, *optional*, defaults to `False`): TODO + projector_dropout (`int`, *optional*, defaults to 0.0): TODO + attention_dropout (`int`, *optional*, defaults to 0.0): TODO + rope_theta (`int`, *optional*, defaults to 10000): TODO + """ + + base_model_tp_plan = { + "model.layers.*.self_attn.q_proj": "colwise", + "model.layers.*.self_attn.k_proj": "colwise", + "model.layers.*.self_attn.v_proj": "colwise", + "model.layers.*.self_attn.o_proj": "rowwise", + "vision_adapter.mlp.fc1": "colwise", + "vision_adapter.mlp.fc2": "rowwise", + "patch_embedding.linear": "colwise_rep", + } + model_type = "llama4_vision_model" + base_config_key = "vision_config" + + def __init__( + self, + hidden_size: int = 768, + hidden_act: str = "gelu", + num_hidden_layers: int = 34, + num_attention_heads: int = 16, + num_channels: int = 3, + intermediate_size: int = 5632, + vision_output_dim: int = 7680, + image_size: int = 448, + patch_size: int = 14, + norm_eps: float = 1e-5, + vision_feature_layer=-1, + vision_feature_select_strategy="default", + initializer_range: float = 0.02, + pixel_shuffle_ratio=0.5, + projector_input_dim=4096, + projector_output_dim=4096, + multi_modal_projector_bias=False, + projector_dropout=0.0, + attention_dropout=0.0, + rope_theta=10000, + **kwargs, + ): + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.num_hidden_layers = num_hidden_layers + self.num_channels = num_channels + self.intermediate_size = intermediate_size + self.image_size = image_size + self.vision_output_dim = vision_output_dim + self.patch_size = patch_size + self.norm_eps = norm_eps + self.num_attention_heads = num_attention_heads + self.initializer_range = initializer_range + self.pixel_shuffle_ratio = pixel_shuffle_ratio + self.projector_input_dim = projector_input_dim + self.projector_output_dim = projector_output_dim + self.multi_modal_projector_bias = multi_modal_projector_bias + self.projector_dropout = projector_dropout + self.attention_dropout = attention_dropout + self.vision_feature_layer = vision_feature_layer + self.vision_feature_select_strategy = vision_feature_select_strategy + self.rope_theta = rope_theta + super().__init__(**kwargs) + + +class Llama4TextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Llama4TextModel`]. It is used to instantiate a + Llama4 text model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Llama4 109B. + + e.g. [meta-llama/Llama-4-Scout-17B-16E](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 202048): + Vocabulary size of the Llama4 text model. Defines the maximum number of different tokens that can be represented + by the `inputs_ids` passed when calling [`Llama4TextModel`]. + hidden_size (`int`, *optional*, defaults to 5120): + Dimensionality of the embeddings and hidden states. + intermediate_size (`int`, *optional*, defaults to 8192): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + intermediate_size_mlp (`int`, *optional*, defaults to 16384): TODO + num_hidden_layers (`int`, *optional*, defaults to 48): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 40): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If not + specified, will default to `num_attention_heads`. + head_dim (`int`, *optional*, defaults to 128): TODO + hidden_act (`str` or `Callable`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the encoder and pooler. + max_position_embeddings (`int`, *optional*, defaults to 131072): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions. + pad_token_id (`int`, *optional*, defaults to 128004): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the beginning of sentence token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the end of sentence token. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to `500000.0`): + The base period of the RoPE embeddings. + attention_dropout (`int`, *optional*, defaults to 0.0): TODO + num_experts_per_tok (`int`, *optional*, defaults to 1): TODO + num_local_experts (`int`, *optional*, defaults to 16): TODO + moe_layers (`int`, *optional*): TODO + interleave_moe_layer_step (`int`, *optional*, defaults to 1): TODO + use_qk_norm (`int`, *optional*, defaults to `True`): TODO + output_router_logits (`int`, *optional*, defaults to `False`): TODO + router_aux_loss_coef (`int`, *optional*, defaults to 0.001): TODO + router_jitter_noise (`int`, *optional*, defaults to 0.0): TODO + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + + + no_rope_layers (`int`, *optional*): TODO + no_rope_layer_interval (`int`, *optional*, defaults to 4): TODO + attention_chunk_size (`int`, *optional*, defaults to 8192): + + attn_temperature_tuning (`int`, *optional*, defaults to 4): TODO + floor_scale (`int`, *optional*, defaults to 8192): TODO + attn_scale (`int`, *optional*, defaults to 0.1): TODO + + Example: + """ + + model_type = "llama4_text" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.input_layernorm.weight": "sequence_parallel", + "layers.*.post_attention_layernorm.weight": "sequence_parallel", + "norm.weight": "sequence_parallel", + "layers.*.feed_forward.shared_expert.gate_proj": "local_colwise", + "layers.*.feed_forward.shared_expert.up_proj": "local_colwise", + "layers.*.feed_forward.shared_expert.down_proj": "local_rowwise", + "layers.*.feed_forward.experts.gate_up_proj": "local_packed_rowwise", # row because not linear + "layers.*.feed_forward.experts.down_proj": "local_colwise", # col because not linear + "layers.*.feed_forward.experts": "local", + "layers.*.feed_forward.gate_proj": "local_colwise", + "layers.*.feed_forward.up_proj": "local_colwise", + "layers.*.feed_forward.down_proj": "local_rowwise", + "layers.*.feed_forward": "gather", + } + + def __init__( + self, + vocab_size=202048, + hidden_size=5120, + intermediate_size=8192, + intermediate_size_mlp=16384, + num_hidden_layers=48, + num_attention_heads=40, + num_key_value_heads=8, + head_dim=128, + hidden_act="silu", + max_position_embeddings=4096 * 32, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=500000, + attention_dropout=0.0, + num_experts_per_tok=1, + num_local_experts=16, + moe_layers=None, + interleave_moe_layer_step=1, + use_qk_norm=True, + output_router_logits=False, + router_aux_loss_coef=0.001, + router_jitter_noise=0.0, + rope_scaling=None, + no_rope_layers=None, + no_rope_layer_interval=4, + attention_chunk_size=8192, + attn_temperature_tuning=4, + floor_scale=8192, + attn_scale=0.1, + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.attn_temperature_tuning = attn_temperature_tuning + self.attn_scale = attn_scale + self.floor_scale = floor_scale + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.intermediate_size_mlp = intermediate_size_mlp + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.rope_scaling = rope_scaling + self.attention_bias = False + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads + self.use_qk_norm = use_qk_norm + + self.num_experts_per_tok = num_experts_per_tok + self.num_local_experts = num_local_experts + + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.router_jitter_noise = router_jitter_noise + default_no_rope_layers = [ + int((layer_idx + 1) % no_rope_layer_interval != 0) for layer_idx in range(self.num_hidden_layers) + ] + + # no_rope_layers == [] is invalid as we cannot have 0 layers + self.no_rope_layers = no_rope_layers if no_rope_layers else default_no_rope_layers + + self.interleave_moe_layer_step = interleave_moe_layer_step + self.moe_layers = ( + moe_layers + if moe_layers is not None + else list(range(interleave_moe_layer_step - 1, num_hidden_layers, interleave_moe_layer_step)) + ) + self.attention_chunk_size = attention_chunk_size + + +class Llama4Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Llama4Model`]. It is used to instantiate an + Llama4 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Llama4 109B. + + e.g. [meta-llama/Llama-4-Scout-17B-16E](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vision_config (`Llama4VisionConfig`, *optional*): + The Llama4 Vision config. + text_config (`Llama4TextConfig`, *optional*): + The Llama4 Text config. + boi_token_index (`int`, *optional*, defaults to 200080): + The begin-of-image token index to wrap the image prompt. + eoi_token_index (`int`, *optional*, defaults to 200081): + The end-of-image token index to wrap the image prompt. + image_token_index (`int`, *optional*, defaults to 200092): + The image token index to encode the image prompt. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + + ```python + >>> from transformers import Llama4Model, Llama4Config + + >>> # Initializing a Llama4 7B style configuration + >>> configuration = Llama4Config() + + >>> # Initializing a model from the Llama4 7B style configuration + >>> model = Llama4Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "llama4" + sub_configs = {"text_config": Llama4TextConfig, "vision_config": Llama4VisionConfig} + base_model_tp_plan = { + "multi_modal_projector.linear_1": "colwise_rep", + } + + def __init__( + self, + vision_config=None, + text_config=None, + boi_token_index=200080, + eoi_token_index=200081, + image_token_index=200092, + tie_word_embeddings=False, + **kwargs, + ): + if vision_config is None: + self.vision_config = Llama4VisionConfig() + logger.info("vision_config is None, using default llama4 vision config") + elif isinstance(vision_config, dict): + self.vision_config = Llama4VisionConfig(**vision_config) + elif isinstance(vision_config, Llama4VisionConfig): + self.vision_config = vision_config + + self.boi_token_index = boi_token_index + self.eoi_token_index = eoi_token_index + self.image_token_index = image_token_index + + if text_config is None: + self.text_config = Llama4TextConfig() + logger.info("text_config is None, using default llama4 text config") + elif isinstance(text_config, dict): + self.text_config = Llama4TextConfig(**text_config) + elif isinstance(text_config, Llama4TextConfig): + self.text_config = text_config + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + + +__all__ = ["Llama4Config", "Llama4TextConfig", "Llama4VisionConfig"] diff --git a/src/transformers/models/llama4/convert_llama4_weights_to_hf.py b/src/transformers/models/llama4/convert_llama4_weights_to_hf.py new file mode 100644 index 0000000000..d1cd6b1993 --- /dev/null +++ b/src/transformers/models/llama4/convert_llama4_weights_to_hf.py @@ -0,0 +1,736 @@ +import argparse +import gc +import io +import json +import os +import re +from typing import List, Optional + +import torch +from tokenizers import AddedToken, processors +from tqdm import tqdm + +from transformers import ( + GenerationConfig, + Llama4Config, + Llama4ForConditionalGeneration, + Llama4ImageProcessorFast, + Llama4Processor, + Llama4TextConfig, + Llama4VisionConfig, + PreTrainedTokenizerFast, +) +from transformers.integrations.tiktoken import TikTokenConverter + + +_OFFLINE_QUANT_COMPATIBLE = os.environ.get("OFFLINE_QUANT_COMPATIBLE", "0") == "1" + +torch.serialization.add_safe_globals([io.BytesIO]) +# fmt: off +# `None` means we drop the key + + +weight_postfix = ".weight" if _OFFLINE_QUANT_COMPATIBLE else "" +ORIGINAL_TO_CONVERTED_KEY_MAPPING = { + # CausalLM keys + r"output.weight": r"language_model.lm_head.weight", + r"\nnorm.weight": r"\nlanguage_model.model.norm.weight", + # Model keys + r"tok_embeddings.weight": r"language_model.model.embed_tokens.weight", + r"freq_cis": None, + r"rope.freqs": None, + r"layers.(\d+).attention_norm.weight": r"language_model.model.layers.\1.input_layernorm.weight", + r"layers.(\d+).attention.wqkv.layer_norm_weight": r"language_model.model.layers.\1.input_layernorm.weight", + r"layers.(\d+).feed_forward.norm.weight": r"language_model.model.layers.\1.post_attention_layernorm.weight", + r"layers.(\d+).attention.wo.weight": r"language_model.model.layers.\1.self_attn.o_proj.weight", + r"layers.(\d+).attention.wqkv.weight": r"language_model.model.layers.\1.self_attn.qkv_proj.weight", + + # MoE keys: no simple MLPmodel. + r"layers.(\d+).feed_forward.experts.moe_w_in_eD_F": r"language_model.model.layers.\1.feed_forward.experts.gate_proj" + weight_postfix, # will be fused with up + r"layers.(\d+).feed_forward.experts.moe_w_out_eF_D": r"language_model.model.layers.\1.feed_forward.experts.down_proj" + weight_postfix, # expert win + r"layers.(\d+).feed_forward.experts.moe_w_swiglu_eD_F": r"language_model.model.layers.\1.feed_forward.experts.up_proj" + weight_postfix, # fused with up + r"layers.(\d+).feed_forward.router_DE": r"language_model.model.layers.\1.feed_forward.router.weight", # used for top + r"layers.(\d+).feed_forward.w_in_shared_FD": r"language_model.model.layers.\1.feed_forward.shared_expert.gate_proj", # might need to be fused for efficiency? + r"layers.(\d+).feed_forward.w_out_shared_DF": r"language_model.model.layers.\1.feed_forward.shared_expert.down_proj", # might need to be fused for efficiency? + r"layers.(\d+).feed_forward.w_swiglu_FD": r"language_model.model.layers.\1.feed_forward.shared_expert.up_proj", # might need to be fused for efficiency? + r"layers.(\d+).feed_forward.global_gate_stats_3E": None, + # Unused keys in load hooks (explicitly removed) + r'layers.(\d+).attention.wqkv._extra_state': None, + r'layers.(\d+).attention.wo._extra_state': None, + # Key apparently unused in base models + r'layers.(\d+).feed_forward.expert_activation_DE': None, + + # MLP layer variant + r"layers.(\d+).feed_forward.w1.weight": r"language_model.model.layers.\1.feed_forward.gate_proj.weight", # might need to be fused for efficiency? + r"layers.(\d+).feed_forward.w3.weight": r"language_model.model.layers.\1.feed_forward.up_proj.weight", # might need to be fused for efficiency? + # r"layers.(\d+).feed_forward.mlp.fc1_weight": r"language_model.model.layers.\1.feed_forward.gate_up_proj.weight", + r"layers.(\d+).feed_forward.mlp.fc2_weight": r"language_model.model.layers.\1.feed_forward.down_proj.weight", + r"layers.(\d+).feed_forward.mlp.layer_norm.weight": r"language_model.model.layers.\1.post_attention_layernorm.weight", + + # Vision encoder mapping + r"vision_embeddings.vision_encoder.conv1._linear": r"vision_model.patch_embedding.linear", + r'vision_embeddings.vision_adapter.mlp.c_fc': r"vision_model.vision_adapter.mlp.fc1", + r'vision_embeddings.vision_adapter.mlp.c_proj': r"vision_model.vision_adapter.mlp.fc2", + r"vision_embeddings.vision_encoder.transformer.resblocks.(\d+).attn.wq.(weight|bias)": r"vision_model.model.layers.\1.self_attn.q_proj.\2", + r"vision_embeddings.vision_encoder.transformer.resblocks.(\d+).attn.wk.(weight|bias)": r"vision_model.model.layers.\1.self_attn.k_proj.\2", + r"vision_embeddings.vision_encoder.transformer.resblocks.(\d+).attn.wv.(weight|bias)": r"vision_model.model.layers.\1.self_attn.v_proj.\2", + r"vision_embeddings.vision_encoder.transformer.resblocks.(\d+).attn.wo.(weight|bias)": r"vision_model.model.layers.\1.self_attn.o_proj.\2", + r"vision_embeddings.vision_encoder.transformer.resblocks.(\d+).mlp.c_fc": r"vision_model.model.layers.\1.mlp.fc1", + r"vision_embeddings.vision_encoder.transformer.resblocks.(\d+).mlp.c_proj": r"vision_model.model.layers.\1.mlp.fc2", + r"vision_embeddings.vision_encoder.transformer.resblocks.(\d+).ln_1.(weight|bias)": r"vision_model.model.layers.\1.input_layernorm.\2", + r"vision_embeddings.vision_encoder.transformer.resblocks.(\d+).ln_2.(weight|bias)": r"vision_model.model.layers.\1.post_attention_layernorm.\2", + # r'vision_embeddings.vision_encoder.ln_(1|2).(weight|bias)': r'vision_model.transformer.vision_encoder.layernorm_\1.\2', + r'vision_embeddings.vision_encoder.ln_post': r'vision_model.layernorm_post', + r'vision_embeddings.vision_encoder.ln_pre': r'vision_model.layernorm_pre', + r'vision_embeddings.vision_encoder.class_embedding': r'vision_model.class_embedding', + r"vision_embeddings.vision_encoder.positional_embedding_vlm": r"vision_model.positional_embedding_vlm", + r"vision_embeddings.vision_encoder.(?=\w)": r"vision_model.model.", + r"vision_projection.weight": r"multi_modal_projector.linear_1.weight", +} +# fmt: on + + +def convert_old_keys_to_new_keys(state_dict_keys: dict = None): + """ + This function should be applied only once, on the concatenated keys to efficiently rename using + the key mappings. + """ + output_dict = {} + if state_dict_keys is not None: + old_text = "\n".join(state_dict_keys) + new_text = old_text + for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items(): + if replacement is None: + new_text = re.sub(pattern, "", new_text) # an empty line + continue + new_text = re.sub(pattern, replacement, new_text) + output_dict = dict(zip(old_text.split("\n"), new_text.split("\n"))) + return output_dict + + +def permute_for_rope(input_tensor, n_heads, dim1, dim2): + """ + When you go from the complex ROPE formulation to sin and cos one, you need + to permute the query and key weights (to avoid doing it on the fly) + """ + input_tensor = input_tensor.view(n_heads, dim1 // n_heads // 2, 2, dim2) + input_tensor = input_tensor.transpose(1, 2).reshape(dim1, dim2) + return input_tensor + + +def is_param_same_across_shards(key): + """ + Return `False` if the parameter is different across checkpoint shards + and needs to be concatenated. + """ + patterns = [ + r"language_model.layers.(\d+).(.*)layernorm.weight", + r"language_model.norm.weight", + r"router.weight", + r"feed_forward.global_gate_stats", + # not all vision weights are sharded, some are repeated + r"vision_model.class_embedding", + r"vision_model.positional_embedding_vlm", + r"vision_embeddings.vision_encoder.positional_embedding_vlm", + r"vision_model.model.layers.(\d+).self_attn.o_proj.bias", + r"vision_model.model.layers.(\d+).input_layernorm", + r"vision_model.model.layers.(\d+).post_attention_layernorm", + r"vision_model.layernorm_pre", + r"vision_model.layernorm_post", + r"vision_model.model.layers.(\d+).mlp.fc2.bias", + r"norm.weight", + ] # fmt: skip + return any(re.search(pattern, key) for pattern in patterns) + + +def get_concat_dim(key): + """ + Return the dimension to concatenate the weights on. + """ + concat_dim_1 = [ + # language dim 1 sharded weights + "feed_forward.router.weight", + "self_attn.o_proj", + "experts.gate_proj", + "experts.up_proj", + "expert.down_proj", + # "feed_forward.up_proj", + # "feed_forward.gate_proj", + "feed_forward.down_proj", + "global_gate_stats", + # vision dim1 sharded stuff + "mlp.fc2.weight", # covers all rowparallels across vis + ] # fmt: off + if any(re.search(pattern, key) for pattern in concat_dim_1): + return 1 + return 0 + + +def compute_intermediate_size(hidden_dim, multiple_of=1024, ffn_dim_multiplier=1.3): + hidden_dim = 4 * int(2 * hidden_dim / 3) + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + return hidden_dim + + +# Ignore extra info - h/t Aritra +def safe_load(filename): + # Can use weights_only because io.BytesIO was registered, but we still need to skip those objects + shard = torch.load(filename, weights_only=True, map_location="cpu", mmap=True) + shard = {k: v for k, v in shard.items() if not isinstance(v, io.BytesIO)} + return shard + + +# Unpack mlp projections - possibly to be removed when they are fused +def preprocess_keys(state_dict): + new_state_dict = {} + for key, value in state_dict.items(): + if "mlp.fc1_weight" in key: + prefix = key.split("mlp.fc1_weight")[0] + w1, w3 = value.chunk(2, dim=0) + new_state_dict[prefix + "w1.weight"] = w1 + new_state_dict[prefix + "w3.weight"] = w3 + else: + new_state_dict[key] = value + return new_state_dict + + +def max_context_length(model_path, instruct=False): + """256K for base, 1M for 128E instruct, 10M for 16E instruct.""" + if not instruct: + return 256 * 1024 + + with open(os.path.join(model_path, "params.json"), "r") as f: + params = json.load(f) + params = params.get("model", params) + num_experts = params["moe_args"]["num_experts"] + return 10485760 if num_experts == 16 else 1048576 + + +def write_model( + model_path, + input_base_path, + num_shards, + convert_checkpoints, + safe_serialization=True, + instruct=False, +): + os.makedirs(model_path, exist_ok=True) + + with open(os.path.join(input_base_path, "params.json"), "r") as f: + params = json.load(f) + + params = params.get("model", params) + torch_dtype = "bfloat16" + + # ------------------------------------------------------------ + # Text model params and config + # ------------------------------------------------------------ + + # params from config + vocab_size = 202048 # params["vocab_size"] # seems like the lm head is 25256 so padded instead of 202048 + num_layers = params["n_layers"] + dim = params["dim"] + num_heads = params["n_heads"] + rms_norm_eps = params["norm_eps"] + rope_theta = params["rope_theta"] + no_rope_layer_interval = params["nope_layer_interval"] + attention_chunk_size = params["attention_chunk_size"] + + config_kwargs = {} + if params["use_scaled_rope"]: + # some constans from original code + rope_scaling = { + "rope_type": "llama3", + "factor": 8.0, + "low_freq_factor": 1.0, + "high_freq_factor": 4.0, + "original_max_position_embeddings": 8192, + } + config_kwargs.update({"rope_scaling": rope_scaling}) + + # compute additional params for weight conversion + num_heads_per_shard = num_heads // num_shards + dim_per_head = dim // num_heads + # intermediate_size = compute_intermediate_size(dim, multiple_of=params["multiple_of"]) + + num_key_value_heads = params["n_kv_heads"] # for GQA / MQA + + num_experts = params["moe_args"]["num_experts"] + interleave_moe_layer_step = params["moe_args"].get("interleave_moe_layer_step", 1) + + no_rope_layer_interval = params["nope_layer_interval"] + + bos_token_id = 200000 + eos_token_id = [200001, 200007, 200008] if instruct else 200001 + pad_token_id = 200018 + + text_config = Llama4TextConfig( + num_attention_heads=num_heads, + vocab_size=vocab_size, + hidden_size=dim, + rms_norm_eps=rms_norm_eps, + rope_theta=rope_theta, + num_hidden_layers=num_layers, + intermediate_size=8192, + intermediate_size_mlp=16384, + max_position_embeddings=max_context_length(input_base_path, instruct), + num_local_experts=num_experts, + interleave_moe_layer_step=interleave_moe_layer_step, + use_qk_norm=params["use_qk_norm"], + no_rope_layer_interval=no_rope_layer_interval, + attention_chunk_size=attention_chunk_size, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + tie_word_embeddings=False, # Constant set to False + torch_dtype=torch_dtype, + for_llm_compressor=_OFFLINE_QUANT_COMPATIBLE, + no_rope_layers=no_rope_layer_interval, + **config_kwargs, + ) + # default vision config frmo params + + vision_params = params["vision_args"] + vision_dim = vision_params["dim"] + vision_num_layers = vision_params["n_layers"] + image_size = vision_params["image_size"]["height"] # siglip config is outdated + vision_num_heads = vision_params["n_heads"] + + vision_output_dim = vision_params["output_dim"] + + vision_config = Llama4VisionConfig( + hidden_act="gelu", + num_hidden_layers=vision_num_layers, + image_size=image_size, + num_attention_heads=vision_num_heads, + hidden_size=vision_dim, + vision_output_dim=vision_output_dim, + ) + + config = Llama4Config(text_config=text_config, vision_config=vision_config) + config.save_pretrained(model_path) + + print("Model config saved successfully...") + + # ------------------------------------------------------------ + # Convert weights + # ------------------------------------------------------------ + + if convert_checkpoints: + print(f"Fetching all parameters from the checkpoint at {input_base_path}...") + if num_shards == 1: + if os.path.exists(os.path.join(input_base_path, "consolidated.00.pth")): + path = os.path.join(input_base_path, "consolidated.00.pth") + else: + path = os.path.join(input_base_path, "consolidated.pth") + loaded = [safe_load(path)] + else: + loaded = [ + safe_load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth")) + for i in tqdm(range(num_shards), desc="Loading shards", unit="shard") + ] + loaded = [preprocess_keys(d) for d in loaded] + + all_keys_raw = list(loaded[0].keys()) + repeated_keys = [] + sharded_keys = [] + for _key in all_keys_raw: + try: + if (loaded[0][_key] == loaded[1][_key]).all(): + repeated_keys.append(_key) + else: + sharded_keys.append(_key) + except Exception as e: + print(f"Encountered exception {e} for {_key}") + print("Initializing an empty model") + with torch.device("meta"): + model = Llama4ForConditionalGeneration(config) + + print("Converting model...") + all_keys = list(loaded[0].keys()) + new_keys = convert_old_keys_to_new_keys(all_keys) + state_dict = {} + replicated_params = [] # To keep track of replicated weights. + for key in tqdm(all_keys, desc="Renaming and processing all keys", unit="key"): + new_key = new_keys[key] + print(key, new_key) + if not is_param_same_across_shards(new_key): + current_parameter = [chunk.pop(key) for chunk in loaded if not isinstance(chunk[key], io.BytesIO)] + else: + print(f"{key} (now {new_key}) is the same across all shards.") + replicated_params.append((key, new_key)) + current_parameter = [loaded[0].pop(key)] if not isinstance(loaded[0][key], io.BytesIO) else [] + + if "running_gate_stats_3E" in key: + new_keys.pop(new_key) + continue + + concat_dim = get_concat_dim(new_key) + + # Post-process the current_parameter. + if "qkv_proj" in new_key: + queries = [] + keys = [] + values = [] + for param in current_parameter: + query, key_, value = param.split( + [ + num_heads * dim_per_head // num_shards, + num_key_value_heads * dim_per_head // num_shards, + num_key_value_heads * dim_per_head // num_shards, + ] + ) + queries.append(query.reshape(num_heads_per_shard, -1, dim)) + keys.append(key_.reshape(num_key_value_heads // num_shards, -1, dim)) + values.append(value.reshape(num_key_value_heads // num_shards, -1, dim)) + + queries = torch.cat(queries, dim=0).reshape(dim, dim) + keys = torch.cat(keys, dim=0).reshape(num_key_value_heads * dim_per_head, dim) + values = torch.cat(values, dim=0).reshape(num_key_value_heads * dim_per_head, dim) + # queries = permute_for_rope(queries, num_heads, dim, dim) + # keys = permute_for_rope(keys, num_key_value_heads, num_key_value_heads*dim_per_head, dim) + + q = new_key.replace("qkv", "q") + tqdm.write(f"Processing: {key.ljust(50)} ->\t {q}, {queries.shape}") + state_dict[q] = queries + + k = new_key.replace("qkv", "k") + tqdm.write(f"Processing: {key.ljust(50)} ->\t {k}, {keys.shape}") + state_dict[k] = keys + + v = new_key.replace("qkv", "v") + tqdm.write(f"Processing: {key.ljust(50)} ->\t {v}, {values.shape}") + state_dict[v] = values + elif _OFFLINE_QUANT_COMPATIBLE and "feed_forward.experts." in new_key: + # for experts, we need to split expert for offline quantiation purpose and don't need to fuse + expert_lists = [] + for k in current_parameter: + expert_lists.append( + list(k.reshape(num_experts, -1, k.shape[-1]).unbind(0)) + ) # [#expert * IN, OUT] -> #experts * [IN, OUT] + for i in range(num_experts): + expert = torch.cat([expert_list[i] for expert_list in expert_lists], dim=concat_dim) + expert_key = new_key.replace("experts.", f"experts.{i}.") + state_dict[expert_key] = expert.transpose(0, 1).contiguous() # [OUT, IN] + tqdm.write(f"Processing: {key.ljust(50)} ->\t {expert_key}, {state_dict[expert_key].shape}") + elif re.search(r"(gate|up)_proj", new_key): + path = new_key.split(".") + gate_key = re.sub(r"(gate|up)_proj", lambda m: "gate_proj", new_key) + up_key = re.sub(r"(gate|up)_proj", lambda m: "up_proj", new_key) + if gate_key == new_key: + state_dict[new_key] = torch.cat(current_parameter, dim=concat_dim) + elif new_key == up_key: + if "experts" not in new_key: + state_dict[new_key] = torch.cat(current_parameter, dim=concat_dim) + else: + gate_proj = state_dict.pop(gate_key) + gate_proj = [ + gate_proj.reshape(num_experts, -1, 8, 1024)[:, :, k, :].reshape(num_experts, -1, 1024) + for k in range(8) + ] + gate_proj = torch.cat(gate_proj, dim=-1) + + up_proj = [ + k.reshape(num_experts, -1, 8, 1024).reshape(num_experts, -1, 1024) + for k in current_parameter + ] + up_proj = torch.cat(up_proj, dim=-1) + + gate_up_proj = torch.cat((gate_proj, up_proj), dim=-1) + new_key = new_key.replace("up_proj", "gate_up_proj") + state_dict[new_key] = gate_up_proj.contiguous() + + tqdm.write(f"Processing: {key.ljust(50)} ->\t {new_key}, {state_dict[new_key].shape}") + elif "down_proj" in new_key: + current_parameter = torch.cat(current_parameter, dim=concat_dim) + if "experts" in new_key: + p = [] + for i in range(8): + p += [current_parameter.reshape(8, -1, 5120)[i, :, :].view(num_experts, -1, 5120)] + current_parameter = torch.cat(p, dim=1) + state_dict[new_key] = current_parameter.contiguous() + tqdm.write(f"Processing: {key.ljust(50)} ->\t {new_key}, {state_dict[new_key].shape}") + elif "router" in new_key: + current_parameter = torch.cat(current_parameter, dim=concat_dim) + state_dict[new_key] = current_parameter.transpose(0, 1) + elif "lm_head" in new_key: + current_parameter = torch.cat(current_parameter, dim=concat_dim).clone() + # TODO we need to do better than mean, works for now + # if (vocab_size - current_parameter.shape[0]) > 0: + # mean_embedding = torch.mean(current_parameter, dim=0)[:, None].repeat(vocab_size-current_parameter.shape[0],1) + # print(mean_embedding.shape) + # current_parameter = torch.cat((current_parameter, mean_embedding), dim=0) + state_dict[new_key] = current_parameter + tqdm.write( + f"Processing: {key.ljust(50)} ->\t {new_key}, {state_dict[new_key].shape}, concat dim = {concat_dim}" + ) + elif new_key == "vision_model.patch_embedding.linear.weight": + current_parameter = torch.cat(current_parameter, dim=concat_dim).clone() + # We don't reshape the patch embedding as we're using unfolded convolution as well + state_dict[new_key] = current_parameter # .reshape(-1, 3, vision_patch_size, vision_patch_size) + # generic concat for weights/select one for biases + elif isinstance(current_parameter, list) and len(current_parameter) > 0: + if not is_param_same_across_shards(new_key): + current_parameter = torch.cat(current_parameter, dim=concat_dim) + state_dict[new_key] = current_parameter + tqdm.write( + f"Processing: {key.ljust(50)} ->\t {new_key}, {state_dict[new_key].shape}, concat dim = {concat_dim}" + ) + elif is_param_same_across_shards(new_key): + state_dict[new_key] = current_parameter[0] + tqdm.write( + f"Processing: {key.ljust(50)} ->\t {new_key}, {state_dict[new_key].shape}, concat dim = {concat_dim}" + ) + + elif new_key == "": + # skip empty keys + continue + else: + # just load the parameter + state_dict[new_key] = current_parameter + tqdm.write( + f"Processing: {key.ljust(50)} ->\t {new_key}, {state_dict[new_key].shape}, concat dim = {concat_dim}" + ) + del loaded + gc.collect() + + print("Loading the checkpoint in a Llama4 model.") + state_dict.pop("") + model.load_state_dict(state_dict, strict=True, assign=True) + print("Model reloaded successfully.") + print("Saving the model.") + model.save_pretrained(model_path, safe_serialization=safe_serialization) + del state_dict, model + + # Safety check: reload the converted model + gc.collect() + print("Reloading the model to check if it's saved correctly.") + with torch.no_grad(): + # TODO test if we can do `tp_plan="auto"`` + model = Llama4ForConditionalGeneration.from_pretrained( + model_path, torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="eager" + ) + + model.generation_config.top_p = 0.9 + model.generation_config.temperature = 0.6 + print("Model reloaded successfully.") + + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(model_path) + inputs = tokenizer(["Roses are red,"], return_tensors="pt").to(model.device) + out = model.generate(**inputs, max_new_tokens=4) + print(tokenizer.batch_decode(out)) + # generation config + if instruct: + print("Saving generation config...") + generation_config = GenerationConfig( + do_sample=True, + temperature=0.6, + top_p=0.9, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + ) + generation_config.save_pretrained(model_path) + + +BOS_ADDED_TOKEN = AddedToken( + "<|begin_of_text|>", single_word=False, lstrip=False, rstrip=False, normalized=False, special=True +) +EOS_ADDED_TOKEN = AddedToken( + "<|end_of_text|>", single_word=False, lstrip=False, rstrip=False, normalized=False, special=True +) +EOT_ADDED_TOKEN = AddedToken("<|eot|>", single_word=False, lstrip=False, rstrip=False, normalized=False, special=True) + + +def get_reserved_special_tokens(name, count, start_index=0): + return [f"<|{name}_reserved_special_token_{i}|>" for i in range(start_index, start_index + count)] + + +# 200005, ..., 200079 +LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS = [ + "<|header_start|>", + "<|header_end|>", + "<|eom|>", + "<|eot|>", + "<|step|>", + "<|text_post_train_reserved_special_token_0|>", + "<|text_post_train_reserved_special_token_1|>", + "<|text_post_train_reserved_special_token_2|>", + "<|text_post_train_reserved_special_token_3|>", + "<|text_post_train_reserved_special_token_4|>", + "<|text_post_train_reserved_special_token_5|>", + "<|python_start|>", + "<|python_end|>", + "<|finetune_right_pad|>", +] + get_reserved_special_tokens( + "text_post_train", 61, 6 +) # <|text_post_train_reserved_special_token_6|>, ..., <|text_post_train_reserved_special_token_66|> + +# 200080, ..., 201133 +LLAMA4_VISION_SPECIAL_TOKENS = [ + "<|image_start|>", + "<|image_end|>", + "<|vision_reserved_special_token_0|>", + "<|vision_reserved_special_token_1|>", + "<|tile_x_separator|>", + "<|tile_y_separator|>", + "<|vision_reserved_special_token_2|>", + "<|vision_reserved_special_token_3|>", + "<|vision_reserved_special_token_4|>", + "<|vision_reserved_special_token_5|>", + "<|image|>", + "<|vision_reserved_special_token_6|>", + "<|patch|>", +] + get_reserved_special_tokens( + "vision", 1041, 7 +) # <|vision_reserved_special_token_7|>, ..., <|vision_reserved_special_token_1047|> + +LLAMA4_SPECIAL_TOKENS = LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS + LLAMA4_VISION_SPECIAL_TOKENS + +BASIC_SPECIAL_TOKENS = [ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|fim_prefix|>", + "<|fim_middle|>", + "<|fim_suffix|>", +] + + +class Llama4Converter(TikTokenConverter): + def __init__( + self, + vocab_file, + special_tokens: List[str], + pattern: str, + model_max_length: int = 0, + chat_template: Optional[str] = None, + **kwargs, + ): + super().__init__(vocab_file, pattern=pattern) + self.additional_special_tokens = special_tokens + tokenizer = self.converted() + if chat_template is not None: + kwargs["chat_template"] = chat_template + + self.converted_tokenizer = PreTrainedTokenizerFast( + tokenizer_object=tokenizer, + model_input_names=["input_ids", "attention_mask"], + model_max_length=model_max_length, + **kwargs, + ) + + # to check + # import tiktoken + # model = tiktoken.Encoding( + # name=Path(model_path).name, + # pat_str=self.O200K_PATTERN, + # mergeable_ranks=mergeable_ranks, + # special_tokens=self.special_tokens, + # ) + + instruct = chat_template is not None + self.update_post_processor(self.converted_tokenizer) + # finer special_tokens_map.json + self.converted_tokenizer._bos_token = BOS_ADDED_TOKEN + self.converted_tokenizer._eos_token = EOT_ADDED_TOKEN if instruct else EOS_ADDED_TOKEN + + # We can't do this while building the tokenizer because we have no easy access to the bos token id + def update_post_processor(self, tokenizer): + tokenizer._tokenizer.post_processor = processors.Sequence( + [ + processors.ByteLevel(trim_offsets=False), + processors.TemplateProcessing( + single="<|begin_of_text|> $A", + pair="<|begin_of_text|>:0 $A:0 <|begin_of_text|>:1 $B:1", + special_tokens=[ + ("<|begin_of_text|>", tokenizer.convert_tokens_to_ids("<|begin_of_text|>")), + ], + ), + ] + ) + + +O200K_PATTERN = r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+""" # noqa: E501 + + +def write_tokenizer(args): + tokenizer_path = os.path.join(args.input_dir, "tokenizer.model") + chat_template = "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %} \n {%- if messages[0]['content'] is string %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- else %}\n {#- FIXME: The processor requires an array, always. #}\n {%- set system_message = messages[0]['content'][0]['text']|trim %}\n {%- endif %}\n {%- set messages = messages[1:] %}\n {%- set user_supplied_system_message = true %}\n{%- else %}\n {%- set system_message = \"\" %}\n {%- set user_supplied_system_message = false %}\n{%- endif %}\n\n{#- System message if the user supplied one #}\n{%- if user_supplied_system_message %}\n {{- \"<|header_start|>system<|header_end|>\n\n\" }}\n {%- if tools is not none %}\n {{- \"Environment: ipython\n\" }}\n {%- endif %}\n {%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\n\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\n\n\" }}\n {%- endfor %}\n {%- endif %}\n {{- system_message }}\n {{- \"<|eot|>\" }}\n{%- endif %}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|header_start|>user<|header_end|>\n\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\n\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\n\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\n\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }}\n {%- if message['content'] is string %}\n {{- message['content'] }}\n {%- else %}\n {%- for content in message['content'] %}\n {%- if content['type'] == 'image' %}\n {{- '<|image|>' }}\n {%- elif content['type'] == 'text' %}\n {{- content['text'] }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- \"<|eot|>\" }}\n {%- elif 'tool_calls' in message and message.tool_calls|length > 0 %}\n {{- '<|header_start|>assistant<|header_end|>\n\n' -}}\n {{- '<|python_start|>' }}\n {%- if message['content'] is string %}\n {{- message['content'] }}\n {%- else %}\n {%- for content in message['content'] %}\n {%- if content['type'] == 'image' %}\n {{- '<|image|>' }}\n {%- elif content['type'] == 'text' %}\n {{- content['text'] }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '<|python_end|>' }}\n {%- for tool_call in message.tool_calls %}\n {{- '{\"name\": \"' + tool_call.function.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.function.arguments | tojson }}\n {{- \"}\" }}\n {%- endfor %}\n {{- \"<|eot|>\" }}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|header_start|>ipython<|header_end|>\n\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|header_start|>assistant<|header_end|>\n\n' }}\n{%- endif %}\n" + + special_tokens = BASIC_SPECIAL_TOKENS + LLAMA4_SPECIAL_TOKENS + converter = Llama4Converter( + vocab_file=tokenizer_path, + pattern=O200K_PATTERN, + special_tokens=special_tokens, + chat_template=chat_template if args.instruct else None, + bos_token="<|begin_of_text|>", + eos_token="<|end_of_text|>" if not args.instruct else "<|eot|>", + pad_token="<|finetune_right_pad_id|>", + model_max_length=max_context_length(args.input_dir, args.instruct), + ) + tokenizer = converter.converted_tokenizer + + image_processor = Llama4ImageProcessorFast() + processor = Llama4Processor( + image_processor=image_processor, + tokenizer=tokenizer, + chat_template=tokenizer.chat_template, + ) + processor.save_pretrained(args.output_dir) + del processor + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--input_dir", + type=str, + default="/fsx/arthur/Llama-4-17B-Omni-Instruct-Original", + help="Location of the local folder copied from the Hub.", + ) + parser.add_argument( + "--output_dir", + default="llama4_hf_vision", + type=str, + help="Location to write HF model and tokenizer", + ) + parser.add_argument( + "--safe_serialization", default=True, type=bool, help="Whether or not to save using `safetensors`." + ) + parser.add_argument( + "--special_tokens", + default=None, + type=List[str], + help="The list of special tokens that should be added to the model.", + ) + parser.add_argument( + "--num_shards", + default=8, + type=int, + help="The number of individual shards used for the model. Does not have to be the same as the number of consolidated_xx.pth", + ) + parser.add_argument( + "--instruct", + action="store_true", + help="Whether the model is an instruct model", + ) + parser.add_argument( + "--convert_checkpoints", + action="store_true", + help="Whether to convert the original weights (or skip if previously converted)", + ) + + args = parser.parse_args() + write_tokenizer(args) + + write_model( + model_path=args.output_dir, + input_base_path=args.input_dir, + safe_serialization=args.safe_serialization, + num_shards=args.num_shards, + instruct=args.instruct, + convert_checkpoints=args.convert_checkpoints, + ) diff --git a/src/transformers/models/llama4/image_processing_llama4_fast.py b/src/transformers/models/llama4/image_processing_llama4_fast.py new file mode 100644 index 0000000000..6935ba798f --- /dev/null +++ b/src/transformers/models/llama4/image_processing_llama4_fast.py @@ -0,0 +1,480 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for Got-OCR-2.""" + +import math +from collections import defaultdict +from functools import lru_cache +from typing import List, Optional, Set, Tuple, Union + +from ...image_processing_utils import BatchFeature +from ...image_processing_utils_fast import ( + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, + BaseImageProcessorFast, + DefaultFastImageProcessorKwargs, + group_images_by_shape, + reorder_images, +) +from ...image_utils import ( + ImageInput, + PILImageResampling, + SizeDict, +) +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + add_start_docstrings, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, +) + + +if is_torch_available(): + import torch + +if is_torchvision_available(): + if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F + else: + from torchvision.transforms import functional as F + + +def get_factors(dividend: int) -> Set[int]: + """ + Calculate all factors of a given number, i.e. a dividor that leaves + no remainder. For example, if dividend=12, it will return {1, 2, 3, 4, 6, 12}. + + Args: + dividend (int): The number to find factors for. + + Returns: + set: A set containing all factors of the number. + """ + factors_set = set() + + for i in range(1, int(dividend**0.5) + 1): + if dividend % i == 0: + factors_set.add(i) + factors_set.add(dividend // i) + return factors_set + + +def get_max_res_without_distortion( + image_size: Tuple[int, int], + target_size: Tuple[int, int], +) -> Tuple[int, int]: + """ + Determines the maximum resolution to which an image can be resized to without distorting its + aspect ratio, based on the target resolution. + + Args: + image_size (Tuple[int, int]): The original resolution of the image (height, width). + target_resolution (Tuple[int, int]): The desired resolution to fit the image into (height, width). + Returns: + Tuple[int, int]: The optimal dimensions (height, width) to which the image should be resized. + Example: + >>> _get_max_res_without_distortion([200, 300], target_size = [450, 200]) + (134, 200) + >>> _get_max_res_without_distortion([800, 600], target_size = [450, 1300]) + (450, 338) + """ + + original_height, original_width = image_size + target_height, target_width = target_size + + scale_w = target_width / original_width + scale_h = target_height / original_height + + if scale_w < scale_h: + new_width = target_width + new_height = min(math.floor(original_height * scale_w), target_height) + else: + new_height = target_height + new_width = min(math.floor(original_width * scale_h), target_width) + + return new_height, new_width + + +class Llama4ImageProcessorKwargs(DefaultFastImageProcessorKwargs): + max_patches: Optional[int] + resize_to_max_canvas: Optional[bool] + + +def split_to_tiles(images: torch.Tensor, num_tiles_height: int, num_tiles_width: int) -> torch.Tensor: + # Split image into number of required tiles (width x height) + batch_size, num_channels, height, width = images.size() + images = images.view( + batch_size, + num_channels, + num_tiles_height, + height // num_tiles_height, + num_tiles_width, + width // num_tiles_width, + ) + # Permute dimensions to reorder the axes + image = images.permute(0, 2, 4, 1, 3, 5).contiguous() + # Reshape into the desired output shape (batch_size * 4, num_channels, width/2, height/2) + image = image.view( + batch_size, + num_tiles_width * num_tiles_height, + num_channels, + height // num_tiles_height, + width // num_tiles_width, + ) + return image + + +@lru_cache(maxsize=1) +def find_supported_resolutions(max_num_chunks: int, patch_size: SizeDict) -> torch.Tensor: + """ + Computes all of the allowed resolutions for a fixed number of chunks + and patch_size. Useful for when dividing an image into chunks. + + Args: + max_num_chunks (int): Maximum number of chunks for processing. + patch_size (int): Size of the side of the patch. + + Returns: + torch.Tensor: List of possible resolutions as tuples (height, width). + + Example: + >>> max_num_chunks = 5 + >>> patch_size = 224 + >>> find_supported_resolutions(max_num_chunks, patch_size) + tensor([(224, 896), (448, 448), (224, 224), (896, 224), (224, 672), + (672, 224), (224, 448), (448, 224)]) + + Given max_num_chunks=4, patch_size=224, it will create a dictionary: + { + 0.25: [(1, 4)], + 1.0: [(2, 2), (1, 1)], + 4.0: [(4, 1)], + 0.33: [(1, 3)], + 3.0: [(3, 1)], + 0.5: [(1, 2)], + 2.0: [(2, 1)] + } + + and return the resolutions multiplied by the patch_size: + [(1*224, 4*224), (2*224, 2*224), ..., (2*224, 1*224)] + """ + height, width = patch_size.height, patch_size.width + if height != width: + raise ValueError("`size` must be square.") + + patch_size = height + + asp_dict = defaultdict(list) + for chunk_size in range(max_num_chunks, 0, -1): + _factors = sorted(get_factors(chunk_size)) + _asp_ratios = [(factor, chunk_size // factor) for factor in _factors] + for height, width in _asp_ratios: + ratio_float = height / width + asp_dict[ratio_float].append((height, width)) + + # get the resolutions multiplied by the patch_size + possible_resolutions = [] + for key, value in asp_dict.items(): + for height, depth in value: + possible_resolutions.append((height * patch_size, depth * patch_size)) + + return possible_resolutions + + +def pad_to_best_fit( + images: "torch.Tensor", + target_size: Tuple[int, int], + background_color: Union[int, Tuple[int, int, int]] = 0, +) -> "torch.Tensor": + """ + Pads an image to fit the target size. + + Args: + images (`np.ndarray`): + The images to pad. + background_color (`int` or `Tuple[int, int, int]`, *optional*, defaults to 0): + The color to use for the padding. Can be an integer for single channel or a + tuple of integers representing for multi-channel images. If passed as integer + in mutli-channel mode, it will default to `0` in subsequent channels. + Returns: + `torch.Tensor`: The padded images. + """ + + num_channels = images.shape[1] if len(images.shape) == 4 else images.shape[0] + if isinstance(background_color, int): + background_color = [background_color] + [0] * (num_channels - 1) + elif len(background_color) != num_channels: + raise ValueError( + f"background_color must have no more than {num_channels} elements to match the number of channels" + ) + + height, width = images.shape[-2:] + target_height, target_width = target_size + paste_x_right = target_width - width + paste_y_right = target_height - height + padded_images = F.pad(images, padding=[0, 0, paste_x_right, paste_y_right], fill=background_color) + + return padded_images + + +def get_best_fit( + image_size: Tuple[int, int], + possible_resolutions: torch.Tensor, + resize_to_max_canvas: bool = False, +) -> Tuple[int, int]: + """ + Determines the best canvas possible from a list of possible resolutions to, without distortion, + resize an image to. + + For each possible resolution, calculates the scaling factors for + width and height, and selects the smallest one, which is the limiting side. + E.g. to match the canvas you can upscale height by 2x, and width by 1.5x, + therefore, the maximum upscaling you can do is min(2, 1.5) = 1.5. + + If upscaling is possible (any of the scaling factors is greater than 1), + then picks the smallest upscaling factor > 1, unless resize_to_max_canvas is True. + + If upscaling is not possible, then picks the largest scaling factor <= 1, i.e. + reduce downscaling as much as possible. + + If there are multiple resolutions with the same max scale, we pick the one with the lowest area, + to minimize padding. E.g., the same image can be upscaled to 224x224 and 224x448, but the latter + has more padding. + + Args: + image_size (Tuple[int, int]): A tuple containing the height and width of the image. + possible_resolutions (torch.Tensor): A tensor of shape (N, 2) where each + row represents a possible resolution (height, width). + resize_to_max_canvas (bool): If True, will return the largest upscaling resolution. + + Returns: + List[int]: The best resolution [height, width] for the given image. + + Example: + >>> image_size = (200, 300) + >>> possible_resolutions = torch.tensor([[224, 672], + ... [672, 224], + ... [224, 448], + ... [448, 224], + ... [224, 224]]) + >>> get_best_fit(image_size, possible_resolutions) + [224, 448] + + We have: + scale_w = tensor([2.2400, 0.7467, 1.4933, 0.7467, 0.7467]) + scale_h = tensor([1.1200, 3.3600, 1.1200, 2.2400, 1.1200]) + scales = tensor([1.1200, 0.7467, 1.1200, 0.7467, 0.7467]) + Only one of the scales > 1: + upscaling_possible = tensor([1.1200, 1.1200]) + smallest_rescale = tensor(1.1200) + So we pick the resolution with the smallest smallest area: + areas = tensor([150528, 100352]) # [672, 224], [224, 448] + optimal_canvas = tensor([224, 448]) + """ + + original_height, original_width = image_size + + # get all possible resolutions heights/widths + target_heights, target_widths = ( + possible_resolutions[:, 0], + possible_resolutions[:, 1], + ) + + # get scaling factors to resize the image without distortion + scale_w = target_widths / original_width + scale_h = target_heights / original_height + + # get the min scale between width and height (limiting side -> no distortion) + scales = torch.where(scale_h > scale_w, scale_w, scale_h) + + # filter only scales that allow upscaling + upscaling_options = scales[scales >= 1] + if len(upscaling_options) > 0: + if resize_to_max_canvas: + selected_scale = torch.max(upscaling_options) + else: + selected_scale = torch.min(upscaling_options) + else: + # no upscaling possible, + # get the minimum downscaling (max scale for scales<1) + downscaling_options = scales[scales < 1] + selected_scale = torch.max(downscaling_options) + + # get all resolutions that support this scaling factor, + # e.g. you can upscale to 224x224, 224x448, 224x672 without distortion + chosen_canvas = possible_resolutions[scales == selected_scale] + + # if there are multiple resolutions, + # get the one with minimum area to reduce padding + if len(chosen_canvas) > 1: + areas = chosen_canvas[:, 0] * chosen_canvas[:, 1] + optimal_idx = torch.argmin(areas) + optimal_canvas = chosen_canvas[optimal_idx] + else: + optimal_canvas = chosen_canvas[0] + + return tuple(optimal_canvas.tolist()) + + +@add_start_docstrings( + "Constructs a fast Llama4 image processor.", + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, + """ + max_patches (`int`, *optional*, defaults to 16): + The maximum number of patches to be extracted from the image. + Can be overridden by the `max_patches` parameter in the `preprocess` method. + resize_to_max_canvas (`bool`, *optional*, defaults to False): + Whether to resize the image to the maximum canvas size. + If True, picks the canvas the allows the largest resizing without distortion. + If False, downsample as little as possible, including no resizing at all, + but never upsample, unless the image is smaller than the patch size. + """, +) +class Llama4ImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BILINEAR + image_mean = [0.5, 0.5, 0.5] + image_std = [0.5, 0.5, 0.5] + size = {"height": 336, "width": 336} + do_resize = True + do_rescale = True + do_normalize = True + do_convert_rgb = True + max_patches = 16 + resize_to_max_canvas = False + valid_kwargs = Llama4ImageProcessorKwargs + + def __init__(self, **kwargs: Unpack[Llama4ImageProcessorKwargs]): + super().__init__(**kwargs) + + def rescale_and_normalize( + self, + images: "torch.Tensor", + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Union[float, List[float]], + image_std: Union[float, List[float]], + ) -> "torch.Tensor": + """ + Rescale and normalize images. + Override to rescale and normalize the images in torch.bfloat16 as in the original implementation + """ + if do_rescale and do_normalize: + images = images.to(dtype=torch.bfloat16) * rescale_factor + images = self.normalize(images, image_mean, image_std) + elif do_rescale: + images = images * rescale_factor + elif do_normalize: + images = self.normalize(images, image_mean, image_std) + + return images + + @add_start_docstrings( + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, + """ + max_patches (`int`, *optional*, defaults to 16): + The maximum number of patches to be extracted from the image. + Can be overridden by the `max_patches` parameter in the `preprocess` method. + resize_to_max_canvas (`bool`, *optional*, defaults to False): + Whether to resize the image to the maximum canvas size. + If True, picks the canvas the allows the largest resizing without distortion. + If False, downsample as little as possible, including no resizing at all, + but never upsample, unless the image is smaller than the patch size. + """, + ) + def preprocess(self, images: ImageInput, **kwargs: Unpack[Llama4ImageProcessorKwargs]) -> BatchFeature: + return super().preprocess(images, **kwargs) + + def _preprocess( + self, + images: List["torch.Tensor"], + size: SizeDict, + max_patches: int, + resize_to_max_canvas: bool, + interpolation: Optional["F.InterpolationMode"], + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Optional[Union[float, List[float]]], + image_std: Optional[Union[float, List[float]]], + return_tensors: Optional[Union[str, TensorType]], + **kwargs, + ) -> BatchFeature: + possible_resolutions = find_supported_resolutions(max_num_chunks=max_patches, patch_size=size) + possible_resolutions = torch.tensor(possible_resolutions) + # process images by batch, grouped by shape + grouped_images, grouped_images_index = group_images_by_shape(images) + grouped_processed_images = {} + grouped_aspect_ratios = {} + for shape, stacked_images in grouped_images.items(): + image_size = stacked_images.shape[-2:] + target_size = get_best_fit(image_size, possible_resolutions, resize_to_max_canvas=resize_to_max_canvas) + # If target_size requires upscaling, we might want to limit the upscaling to max_upscaling_size + max_upscaling_size = None if resize_to_max_canvas else size.height + if max_upscaling_size is not None: + new_target_height = min(max(image_size[0], max_upscaling_size), target_size[0]) + new_target_width = min(max(image_size[1], max_upscaling_size), target_size[1]) + target_size_without_distortion = (new_target_height, new_target_width) + + # resize to target_size while preserving aspect ratio + new_size_without_distortion = get_max_res_without_distortion(image_size, target_size_without_distortion) + new_size_without_distortion = SizeDict( + height=max(new_size_without_distortion[0], 1), width=max(new_size_without_distortion[1], 1) + ) + processed_images = self.resize( + stacked_images, + new_size_without_distortion, + interpolation=interpolation, + ) + + # pad to target_size to be able to split into tiles + processed_images = pad_to_best_fit(processed_images, target_size) + processed_images = self.rescale_and_normalize( + processed_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + + ratio_h, ratio_w = ( + target_size[0] // size.height, + target_size[1] // size.height, + ) + # split into tiles + processed_images = split_to_tiles(processed_images, ratio_h, ratio_w) + grouped_processed_images[shape] = processed_images + grouped_aspect_ratios[shape] = torch.tensor([[ratio_h, ratio_w]] * stacked_images.shape[0]) + + # add a global tile to the processed tile if there are more than one tile + if ratio_h * ratio_w > 1: + global_tiles = self.resize( + stacked_images, + size, + interpolation=interpolation, + ) + global_tiles = self.rescale_and_normalize( + global_tiles, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + grouped_processed_images[shape] = torch.cat([processed_images, global_tiles.unsqueeze(1)], dim=1) + processed_images = reorder_images(grouped_processed_images, grouped_images_index) + aspect_ratios_list = reorder_images(grouped_aspect_ratios, grouped_images_index) + + processed_images = torch.cat(processed_images, dim=0) if return_tensors else processed_images + aspect_ratios = torch.stack(aspect_ratios_list, dim=0) if return_tensors else aspect_ratios_list + return BatchFeature( + data={"pixel_values": processed_images, "aspect_ratios": aspect_ratios}, tensor_type=return_tensors + ) + + +__all__ = ["Llama4ImageProcessorFast"] diff --git a/src/transformers/models/llama4/modeling_llama4.py b/src/transformers/models/llama4/modeling_llama4.py new file mode 100644 index 0000000000..f511674dc0 --- /dev/null +++ b/src/transformers/models/llama4/modeling_llama4.py @@ -0,0 +1,1903 @@ +# coding=utf-8 +# Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from dataclasses import dataclass +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint + +from transformers.models.llama4.configuration_llama4 import Llama4VisionConfig + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPast, + CausalLMOutputWithPast, + ModelOutput, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_torch_flex_attn_available, + logging, + replace_return_docstrings, +) +from .configuration_llama4 import Llama4Config, Llama4TextConfig + + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + +logger = logging.get_logger(__name__) +_CHECKPOINT_FOR_DOC = "meta-ai/Llama-4-17B" +_CONFIG_FOR_DOC = "Llama4Config" + + +class Llama4TextExperts(nn.Module): + def __init__(self, config: Llama4Config): + super().__init__() + self.num_experts = config.num_local_experts + self.intermediate_size = config.intermediate_size + self.hidden_size = config.hidden_size + self.expert_dim = self.intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim)) + self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size))) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + This should really not be run on a single machine, as we are reaching compute bound: + - the inputs are expected to be "sorted" per expert already. + - the weights are viewed with another dim, to match num_expert, 1, shape * num_tokens, shape + + Args: + hidden_states (torch.Tensor): (batch_size * token_num, hidden_size) + selected_experts (torch.Tensor): (batch_size * token_num, top_k) + routing_weights (torch.Tensor): (batch_size * token_num, top_k) + Returns: + torch.Tensor + """ + hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) + gate_up = torch.bmm(hidden_states, self.gate_up_proj) + gate, up = gate_up.chunk(2, dim=-1) # not supported for DTensors + next_states = torch.bmm((up * self.act_fn(gate)), self.down_proj) + next_states = next_states.view(-1, self.hidden_size) + return next_states + + +# Phi3MLP +class Llama4TextMLP(nn.Module): + def __init__(self, config, intermediate_size=None): + super().__init__() + + if intermediate_size is None: + intermediate_size = config.intermediate_size + + self.config = config + self.gate_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False) + self.up_proj = nn.Linear(config.hidden_size, intermediate_size, bias=False) + self.down_proj = nn.Linear(intermediate_size, config.hidden_size, bias=False) + self.activation_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.activation_fn(self.gate_proj(x)) * self.up_proj(x) + return self.down_proj(down_proj) + + +class Llama4TextL2Norm(torch.nn.Module): + def __init__(self, dim: int = None, eps: float = 1e-6): + super().__init__() + self.eps = eps + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + return self._norm(x.float()).type_as(x) + + def extra_repr(self): + return f"eps={self.eps}" + + +class Llama4TextRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-5): + """ + Llama4RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(hidden_size)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + return output * self.weight + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.eps}" + + +class Llama4TextMoe(nn.Module): + def __init__(self, config): + super().__init__() + self.top_k = config.num_experts_per_tok + self.hidden_dim = config.hidden_size + self.num_experts = config.num_local_experts + self.experts = Llama4TextExperts(config) + self.router = nn.Linear(config.hidden_size, config.num_local_experts, bias=False) + self.shared_expert = Llama4TextMLP(config) + + def forward(self, hidden_states): + batch, seq_len, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_dim) + router_logits = self.router(hidden_states).transpose(0, 1) + tokens_per_expert = batch * seq_len + + router_top_value, router_indices = torch.topk(router_logits.transpose(0, 1), self.top_k, dim=1) + router_scores = ( + torch.full_like(router_logits.transpose(0, 1), float("-inf")) + .scatter_(1, router_indices, router_top_value) + .transpose(0, 1) + ) + # We do this to make sure we have -inf for non topK tokens before going through the ! + # Here we are just creating a tensor to index each and every single one of the hidden states. Let s maybe register a buffer for this! + router_indices = ( + torch.arange(tokens_per_expert, device=hidden_states.device).view(1, -1).expand(router_scores.size(0), -1) + ) + router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype) + + router_indices = router_indices.reshape(-1, 1).expand(-1, hidden_dim) + routed_in = torch.gather( + input=hidden_states, + dim=0, + index=router_indices, + ).to(hidden_states.device) + # we gather inputs corresponding to each expert based on the router indices + routed_in = routed_in * router_scores.reshape(-1, 1) + routed_out = self.experts(routed_in) + out = self.shared_expert(hidden_states) + # now that we finished expert computation -> we scatter add because we gathered previously + # we have to do this because we used all experts on all tokens. This is faster than the for loop, tho you are compute bound + # this scales a lot better if you do EP! + out.scatter_add_(dim=0, index=router_indices, src=routed_out.view(-1, hidden_dim)) + return out, router_scores + + +class Llama4TextRotaryEmbedding(nn.Module): + def __init__(self, config: Llama4TextConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + self.rope_type = "llama3" if config.rope_scaling is not None else "default" + + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + # This .to() is needed if the model has been moved to a device after being initialized (because + # the buffer is automatically moved, but not the original copy) + self.original_inv_freq = self.original_inv_freq.to(device) + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # Convert to complex representation + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + freqs_cis = freqs_cis * self.attention_scaling + return freqs_cis + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis[:, :, None, :]).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis[:, :, None, :]).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) / math.sqrt(module.head_dim) + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Llama4TextAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Llama4TextConfig, layer_idx): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_attention_heads = config.num_attention_heads + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.num_key_value_heads = config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attn_scale = config.attn_scale + self.floor_scale = config.floor_scale + self.attn_temperature_tuning = config.attn_temperature_tuning + self.attention_dropout = config.attention_dropout + self.is_causal = True + self.use_rope = int((layer_idx + 1) % 4 != 0) # rope unused for dense layers + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + if self.config.use_qk_norm and self.use_rope: + self.qk_norm = Llama4TextL2Norm() + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape) + key_states = self.k_proj(hidden_states).view(*input_shape, -1, self.head_dim) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + if self.use_rope: # the 16E model skips rope for long context on certain layers + query_states, key_states = apply_rotary_emb( + query_states, key_states, position_embeddings.to(query_states.device) + ) + + if hasattr(self, "qk_norm"): # the 128E model does not use qk_norm + query_states = self.qk_norm(query_states) + key_states = self.qk_norm(key_states) + + # Use temperature tuning from https://arxiv.org/abs/2501.19399) to NoROPE layers + if self.attn_temperature_tuning and not self.use_rope: + attn_scales = ( + torch.log(torch.floor((cache_position.float() + 1.0) / self.floor_scale) + 1.0) * self.attn_scale + 1.0 + ) + attn_scales = attn_scales.view((*input_shape, 1, 1)) + query_states = (query_states * attn_scales).to(query_states.dtype) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Llama4TextDecoderLayer(nn.Module): + def __init__(self, config, layer_idx): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Llama4TextAttention(config, layer_idx) + self.use_chunked_attention = int((layer_idx + 1) % 4 != 0) # <=> use rope + self.is_moe_layer = layer_idx in config.moe_layers + if self.is_moe_layer: # the 128E model interleaves dense / sparse + self.feed_forward = Llama4TextMoe(config) + else: + self.feed_forward = Llama4TextMLP(config, intermediate_size=config.intermediate_size_mlp) + + self.input_layernorm = Llama4TextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Llama4TextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.layer_idx = layer_idx + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + chunk_causal_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # use local attention mask for ROPE layers + if self.use_chunked_attention and chunk_causal_mask is not None: + attention_mask = chunk_causal_mask + + # Self Attention + attention_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + attention_states + + # Fully Connected + residual = hidden_states + + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.feed_forward(hidden_states) + if self.is_moe_layer: + hidden_states, router_logits = hidden_states + else: + router_logits = None + hidden_states = residual + hidden_states.view(residual.shape) + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if output_router_logits: + outputs += (router_logits,) + + return outputs + + +LLAMA4_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Llama4Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Llama4 Model outputting raw hidden-states without any specific head on top.", + LLAMA4_START_DOCSTRING, +) +class Llama4PreTrainedModel(PreTrainedModel): + config_class = Llama4Config + supports_gradient_checkpointing = True + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + std = ( + self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.config.text_config.initializer_range + ) + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +LLAMA4_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Llama4 Model outputting raw hidden-states without any specific head on top.", + LLAMA4_START_DOCSTRING, +) +class Llama4TextModel(Llama4PreTrainedModel): + _no_split_modules = ["Llama4TextDecoderLayer"] + base_model_prefix = "model" + config_class = Llama4TextConfig + + def __init__(self, config: Llama4TextConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Llama4TextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Llama4TextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Llama4TextRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA4_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids.to(self.embed_tokens.weight.device)) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask, chunk_causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + freq_cis = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + chunk_causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + freq_cis, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + chunk_causal_mask=chunk_causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=freq_cis, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + return output if return_dict else output.to_tuple() + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + chunked_attention_mask=None, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask, attention_mask # flash does not support chunked attn TODO support flash + return None, None + + if self.config._attn_implementation not in ["sdpa", "flex_attention", "eager"]: + return None, None + + sequence_length = input_tensor.shape[1] + cache_position = cache_position.to(self.device) + attention_chunk_size = self.config.attention_chunk_size + + first_cache_position = cache_position[0] + last_cache_position = cache_position[-1] + + # to avoid graph break, we introduce this hack + cond1 = first_cache_position >= attention_chunk_size + cond2 = (first_cache_position < attention_chunk_size) & ( + first_cache_position + sequence_length > attention_chunk_size + ) + + key_length = torch.where( + cond1, + attention_chunk_size + sequence_length - 1, + torch.where(cond2, first_cache_position + sequence_length, attention_chunk_size), + ) + + if past_key_values is not None and past_key_values.is_compileable: + target_length = past_key_values.get_max_cache_shape + else: + target_length = attention_mask.shape[-1] if attention_mask is not None else sequence_length + + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + offsets = (first_cache_position, max(last_cache_position - key_length, 0)) + chunked_attention_mask = make_flex_block_causal_mask( + attention_mask, self.config.attention_chunk_size, sequence_length, key_length, offsets=offsets + ) + attention_mask = make_flex_block_causal_mask( + attention_mask, + query_length=sequence_length, + key_length=past_key_values.get_max_cache_shape(), + offsets=None if sequence_length != 1 else (first_cache_position, 0), + ) + return attention_mask, chunked_attention_mask + if isinstance(attention_mask, BlockMask): + return attention_mask, chunked_attention_mask + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + dtype, device = input_tensor.dtype, input_tensor.device + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + if target_length > self.config.attention_chunk_size: + chunked_attention_mask = self.create_chunked_attention_mask( + self.config.attention_chunk_size, + start=first_cache_position, + end=first_cache_position + key_length, + device=device, + ) + chunked_attention_mask = chunked_attention_mask & attention_mask + if sequence_length == 1: + chunked_attention_mask = chunked_attention_mask[-1:] + if self.config._attn_implementation == "eager": + chunked_attention_mask = ( + chunked_attention_mask[None, None, :, :] + .to(dtype) + .masked_fill(chunked_attention_mask, torch.finfo(dtype).min) + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu"] + and attention_mask.ndim == 4 + and not output_attentions # Only unmask for 4d masks + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + # chunked_attention_mask = AttentionMaskConverter._unmask_unattended(chunked_attention_mask, min_dtype) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and chunked_attention_mask is not None: + chunked_attention_mask = chunked_attention_mask.bool() + causal_mask = causal_mask.bool() + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=first_cache_position, + is_training=self.training, + ): + causal_mask = None + return causal_mask, chunked_attention_mask + + def create_chunked_attention_mask( + self, attention_chunk_size: int, start: int, end: int, device: torch.device + ) -> torch.Tensor: + """ + Generate the following: + + 'What' : 0 ■ ⬚ ⬚ ⬚ ⬚ ⬚ | + '▁is' : 1 ■ ■ ⬚ ⬚ ⬚ ⬚ | + '▁ch' : 2 ■ ■ ■ ⬚ ⬚ ⬚ | + 'unked' : 3 ⬚ ⬚ ⬚ ■ ⬚ ⬚ | + '▁attention': 4 ⬚ ⬚ ⬚ ■ ■ ⬚ | + '?' : 5 ⬚ ⬚ ⬚ ■ ■ ■ | + + If the chunk size is 3. + This can just be appplied over the already created attention mask + """ + block_pos = torch.abs( + (torch.arange(start, end).unsqueeze(0) // attention_chunk_size) + - (torch.arange(start, end).unsqueeze(1) // attention_chunk_size) + ) + token_pos = torch.arange(start, end).unsqueeze(0) - torch.arange(start, end).unsqueeze(1) + mask = (block_pos == 0) & (token_pos <= 0) + return mask.to(device) + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.to(device).reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(device) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class Llama4ForCausalLM(Llama4PreTrainedModel, GenerationMixin): + base_model_prefix = "language_model" + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + config_class = Llama4TextConfig + + def __init__(self, config: Llama4TextConfig): + super().__init__(config) + self.model = Llama4TextModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(LLAMA4_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Llama4ForCausalLM + + >>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@dataclass +class Llama4CausalLMOutputWithPast(ModelOutput): + """ + Base class for Llava causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size (batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[torch.FloatTensor] = None + + +class Llama4VisionMLP2(torch.nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.fc1 = nn.Linear(self.intermediate_size, config.projector_input_dim, bias=False) + self.fc2 = nn.Linear(config.projector_output_dim, config.projector_output_dim, bias=False) + self.activation_fn = nn.GELU() # ACT2FN[config.hidden_act] + self.dropout = config.projector_dropout + + def forward(self, hidden_states): + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + return self.activation_fn(self.fc2(hidden_states)) + + +class Llama4MultiModalProjector(nn.Module): + def __init__(self, config): + super().__init__() + self.linear_1 = nn.Linear( + config.vision_config.vision_output_dim, + config.text_config.hidden_size, + bias=False, + ) + + def forward(self, image_features): + hidden_states = self.linear_1(image_features) + return hidden_states + + +def pixel_shuffle(input_tensor, shuffle_ratio): + # input_tensor: [batch_size, num_patches, channels] + batch_size, num_patches, channels = input_tensor.shape + patch_size = int(math.sqrt(num_patches)) + + input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1) + batch_size, height, width, channels = input_tensor.size() + + reshaped_tensor = input_tensor.view(batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio)) + reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous() + + reshaped_tensor = reshaped_tensor.view( + batch_size, int(height * shuffle_ratio), int(width * shuffle_ratio), int(channels / (shuffle_ratio**2)) + ) + reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous() + + output_tensor = reshaped_tensor.view(batch_size, -1, reshaped_tensor.shape[-1]) + return output_tensor + + +class Llama4VisionPixelShuffleMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.pixel_shuffle_ratio = config.pixel_shuffle_ratio + self.inner_dim = int(config.projector_input_dim // (self.pixel_shuffle_ratio**2)) + self.output_dim = config.projector_output_dim + self.mlp = Llama4VisionMLP2(config) + + def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor: + encoded_patches = pixel_shuffle(encoded_patches, self.pixel_shuffle_ratio) + return self.mlp(encoded_patches) + + +LLAVA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlavaConfig`] or [`LlavaVisionConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +# TODO there is a different RoPE for vision encoder, defined as below +def reshape_for_broadcast(freqs_ci: torch.Tensor, query: torch.Tensor): + ndim = query.ndim + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(query.shape)] + return freqs_ci.view(*shape) + + +def vision_apply_rotary_emb( + query: torch.Tensor, + key: torch.Tensor, + freqs_ci: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + query_ = torch.view_as_complex(query.float().reshape(*query.shape[:-1], -1, 2)) + key_ = torch.view_as_complex(key.float().reshape(*key.shape[:-1], -1, 2)) + freqs_ci = reshape_for_broadcast(freqs_ci=freqs_ci, query=query_) # freqs_ci[:,:,None,:] + freqs_ci = freqs_ci.to(query_.device) + query_out = torch.view_as_real(query_ * freqs_ci).flatten(3) + key_out = torch.view_as_real(key_ * freqs_ci).flatten(3) + return query_out.type_as(query), key_out.type_as(key) # but this drops to 8e-3 + + +class Llama4VisionAttention(nn.Module): + def __init__(self, config: Llama4VisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = config.hidden_size // config.num_attention_heads + self.num_key_value_groups = 1 + self.attention_dropout = config.attention_dropout + + self.q_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.embed_dim, bias=True) + + def forward( + self, + hidden_states: torch.Tensor, + freqs_ci: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Cache] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape) + key_states = self.k_proj(hidden_states).view(hidden_shape) + value_states = self.v_proj(hidden_states).view(hidden_shape) + + query_states, key_states = vision_apply_rotary_emb(query_states, key_states, freqs_ci=freqs_ci) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + # flex disable because breaks on TP 8, embed is 88 not power of 2 + if self.config._attn_implementation not in ["eager", "flex_attention"]: + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + None, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=None, + is_causal=False, # HAS TO BE ENFORCED + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Llama4VisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = nn.GELU() # ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=True) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class Llama4VisionEncoderLayer(nn.Module): + def __init__(self, config: Llama4VisionConfig): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Llama4VisionAttention(config) + self.mlp = Llama4VisionMLP(config) + + self.input_layernorm = nn.LayerNorm(config.hidden_size) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size) + + def forward( + self, + hidden_state: torch.Tensor, + freqs_ci: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = None, + ): + # Self Attention + residual = hidden_state + + hidden_state = self.input_layernorm(hidden_state) + + hidden_state, attn_weights = self.self_attn( + hidden_state, + freqs_ci=freqs_ci, + attention_mask=attention_mask, + ) + hidden_state = residual + hidden_state + + # Feed forward + residual = hidden_state + hidden_state = self.post_attention_layernorm(hidden_state) + hidden_state = self.mlp(hidden_state) + hidden_state = residual + hidden_state + + outputs = (hidden_state,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class Llama4VisionEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`Llama4VisionEncoderLayer`]. + + Args: + config: Llama4VisionConfig + """ + + def __init__(self, config: Llama4VisionConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([Llama4VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + self.config = config + + def forward( + self, + hidden_states: torch.Tensor, + freqs_ci: torch.Tensor, # TODO move this to an attribute instead of keeping it around + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_state=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + freqs_ci=freqs_ci, + ) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + hidden_states = layer_outputs[0] + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class Llama4UnfoldConvolution(nn.Module): + def __init__(self, config): + super().__init__() + kernel_size = config.patch_size + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + self.unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=config.patch_size) + self.linear = nn.Linear( + config.num_channels * kernel_size[0] * kernel_size[1], + config.hidden_size, + bias=False, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.unfold(hidden_states) + hidden_states = hidden_states.permute(0, 2, 1) + hidden_states = self.linear(hidden_states) + return hidden_states + + +class Llama4VisionRotaryEmbedding(nn.Module): + def __init__(self, config): + super().__init__() + idx = config.image_size // config.patch_size + img_idx = torch.arange(idx**2, dtype=torch.int32).reshape(idx**2, 1) + img_idx = torch.cat([img_idx, img_idx[:1]], dim=0) + img_idx[-1, -1] = -2 # ID_CLS_TOKEN + frequencies_x = img_idx % idx # get the coordinates of the 2d matrix along x + frequencies_y = img_idx // idx # get the coordinates of the 2d matrix along y + freq_dim = config.hidden_size // config.num_attention_heads // 2 + rope_freq = 1.0 / (config.rope_theta ** (torch.arange(0, freq_dim, 2)[: (freq_dim // 2)].float() / freq_dim)) + freqs_x = ((frequencies_x + 1)[..., None] * rope_freq[None, None, :]).repeat_interleave(2, dim=-1) + freqs_y = ((frequencies_y + 1)[..., None] * rope_freq[None, None, :]).repeat_interleave(2, dim=-1) + freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2] + freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0) + freq_cis = torch.view_as_complex(torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)) + self.freqs_ci = freq_cis # idx**2, idx**2, idx * 2 + + def forward(self, hidden_states): + return self.freqs_ci.to(hidden_states.device) + + +class Llama4VisionModel(Llama4PreTrainedModel): + base_model_prefix = "vision_model" + _no_split_modules = ["Llama4VisionAttention"] + config_class = Llama4VisionConfig + + def __init__(self, config: Llama4VisionConfig): + super().__init__(config) + self.image_size = config.image_size + self.patch_size = config.patch_size + self.hidden_size = config.hidden_size + self.num_channels = config.num_channels + + self.num_patches = (self.image_size // self.patch_size) ** 2 + 1 + self.scale = config.hidden_size**-0.5 + + self.patch_embedding = Llama4UnfoldConvolution(config) + + self.class_embedding = nn.Parameter(self.scale * torch.randn(self.hidden_size)) + self.positional_embedding_vlm = nn.Parameter(self.scale * torch.randn(self.num_patches, self.hidden_size)) + self.rotary_embedding = Llama4VisionRotaryEmbedding(config) + + # layer norms + self.layernorm_pre = nn.LayerNorm(self.hidden_size) + self.layernorm_post = nn.LayerNorm(self.hidden_size) + + # encoders + self.model = Llama4VisionEncoder(config) + self.vision_adapter = Llama4VisionPixelShuffleMLP(config) + self.post_init() + + def get_input_embeddings(self): + """ + This function is used to fetch the first embedding layer to activate grads on inputs. + """ + return self.patch_embedding + + def forward( + self, + pixel_values: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]: + r""" + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, MllamaVisionModel + + >>> checkpoint = "meta-llama/Llama-3.2-11B-Vision" + >>> model = MllamaVisionModel.from_pretrained(checkpoint) + >>> processor = AutoProcessor.from_pretrained(checkpoint) + + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> inputs = processor(images=image, return_tensors="pt") + + >>> output = model(**inputs) + + >>> print(output.last_hidden_state.shape) + torch.Size([1, 1, 4, 1025, 7680]) + ``` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # num_concurrent_media and num_chunks are both currently 1 + batch_size_times_num_tiles, num_channels, height, width = pixel_values.shape + num_concurrent_media = 1 + num_chunks = 1 + hidden_state = self.patch_embedding(pixel_values) + _, num_patches, hidden_dim = hidden_state.shape + + # Add cls token + hidden_state = hidden_state.reshape( + batch_size_times_num_tiles * num_concurrent_media * num_chunks, num_patches, hidden_dim + ) + class_embedding = self.class_embedding.expand(hidden_state.shape[0], 1, hidden_state.shape[-1]) + hidden_state = torch.cat([hidden_state, class_embedding], dim=1) + num_patches += 1 + + # Position embeddings + hidden_state = hidden_state.reshape( + batch_size_times_num_tiles * num_concurrent_media, num_chunks, num_patches, hidden_dim + ) + positional_embedding = self.positional_embedding_vlm.to(dtype=hidden_state.dtype, device=hidden_state.device) + hidden_state = hidden_state + positional_embedding + + hidden_state = self.layernorm_pre(hidden_state) + + hidden_state = hidden_state.view(batch_size_times_num_tiles, -1, hidden_dim) + freqs_ci = self.rotary_embedding(pixel_values) + + output = self.model( + hidden_state, + attention_mask=None, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + freqs_ci=freqs_ci, + ) + + hidden_state = output.last_hidden_state + + hidden_state = self.layernorm_post(hidden_state) + + hidden_state = hidden_state[:, :-1, :] + + # now, we use Llama4VisionPixelShuffle + mlp to project embeddings + hidden_state = self.vision_adapter(hidden_state) + + hidden_states = output.hidden_states if output_hidden_states else None + + if output_attentions: + attentions = output[2] + else: + attentions = None + + if not return_dict: + return tuple(v for v in [hidden_state, hidden_states, attentions] if v is not None) + + return BaseModelOutput( + last_hidden_state=hidden_state, + hidden_states=hidden_states, + attentions=attentions, + ) + + +class Llama4ForConditionalGeneration(Llama4PreTrainedModel, GenerationMixin): + _tp_plan = {} + base_model_prefix = "" + config_class = Llama4Config + _supports_flex_attn = True + + def __init__(self, config: Llama4Config): + super().__init__(config) + self.vision_model = Llama4VisionModel(config.vision_config) + + self.multi_modal_projector = Llama4MultiModalProjector(config) + self.language_model = Llama4ForCausalLM(config.text_config) + self.vocab_size = config.text_config.vocab_size + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.language_model.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.language_model.set_output_embeddings(new_embeddings) + + def set_decoder(self, decoder): + self.language_model.set_decoder(decoder) + + def get_decoder(self): + return self.language_model.get_decoder() + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + vision_feature_layer: Union[int, List[int]], + vision_feature_select_strategy: str, + **kwargs, + ): + """ + Obtains image last hidden states from the vision tower and apply al projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) + The tensors corresponding to the input images. + vision_feature_layer (`Union[int, List[int]]`): + The index of the layer to select the vision feature. If multiple indices are provided, + the vision feature of the corresponding indices will be concatenated to form the + vision features. + vision_feature_select_strategy (`str`): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"` + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + if vision_feature_select_strategy not in ["default", "full"]: + raise ValueError(f"Unexpected select feature strategy: {self.vision_feature_select_strategy}") + kwargs = {k: v for k, v in kwargs.items() if v is not None} + image_outputs = self.vision_model(pixel_values, output_hidden_states=False, **kwargs) + hidden_state = image_outputs.last_hidden_state + return hidden_state + + @replace_return_docstrings(output_type=Llama4CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[Union[int, List[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + image_sizes: torch.Tensor = None, + **lm_kwargs, + ) -> Union[Tuple, Llama4CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, LlavaForConditionalGeneration + + >>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf") + >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") + + >>> prompt = "USER: \nWhat's the content of the image? ASSISTANT:" + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_new_tokens=15) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed" + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer + if vision_feature_layer is not None + else self.config.vision_config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_config.vision_feature_select_strategy + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + image_features = self.get_image_features( + pixel_values=pixel_values, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + image_sizes=image_sizes, + ) + original_inputs_embeds_shape = inputs_embeds.shape + + vision_flat = image_features.view(-1, image_features.size(-1)) + projected_vision_flat = self.multi_modal_projector(vision_flat) + + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + final_mask = special_image_mask.to(inputs_embeds.device) + inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1)) + + final_mask_1d = final_mask[..., 0].reshape(-1) + num_tokens_to_fill = final_mask_1d.sum() + + if num_tokens_to_fill != projected_vision_flat.size(0): + raise ValueError( + f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, " + f"but multi_modal_projector returned {projected_vision_flat.size(0)}" + ) + + expanded_mask = final_mask_1d.unsqueeze(-1).expand(-1, inputs_embeds.size(-1)) + inputs_embeds.masked_scatter_(expanded_mask, projected_vision_flat) + + inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **lm_kwargs, + ) + + logits = outputs[0] + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + if attention_mask is not None: + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) + shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() + shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() + else: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Llama4CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + pixel_values=None, + attention_mask=None, + cache_position=None, + logits_to_keep=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = self.language_model.prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **kwargs, + ) + + if cache_position[0] == 0: + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model + model_inputs["pixel_values"] = pixel_values + + return model_inputs + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +__all__ = [ + "Llama4PreTrainedModel", + "Llama4TextModel", + "Llama4VisionModel", + "Llama4ForCausalLM", + "Llama4ForConditionalGeneration", +] diff --git a/src/transformers/models/llama4/processing_llama4.py b/src/transformers/models/llama4/processing_llama4.py new file mode 100644 index 0000000000..0ca4a44c5e --- /dev/null +++ b/src/transformers/models/llama4/processing_llama4.py @@ -0,0 +1,275 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Optional, Union + +from transformers.processing_utils import ( + ImagesKwargs, + ProcessingKwargs, + ProcessorMixin, + Unpack, +) +from transformers.tokenization_utils_base import PreTokenizedInput, TextInput + +from ...image_processing_utils import BatchFeature +from ...image_utils import ( + ImageInput, + make_flat_list_of_images, +) + + +class Llama4ImagesKwargs(ImagesKwargs, total=False): + max_patches: Optional[int] + resize_to_max_canvas: Optional[bool] + + +class Llama4ProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: Llama4ImagesKwargs + _defaults = { + "text_kwargs": { + "padding_side": "left", + }, + } + + +chat_template = "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %} \n {%- if messages[0]['content'] is string %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- else %}\n {#- FIXME: The processor requires an array, always. #}\n {%- set system_message = messages[0]['content'][0]['text']|trim %}\n {%- endif %}\n {%- set messages = messages[1:] %}\n {%- set user_supplied_system_message = true %}\n{%- else %}\n {%- set system_message = \"\" %}\n {%- set user_supplied_system_message = false %}\n{%- endif %}\n\n{#- System message if the user supplied one #}\n{%- if user_supplied_system_message %}\n {{- \"<|header_start|>system<|header_end|>\n\n\" }}\n {%- if tools is not none %}\n {{- \"Environment: ipython\n\" }}\n {%- endif %}\n {%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\n\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\n\n\" }}\n {%- endfor %}\n {%- endif %}\n {{- system_message }}\n {{- \"<|eot|>\" }}\n{%- endif %}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|header_start|>user<|header_end|>\n\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\n\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\n\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\n\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }}\n {%- if message['content'] is string %}\n {{- message['content'] }}\n {%- else %}\n {%- for content in message['content'] %}\n {%- if content['type'] == 'image' %}\n {{- '<|image|>' }}\n {%- elif content['type'] == 'text' %}\n {{- content['text'] }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- \"<|eot|>\" }}\n {%- elif 'tool_calls' in message and message.tool_calls|length > 0 %}\n {{- '<|header_start|>assistant<|header_end|>\n\n' -}}\n {{- '<|python_start|>' }}\n {%- if message['content'] is string %}\n {{- message['content'] }}\n {%- else %}\n {%- for content in message['content'] %}\n {%- if content['type'] == 'image' %}\n {{- '<|image|>' }}\n {%- elif content['type'] == 'text' %}\n {{- content['text'] }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '<|python_end|>' }}\n {%- for tool_call in message.tool_calls %}\n {{- '{\"name\": \"' + tool_call.function.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.function.arguments | tojson }}\n {{- \"}\" }}\n {%- endfor %}\n {{- \"<|eot|>\" }}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|header_start|>ipython<|header_end|>\n\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|header_start|>assistant<|header_end|>\n\n' }}\n{%- endif %}\n" + + +class Llama4Processor(ProcessorMixin): + r""" + Constructs a Llama4 processor which wraps a [`AutoImageProcessor`] and + [`PretrainedTokenizerFast`] tokenizer into a single processor that inherits both the image processor and + tokenizer functionalities. See the [`~Llama4Processor.__call__`] and [`~Llama4Processor.decode`] for more information. + Args: + image_processor ([`AutoImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`PreTrainedTokenizer`, `PreTrainedTokenizerFast`], *optional*): + The tokenizer is a required input. + patch_size (`int`, *optional*, defaults to 28): + The size of image patches for tokenization. + img_size (`int`, *optional*, defaults to 364): + The size of the image to be tokenized. This should correspond to the size given to the image processor. + image_token (`str`, *optional*, defaults to `"<|image|>"`): + The token to be used to represent an image in the text. + downsample_factor (`int`, *optional*, defaults to 1): + The factor by which to scale the patch size. + start_of_img_token (`str`, *optional*, defaults to `"<|START_OF_IMG|>"`): + The token to be used to represent the start of an image in the text. + end_of_img_token (`str`, *optional*, defaults to `"<|END_OF_IMG|>"`): + The token to be used to represent the end of an image in the text. + img_patch_token (`str`, *optional*, defaults to `"<|IMG_PATCH|>"`): + The token to be used to represent an image patch in the text. + img_line_break_token (`str`, *optional*, defaults to `"<|IMG_LINE_BREAK|>"`): + The token to be used to represent a line break in the text. + tile_token (`str`, *optional*, defaults to `"TILE"`): + The token to be used to represent an image patch in the text. + tile_global_token (`str`, *optional*, defaults to `"TILE_GLOBAL"`): + The token to be used to represent the cover image in the text. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + attributes = ["image_processor", "tokenizer"] + valid_kwargs = [ + "chat_template", + "image_token", + "patch_size", + "img_size", + "downsample_factor", + "start_of_img_token", + "end_of_img_token", + "img_patch_token", + "img_line_break_token", + "tile_token", + "tile_global_token", + ] + image_processor_class = "AutoImageProcessor" + tokenizer_class = "AutoTokenizer" + + def __init__( + self, + image_processor=None, + tokenizer=None, + patch_size: int = 14, + pixel_shuffle_ratio: float = 0.5, + fake_image_token="<|image|>", + image_token="<|image|>", + start_of_image_token="<|image_start|>", + end_of_image_token="<|image_end|>", + patch_token="<|patch|>", + tile_x_separator_token="<|tile_x_separator|>", + tile_y_separator_token="<|tile_y_separator|>", + chat_template=chat_template, + **kwargs, + ): + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + self.downsample_ratio = int(round(1.0 / (pixel_shuffle_ratio**2))) + self.patch_size = patch_size + + self.fake_image_token = fake_image_token + self.image_token = image_token + self.start_of_img_token = start_of_image_token + self.end_of_img_token = end_of_image_token + self.img_patch_token = patch_token + self.tile_token = tile_x_separator_token + self.tile_global_token = tile_y_separator_token + + def _prompt_split_image(self, aspect_ratio, num_patches_per_chunk): + """ + Create a structured string representation of image tokens + + Args: + num_patches: Number of patches in the image + + Returns: + String with appropriate image tokens + """ + img_string = "<|image_start|>" + ratio_h, ratio_w = aspect_ratio + if ratio_h * ratio_w > 1: + for yy in range(ratio_h): + for xx in range(ratio_w): + img_string += "<|patch|>" * num_patches_per_chunk + if xx < ratio_w - 1: + img_string += "<|tile_x_separator|>" + + img_string += "<|tile_y_separator|>" + img_string += "<|image|>" + img_string += "<|patch|>" * num_patches_per_chunk + img_string += "<|image_end|>" + + return img_string + + def __call__( + self, + images: Optional[ImageInput] = None, + text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + audio=None, + videos=None, + **kwargs: Unpack[Llama4ProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] to encode the text. + To prepare the vision inputs, this method forwards the `images` and `kwargs` arguments to + Llama4ImageProcessor's [`~Llama4ImageProcessor.__call__`] if `images` is not `None`. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + if text is None: + raise ValueError("You have to specify text.") + + output_kwargs = self._merge_kwargs( + Llama4ProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + if not isinstance(text, (list, tuple)): + text = [text] + + # Process images + image_inputs = {} + if images is not None: + images = make_flat_list_of_images(images) + image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) + image_height, image_width = image_inputs["pixel_values"][0].shape[-2:] + num_patches_per_chunk = int( + (image_height // self.patch_size) * (image_width // self.patch_size) // self.downsample_ratio + ) + aspect_ratios = image_inputs.pop("aspect_ratios") + + total_placeholders = sum(prompt.count(self.fake_image_token) for prompt in text) + if total_placeholders != len(images): + raise ValueError( + f"Found {total_placeholders} placeholders across the batch, " + f"but have {len(images)} flattened images." + ) + + image_index = 0 + processed_text = [] + for prompt in text: + placeholder_count = prompt.count(self.fake_image_token) + if placeholder_count == 0: + # do nothing if there is no image + processed_text.append(prompt) + continue + prompt_splits = prompt.split(self.fake_image_token) + new_prompt = [] + for local_image_index, split_part in enumerate(prompt_splits): + new_prompt.append(split_part) + if local_image_index < placeholder_count: + tokens_for_this_image = self._prompt_split_image( + aspect_ratios[image_index], num_patches_per_chunk + ) + image_index += 1 + new_prompt.append(tokens_for_this_image) + processed_text.append("".join(new_prompt)) + + if image_index != len(images): + raise ValueError("Number of image placeholders in the prompt does not match the number of images.") + + text = processed_text + + text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) + + return BatchFeature(data={**text_inputs, **image_inputs}) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(tokenizer_input_names) + list(image_processor_input_names) + + +__all__ = ["Llama4Processor"] diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 714d69baf4..1df663a26e 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -981,6 +981,8 @@ class Pipeline(_ScikitCompat, PushToHubMixin): else: self.device = device if device is not None else -1 + if torch.distributed.is_initialized(): + self.device = self.model.device logger.warning(f"Device set to use {self.device}") self.binary_output = binary_output diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index b486526388..b1c40e7ff2 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -1178,10 +1178,6 @@ class ProcessorMixin(PushToHubMixin): unused_kwargs = {} unused_keys = set(kwargs_from_config) - set(valid_kwargs) if unused_keys: - unused_key_str = ", ".join(unused_keys) - logger.warning( - f"Some kwargs in processor config are unused and will not have any effect: {unused_key_str}. " - ) unused_kwargs = {k: processor_config[k] for k in unused_keys} return unused_kwargs diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index cf8a120958..ff3856c9d4 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -43,8 +43,7 @@ is_torch_greater_or_equal_than_1_12 = is_torch_greater_or_equal("1.12", accept_d _torch_distributed_available = torch.distributed.is_available() if is_torch_greater_or_equal("2.5") and _torch_distributed_available: - from torch.distributed.tensor import Replicate - from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel + pass def softmax_backward_data(parent, grad_output, output, dim, self): @@ -335,29 +334,6 @@ def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int) return torch.isin(elements, test_elements) -# TODO need to add the __repr__ that shows that it is a colwise parallel -# See https://github.com/pytorch/pytorch/issues/145726 -def translate_to_torch_parallel_style(style: str): - """ - In model configurations, we use a neutral type (string) to specify parallel - styles, here we translate them into torch.distributed tensor-parallel - types. - """ - if not isinstance(style, str): - raise ValueError(f"Unsupported parallel style type {type(style)}, expected str") - - if style == "colwise": - return ColwiseParallel() - elif style == "rowwise": - return RowwiseParallel() - elif style == "colwise_rep": - return ColwiseParallel(output_layouts=Replicate()) - elif style == "rowwise_rep": - return RowwiseParallel(input_layouts=Replicate()) - else: - raise ValueError(f"Unsupported parallel style value: {style}") - - def compile_compatible_method_lru_cache(*lru_args, **lru_kwargs): """ LRU cache decorator from standard functools library, but with a workaround to disable @@ -382,88 +358,3 @@ def compile_compatible_method_lru_cache(*lru_args, **lru_kwargs): return wrapper return decorator - - -def distribute_module( - module: nn.Module, - device_mesh=None, - partition_fn=None, - input_fn=None, - output_fn=None, -) -> nn.Module: - """ - This function expose three functions to control the parameters/inputs/outputs of the module: - - 1. To perform sharding on the module before runtime execution by specifying the - ``partition_fn`` (i.e. allow user to convert Module parameters to :class:`DTensor` - parameters according to the `partition_fn` specified). - 2. To control the inputs or outputs of the module during runtime execution by - specifying the ``input_fn`` and ``output_fn``. (i.e. convert the input to - :class:`DTensor`, convert the output back to ``torch.Tensor``) - - Args: - module (:class:`nn.Module`): user module to be partitioned. - device_mesh (:class:`DeviceMesh`): the device mesh to place the module. - partition_fn (Callable): the function to partition parameters (i.e. shard certain - parameters across the ``device_mesh``). If ``partition_fn`` is not specified, - by default we replicate all module parameters of ``module`` across the mesh. - input_fn (Callable): specify the input distribution, i.e. could control how the - input of the module is sharded. ``input_fn`` will be installed as a module - ``forward_pre_hook`` (pre forward hook). - output_fn (Callable): specify the output distribution, i.e. could control how the - output is sharded, or convert it back to torch.Tensor. ``output_fn`` will be - installed as a module ``forward_hook`` (post forward hook). - - Returns: - A module that contains parameters/buffers that are all ``DTensor`` s. - - .. note:: - When initialize the DeviceMesh with the ``xla`` device_type, ``distribute_module`` - return nn.Module with PyTorch/XLA SPMD annotated parameters. See - `this issue `__ - for more details. The XLA integration is experimental and subject to change. - - """ - - torch._C._log_api_usage_once("torch.dtensor.distribute_module") - - device_mesh = device_mesh - - # register input_fn as module forward pre hook - if input_fn is not None: - # check the input_fn signature - num_args = len(inspect.signature(input_fn).parameters) - if num_args == 2: - # input_fn only takes in inputs and device mesh - logger.warning( - "Deprecating input_fn that takes two arguments (inputs, device_mesh), " - "please use input_fn that takes in (module, inputs, device_mesh) instead!", - FutureWarning, - stacklevel=2, - ) - module.register_forward_pre_hook(lambda _, inputs: input_fn(inputs, device_mesh)) # type: ignore[call-arg] - elif num_args == 3: - # input_fn takes in module, inputs, device mesh - module.register_forward_pre_hook(lambda mod, inputs: input_fn(mod, inputs, device_mesh)) - else: - raise ValueError(f"input_fn should take in 3 arguments, but got {num_args} arguments!") - # register output_fn as module forward hook - if output_fn is not None: - num_args = len(inspect.signature(output_fn).parameters) - if num_args == 2: - # output_fn only takes in outputs and device mesh - logger.warning( - "Deprecating output_fn that takes two arguments (inputs, device_mesh), " - "please use output_fn that takes in (module, inputs, device_mesh) instead!", - FutureWarning, - stacklevel=2, - ) - module.register_forward_hook( - lambda mod, inputs, outputs: output_fn(outputs, device_mesh) # type: ignore[call-arg] - ) - elif num_args == 3: - module.register_forward_hook(lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh)) - else: - raise ValueError(f"output_fn should take in 3 arguments, but got {num_args} arguments!") - - return module diff --git a/src/transformers/quantizers/base.py b/src/transformers/quantizers/base.py old mode 100755 new mode 100644 index 46c44b79f2..a780dca754 --- a/src/transformers/quantizers/base.py +++ b/src/transformers/quantizers/base.py @@ -15,7 +15,8 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from ..utils import is_torch_available -from ..utils.quantization_config import QuantizationConfigMixin +from ..utils.quantization_config import QuantizationConfigMixin, QuantizationMethod +from .quantizers_utils import get_module_from_name if TYPE_CHECKING: @@ -23,6 +24,9 @@ if TYPE_CHECKING: if is_torch_available(): import torch + from torch.nn import ModuleList +else: + ModuleList = str class HfQuantizer(ABC): @@ -198,6 +202,10 @@ class HfQuantizer(ABC): """ return + def update_tp_plan(self, config): + "updates the tp plan for the scales" + return config + def preprocess_model(self, model: "PreTrainedModel", **kwargs): """ Setting model attributes and/or converting model before weights loading. At this point @@ -212,6 +220,7 @@ class HfQuantizer(ABC): """ model.is_quantized = True model.quantization_method = self.quantization_config.quant_method + self._convert_model_for_quantization(model) return self._process_model_before_weight_loading(model, **kwargs) def postprocess_model(self, model: "PreTrainedModel", **kwargs): @@ -288,3 +297,44 @@ class HfQuantizer(ABC): @property @abstractmethod def is_trainable(self): ... + + def _convert_model_for_quantization(self, model): + from accelerate import init_empty_weights + + for name, module in model.named_modules(): + module_class_name = module.__class__.__name__ + if ( + module_class_name in MODULES_TO_PATCH_FOR_QUANTIZATION.keys() + and self.quantization_config.quant_method == QuantizationMethod.COMPRESSED_TENSORS + ): + with init_empty_weights(): + parent_module, name = get_module_from_name(model, name) + parent_module._modules[name] = MODULES_TO_PATCH_FOR_QUANTIZATION[module_class_name]( + model.config.get_text_config() + ) + + +class SequentialLlama4TextExperts(ModuleList): + """ + A module that implements a compressed version of a list of expert modules. + This is specifically designed to work with Llama4TextExperts in MoE layers. + """ + + def __init__(self, config): + from transformers.models.llama4.modeling_llama4 import Llama4TextMLP + + super().__init__([Llama4TextMLP(config) for _ in range(config.num_local_experts)]) + self.num_experts = config.num_local_experts + + def forward( + self, + hidden_states: "torch.Tensor", + ) -> "torch.Tensor": + hidden_states = hidden_states.reshape(self.num_experts, -1, hidden_states.shape[-1]) + routed_out = torch.zeros_like(hidden_states) + for expert_idx in range(self.num_experts): + routed_out[expert_idx] = self[expert_idx](hidden_states[expert_idx]) + return routed_out + + +MODULES_TO_PATCH_FOR_QUANTIZATION = {"Llama4TextExperts": SequentialLlama4TextExperts} diff --git a/src/transformers/quantizers/quantizer_compressed_tensors.py b/src/transformers/quantizers/quantizer_compressed_tensors.py index ee1d0df380..4e45abf953 100644 --- a/src/transformers/quantizers/quantizer_compressed_tensors.py +++ b/src/transformers/quantizers/quantizer_compressed_tensors.py @@ -146,6 +146,19 @@ class CompressedTensorsHfQuantizer(HfQuantizer): self.compressor.quantization_config.quantization_status = QuantizationStatus.FROZEN self.compressor.decompress(model_path=cache_path, model=model) + def update_tp_plan(self, config): + additional_plan = { + "layers.*.feed_forward.experts.*.gate_proj.weight": "local_colwise", + "layers.*.feed_forward.experts.*.gate_proj.weight_scale": "local_colwise", + "layers.*.feed_forward.experts.*.up_proj.weight": "local_colwise", + "layers.*.feed_forward.experts.*.up_proj.weight_scale": "local_colwise", + "layers.*.feed_forward.experts.*.down_proj.weight": "local_rowwise", + } + if config.get_text_config() is not None and config.get_text_config().base_model_tp_plan is not None: + config.get_text_config().base_model_tp_plan.update(additional_plan) + + return config + @property def is_trainable(self): return True diff --git a/src/transformers/quantizers/quantizer_fbgemm_fp8.py b/src/transformers/quantizers/quantizer_fbgemm_fp8.py index 69c55d01bd..f0fa8f063b 100644 --- a/src/transformers/quantizers/quantizer_fbgemm_fp8.py +++ b/src/transformers/quantizers/quantizer_fbgemm_fp8.py @@ -116,7 +116,7 @@ class FbgemmFp8HfQuantizer(HfQuantizer): state_dict: Dict[str, Any], **kwargs, ): - from ..integrations import FbgemmFp8Linear + from ..integrations import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts module, tensor_name = get_module_from_name(model, param_name) @@ -129,6 +129,13 @@ class FbgemmFp8HfQuantizer(HfQuantizer): if tensor_name == "weight_scale": raise ValueError("Expect unquantized weights but got a quantized weight_scale") return True + if isinstance(module, FbgemmFp8Llama4TextExperts): + if self.pre_quantized or tensor_name == "bias": + return False + else: + if tensor_name == "gate_up_proj_scale" or tensor_name == "down_proj_scale": + raise ValueError("Expect unquantized weights but got a quantized weight_scale") + return True return False def create_quantized_param( @@ -143,12 +150,52 @@ class FbgemmFp8HfQuantizer(HfQuantizer): """ Quantizes weights into weight and weight_scale """ - new_value, weight_scale = torch.ops.fbgemm.quantize_fp8_per_row(param_value) + + from ..integrations import FbgemmFp8Llama4TextExperts module, tensor_name = get_module_from_name(model, param_name) - module._buffers[tensor_name] = new_value.to(target_device) - # to have the right output shape -> (out_features, 1) - module._buffers["weight_scale"] = weight_scale.view(weight_scale.shape[0], 1).to(target_device) + if isinstance(module, FbgemmFp8Llama4TextExperts): + if tensor_name == "gate_up_proj": + # Process each expert separately + # Transpose the second and third dimension + transposed_param = param_value.transpose(1, 2) + + # Reshape to 2D for quantization + original_shape = transposed_param.shape + flattened_param = transposed_param.reshape(-1, original_shape[-1]) + + # Quantize using per row instead of per column + new_value_flat, weight_scale_flat = torch.ops.fbgemm.quantize_fp8_per_row(flattened_param) + + # Reshape back to original dimensions + new_value = new_value_flat.reshape(original_shape) + new_value = new_value.transpose(1, 2) + weight_scale = weight_scale_flat.reshape(original_shape[0], 1, original_shape[1]) + elif tensor_name == "down_proj": + # Process each expert separately + # Transpose the weights for proper quantization + transposed_param = param_value.transpose(1, 2) + + # Reshape to 2D for quantization + original_shape = transposed_param.shape + flattened_param = transposed_param.reshape(-1, original_shape[-1]) + + # Quantize using per column + new_value_flat, weight_scale_flat = torch.ops.fbgemm.quantize_fp8_per_row(flattened_param) + + # Reshape back to original dimensions + new_value = new_value_flat.reshape(original_shape) + new_value = new_value.transpose(1, 2) + weight_scale = weight_scale_flat.reshape(original_shape[0], original_shape[1], 1) + + module._parameters[f"{tensor_name}_scale"] = torch.nn.Parameter(weight_scale.to(target_device)) + else: + new_value, weight_scale = torch.ops.fbgemm.quantize_fp8_per_row(param_value) + module._parameters[f"{tensor_name}_scale"] = torch.nn.Parameter( + weight_scale.view(weight_scale.shape[0], 1).to(target_device) + ) + + module._parameters[tensor_name] = torch.nn.Parameter(new_value.to(target_device)) if unexpected_keys is not None and param_name in unexpected_keys: unexpected_keys.remove(param_name) @@ -165,25 +212,29 @@ class FbgemmFp8HfQuantizer(HfQuantizer): ): from ..integrations import replace_with_fbgemm_fp8_linear + tp_plan = model._tp_plan self.modules_to_not_convert = self.get_modules_to_not_convert( model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules ) + config = model.config model = replace_with_fbgemm_fp8_linear( model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config, pre_quantized=self.pre_quantized, + config=config, + tp_plan=tp_plan, ) model.config.quantization_config = self.quantization_config def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]: - from ..integrations import FbgemmFp8Linear + from ..integrations import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts not_missing_keys = [] for name, module in model.named_modules(): - if isinstance(module, FbgemmFp8Linear): + if isinstance(module, FbgemmFp8Linear) or isinstance(module, FbgemmFp8Llama4TextExperts): for missing in missing_keys: if ( (name in missing or name in f"{prefix}.{missing}") diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index cb1169fe02..8e047569e7 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -3950,7 +3950,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): verbose (`bool`): Whether or not to print more information and warnings. """ - if max_length is None and len(ids) > self.model_max_length and verbose: + if max_length is None and len(ids) > self.model_max_length and verbose and self.model_max_length != 0: if not self.deprecation_warnings.get("sequence-length-is-longer-than-the-specified-maximum", False): logger.warning( "Token indices sequence length is longer than the specified maximum sequence length " diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index be964d8490..b03455c89e 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -5823,6 +5823,41 @@ class LlamaPreTrainedModel(metaclass=DummyObject): requires_backends(self, ["torch"]) +class Llama4ForCausalLM(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Llama4ForConditionalGeneration(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Llama4PreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Llama4TextModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Llama4VisionModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class LlavaForConditionalGeneration(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/transformers/utils/dummy_torchvision_objects.py b/src/transformers/utils/dummy_torchvision_objects.py index 50314fc55e..c59c9b4bdd 100644 --- a/src/transformers/utils/dummy_torchvision_objects.py +++ b/src/transformers/utils/dummy_torchvision_objects.py @@ -72,6 +72,13 @@ class GotOcr2ImageProcessorFast(metaclass=DummyObject): requires_backends(self, ["torchvision"]) +class Llama4ImageProcessorFast(metaclass=DummyObject): + _backends = ["torchvision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torchvision"]) + + class LlavaImageProcessorFast(metaclass=DummyObject): _backends = ["torchvision"] diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py index ffd15d64ac..4e7293d50a 100644 --- a/src/transformers/utils/dummy_vision_objects.py +++ b/src/transformers/utils/dummy_vision_objects.py @@ -408,6 +408,13 @@ class LevitImageProcessor(metaclass=DummyObject): requires_backends(self, ["vision"]) +class Llama4ImageProcessor(metaclass=DummyObject): + _backends = ["vision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["vision"]) + + class LlavaImageProcessor(metaclass=DummyObject): _backends = ["vision"] diff --git a/tests/models/llama4/__init__.py b/tests/models/llama4/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/models/llama4/test_image_processing_llama4.py b/tests/models/llama4/test_image_processing_llama4.py new file mode 100644 index 0000000000..bf84b3550d --- /dev/null +++ b/tests/models/llama4/test_image_processing_llama4.py @@ -0,0 +1,128 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +from transformers.testing_utils import require_torch, require_vision +from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available + +from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs + + +if is_torch_available(): + pass + +if is_vision_available() and is_torchvision_available(): + from transformers import Llama4ImageProcessorFast + + +class Llama4ImageProcessingTester(unittest.TestCase): + def __init__( + self, + parent, + batch_size=7, + num_channels=3, + image_size=18, + min_resolution=30, + max_resolution=400, + max_patches=1, + do_resize=True, + size=None, + do_normalize=True, + do_pad=False, + image_mean=[0.5, 0.5, 0.5], + image_std=[0.5, 0.5, 0.5], + do_convert_rgb=True, + ): + super().__init__() + size = size if size is not None else {"height": 20, "width": 20} + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.image_size = image_size + self.min_resolution = min_resolution + self.max_resolution = max_resolution + self.max_patches = max_patches + self.do_resize = do_resize + self.size = size + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + self.do_pad = do_pad + self.do_convert_rgb = do_convert_rgb + + def prepare_image_processor_dict(self): + return { + "max_patches": self.max_patches, + "do_resize": self.do_resize, + "size": self.size, + "do_normalize": self.do_normalize, + "image_mean": self.image_mean, + "image_std": self.image_std, + "do_convert_rgb": self.do_convert_rgb, + "do_pad": self.do_pad, + } + + def expected_output_image_shape(self, images): + return self.num_channels, self.size["height"], self.size["width"] + + def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False): + return prepare_image_inputs( + batch_size=self.batch_size, + num_channels=self.num_channels, + min_resolution=self.min_resolution, + max_resolution=self.max_resolution, + equal_resolution=equal_resolution, + numpify=numpify, + torchify=torchify, + ) + + +@require_torch +@require_vision +class Llama4ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): + test_slow_image_processor = False + fast_image_processing_class = Llama4ImageProcessorFast if is_torchvision_available() else None + + def setUp(self): + super().setUp() + self.image_processor_tester = Llama4ImageProcessingTester(self) + + @property + def image_processor_dict(self): + return self.image_processor_tester.prepare_image_processor_dict() + + def test_image_processor_properties(self): + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processor, "do_resize")) + self.assertTrue(hasattr(image_processor, "size")) + self.assertTrue(hasattr(image_processor, "do_normalize")) + self.assertTrue(hasattr(image_processor, "image_mean")) + self.assertTrue(hasattr(image_processor, "image_std")) + self.assertTrue(hasattr(image_processor, "do_convert_rgb")) + + def test_split_tiles(self): + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class(**self.image_processor_dict) + image = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)[0] + processed_images = image_processor( + image, + max_patches=16, + ) + self.assertEqual(len(processed_images.pixel_values), 1) + self.assertEqual(processed_images.pixel_values[0].shape[0], 17) + self.assertEqual(processed_images.pixel_values[0].shape[-2:], (20, 20)) diff --git a/tests/models/llama4/test_modeling_llama4.py b/tests/models/llama4/test_modeling_llama4.py new file mode 100644 index 0000000000..65672993a0 --- /dev/null +++ b/tests/models/llama4/test_modeling_llama4.py @@ -0,0 +1,121 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch Llama4 model.""" + +import unittest + +from transformers import is_torch_available +from transformers.testing_utils import ( + require_read_token, + require_torch_large_gpu, + slow, + torch_device, +) + + +if is_torch_available(): + import torch + + from transformers import ( + Llama4ForConditionalGeneration, + Llama4Processor, + ) + + +@slow +@require_torch_large_gpu +@require_read_token +class Llama4IntegrationTest(unittest.TestCase): + model_id = "ll-re/Llama-4-17B-Omni-Instruct" + # This variable is used to determine which CUDA device are we using for our runners (A10 or T4) + # Depending on the hardware we get different logits / generations + cuda_compute_capability_major_version = None + + @classmethod + def setUpClass(cls): + if is_torch_available() and torch.cuda.is_available(): + # 8 is for A100 / A10 and 7 for T4 + cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] + cls.model = Llama4ForConditionalGeneration.from_pretrained( + "ll-re/Llama-4-17B-Omni-Instruct", device_map="auto", torch_dtype=torch.float32 + ) + + def setUp(self): + self.processor = Llama4Processor.from_pretrained("ll-re/Llama-4-17B-Omni-Instruct", padding_side="left") + + url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png" + self.messages = [ + {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, + { + "role": "user", + "content": [ + {"type": "image", "url": url}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + + def test_model_17b_16e_fp16(self): + EXPECTED_TEXT = [ + "The capital of France is Paris, which is located in the north-central part of the country. Paris is known for its iconic landmarks such as the", + "Roses are red, violets are blue, and this poem is about you. Roses are red, violets are blue, and I love", + ] + + messages = [ + {"role": "user", "content": "Who are you?"}, + ] + inputs = self.processor.apply_chat_template( + messages, add_generation_prompt=True, return_tensors="pt", return_dict=True + ).to(torch_device) + + output = self.model.generate(**inputs, max_new_tokens=100) + output_text = self.processor.batch_decode(output, skip_special_tokens=True) + + print(output_text) + self.assertEqual(output_text, EXPECTED_TEXT) + + def test_model_17b_16e_batch(self): + messages_2 = [ + {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, + { + "role": "user", + "content": [ + { + "type": "image", + "url": "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png", + }, + {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}, + {"type": "text", "text": "Are these images identical?"}, + ], + }, + ] + + inputs = self.processor.apply_chat_template( + [self.messages, messages_2], + tokenize=True, + return_dict=True, + return_tensors="pt", + padding=True, + add_generation_prompt=True, + ).to(torch_device) + + output = self.model.generate(**inputs, max_new_tokens=30, do_sample=False) + output_text = self.processor.batch_decode(output, skip_special_tokens=True) + + EXPECTED_TEXTS = [ + 'user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear turquoise water and a blue sky in the background. It looks like', + "user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, these images are not identical. \n\nHere's a breakdown of the differences:\n\n* **Image 1:** Shows a cow" + ] # fmt: skip + self.assertEqual(output_text, EXPECTED_TEXTS) diff --git a/tests/models/llama4/test_processor_llama4.py b/tests/models/llama4/test_processor_llama4.py new file mode 100644 index 0000000000..4ec01fa497 --- /dev/null +++ b/tests/models/llama4/test_processor_llama4.py @@ -0,0 +1,65 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil +import tempfile +import unittest +from typing import Optional + +from transformers import AutoProcessor, Llama4Processor, PreTrainedTokenizerFast +from transformers.testing_utils import require_vision +from transformers.utils import is_vision_available + +from ...test_processing_common import ProcessorTesterMixin + + +if is_vision_available(): + from transformers import Llama4ImageProcessorFast + + +@require_vision +class Llama4ProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = Llama4Processor + + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + + image_processor = Llama4ImageProcessorFast(max_patches=1, size={"height": 20, "width": 20}) + tokenizer = PreTrainedTokenizerFast.from_pretrained("unsloth/Llama-3.2-11B-Vision-Instruct-unsloth-bnb-4bit") + processor_kwargs = self.prepare_processor_dict() + processor = Llama4Processor(image_processor, tokenizer, **processor_kwargs) + processor.save_pretrained(self.tmpdirname) + + def get_tokenizer(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer + + def get_image_processor(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + + # Override as Llama4ProcessorProcessor needs image tokens in prompts + def prepare_text_inputs(self, batch_size: Optional[int] = None): + if batch_size is None: + return "lower newer " + + if batch_size < 1: + raise ValueError("batch_size must be greater than 0") + + if batch_size == 1: + return ["lower newer "] + return ["lower newer ", " upper older longer string"] + [" lower newer"] * ( + batch_size - 2 + ) diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index 7ce1bdd0f8..76fc2cc428 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -236,6 +236,16 @@ SPECIAL_CASES_TO_ALLOW = { "text_config", "vision_config", ], + "Llama4Config": ["boi_token_index", "eoi_token_index"], + "Llama4TextConfig": [ + "interleave_moe_layer_step", + "no_rope_layer_interval", + "no_rope_layers", + "output_router_logits", + "router_aux_loss_coef", + "router_jitter_noise", + ], + "Llama4VisionConfig": ["multi_modal_projector_bias", "norm_eps"], } @@ -358,6 +368,8 @@ def check_attribute_being_used(config_class, attributes, default_value, source_s "rope_theta", "partial_rotary_factor", "pretraining_tp", + "boi_token_index", + "eoi_token_index", ] attributes_used_in_generation = ["encoder_no_repeat_ngram_size"] diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py index bdcec87c2b..b9f6645638 100644 --- a/utils/check_docstrings.py +++ b/utils/check_docstrings.py @@ -67,6 +67,7 @@ _re_parse_description = re.compile(r"\*optional\*, defaults to (.*)$") # docstrings instead. If formatting should be ignored for the docstring, you can put a comment # no-format on the # line before the docstring. OBJECTS_TO_IGNORE = [ + "Llama4Processor", # Deprecated "InputExample", "InputFeatures", diff --git a/utils/check_dummies.py b/utils/check_dummies.py index 73d7ebbfd1..48a6b2fa71 100644 --- a/utils/check_dummies.py +++ b/utils/check_dummies.py @@ -223,13 +223,20 @@ def check_dummies(overwrite: bool = False): f.write(dummy_files[backend]) else: # Temporary fix to help people identify which objects introduced are not correctly protected. + found = False for _actual, _dummy in zip( actual_dummies["torch"].split("class"), dummy_files["torch"].split("class") ): if _actual != _dummy: actual_broken = _actual dummy_broken = _dummy + found = True break + + if not found: + print("A transient error was found with the dummies, please investigate.") + continue + raise ValueError( "The main __init__ has objects that are not present in " f"transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py.\n" diff --git a/utils/check_repo.py b/utils/check_repo.py index b4119bb7b6..42beda83e6 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -144,6 +144,8 @@ IGNORE_NON_TESTED = ( "Qwen2_5_VLModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5_VLForConditionalGeneration. "MllamaTextModel", # Building part of bigger (tested) model. # TODO: add tests "MllamaVisionModel", # Building part of bigger (tested) model. # TODO: add tests + "Llama4TextModel", # Building part of bigger (tested) model. # TODO: add tests + "Llama4VisionModel", # Building part of bigger (tested) model. # TODO: add tests "Emu3VQVAE", # Building part of bigger (tested) model "Emu3TextModel", # Building part of bigger (tested) model ] @@ -170,6 +172,7 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [ "models/decision_transformer/test_modeling_decision_transformer.py", "models/bark/test_modeling_bark.py", "models/shieldgemma2/test_modeling_shieldgemma2.py", + "models/llama4/test_modeling_llama4.py", ] # Update this list for models that are not in any of the auto MODEL_XXX_MAPPING. Being in this list is an exception and