Mllama fsdp (#36000)

* pixel input assignment revoked

* double send

* Update src/transformers/models/mllama/modeling_mllama.py

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>

---------

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
This commit is contained in:
Benjamin Badger
2025-02-13 03:49:39 -05:00
committed by GitHub
parent 847854b023
commit 1614d196e8

View File

@@ -1541,7 +1541,9 @@ class MllamaVisionModel(MllamaPreTrainedModel):
aspect_ratio_ids = aspect_ratio_ids.reshape(batch_size * num_concurrent_media, -1)
# Patch embedding
patch_embeds = self.patch_embedding(pixel_values.to(self.dtype).to(self.device))
target_dtype = self.patch_embedding.weight.dtype
target_device = self.patch_embedding.weight.device
patch_embeds = self.patch_embedding(pixel_values.to(target_device, target_dtype))
hidden_state = patch_embeds.flatten(2).transpose(1, 2)
# Tile embeddings