[PixtralLarge] Update Pixtral conversion script to support large format! (#34801)
* update conversion script * update for bias again * remove pdv * use my dir * Update how we initialize the tokenizer * Convert in bfloat16 * Undo that one again * fix config dump * .to() was broken for BatchMixFeature * quick debug breakpoint * put the breakpoint in the right place * Add a config flag for the multimodal projector bias * Add a config flag for the multimodal projector bias * Conversion script can load chat templates * Indent config for comparison * Stop clobbering the config * Re-enable the config clobber * Get rid of the config manual save - it has no effect! * Handle adapter bias correctly * Default vision transformer activation to silu * Remove legacy processing path * One commit with all the debug breakpoints before I delete them all, in case I need to revert * Update conversion * Remove vLLM debugging instrumentation * Drop xformers * Remove debug enumerates * make fixup * make fixup * Break copied from in pixtral * Propagate multimodal_projector_bias change * Propagate multimodal_projector_bias change * Remove debug device .to() * Restore attention weights output * Fix Pixtral test * Drop image_seq_length * Drop image_seq_length * Put the legacy processing code back * Add the bias option to the llava_next_video config * Add the bias option to the llava_next_video config * Make certain args required in converter * Make certain args required in converter * typo * make fixup * Reverting some dtype changes since it seems to work without them --------- Co-authored-by: arthur@huggingface.co <arthur@ip-26-0-166-244.ec2.internal> Co-authored-by: Matt <rocketknight1@gmail.com> Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
@@ -50,6 +50,8 @@ class LlavaConfig(PretrainedConfig):
|
||||
The index of the layer to select the vision feature.
|
||||
image_seq_length (`int`, *optional*, defaults to 576):
|
||||
Sequence length of one image embedding.
|
||||
multimodal_projector_bias (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use bias in the multimodal projector.
|
||||
|
||||
Example:
|
||||
|
||||
@@ -85,6 +87,7 @@ class LlavaConfig(PretrainedConfig):
|
||||
vision_feature_select_strategy="default",
|
||||
vision_feature_layer=-2,
|
||||
image_seq_length=576,
|
||||
multimodal_projector_bias=True,
|
||||
**kwargs,
|
||||
):
|
||||
self.ignore_index = ignore_index
|
||||
@@ -127,6 +130,7 @@ class LlavaConfig(PretrainedConfig):
|
||||
text_config = CONFIG_MAPPING["llama"]()
|
||||
|
||||
self.text_config = text_config
|
||||
self.multimodal_projector_bias = multimodal_projector_bias
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
@@ -86,10 +86,13 @@ class LlavaCausalLMOutputWithPast(ModelOutput):
|
||||
class LlavaMultiModalProjector(nn.Module):
|
||||
def __init__(self, config: LlavaConfig):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
|
||||
self.linear_1 = nn.Linear(
|
||||
config.vision_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
|
||||
)
|
||||
self.act = ACT2FN[config.projector_hidden_act]
|
||||
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
|
||||
self.linear_2 = nn.Linear(
|
||||
config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
|
||||
)
|
||||
|
||||
def forward(self, image_features):
|
||||
hidden_states = self.linear_1(image_features)
|
||||
|
||||
@@ -55,6 +55,8 @@ class LlavaNextConfig(PretrainedConfig):
|
||||
Whether the model's input and output word embeddings should be tied.
|
||||
image_seq_length (`int`, *optional*, defaults to 576):
|
||||
Sequence length of one image embedding.
|
||||
multimodal_projector_bias (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use bias in the multimodal projector.
|
||||
|
||||
Example:
|
||||
|
||||
@@ -92,12 +94,14 @@ class LlavaNextConfig(PretrainedConfig):
|
||||
image_grid_pinpoints=None,
|
||||
tie_word_embeddings=False,
|
||||
image_seq_length=576,
|
||||
multimodal_projector_bias=True,
|
||||
**kwargs,
|
||||
):
|
||||
self.ignore_index = ignore_index
|
||||
self.image_token_index = image_token_index
|
||||
self.projector_hidden_act = projector_hidden_act
|
||||
self.image_seq_length = image_seq_length
|
||||
self.multimodal_projector_bias = multimodal_projector_bias
|
||||
|
||||
if vision_feature_select_strategy not in ["default", "full"]:
|
||||
raise ValueError(
|
||||
|
||||
@@ -194,10 +194,13 @@ class LlavaNextCausalLMOutputWithPast(ModelOutput):
|
||||
class LlavaNextMultiModalProjector(nn.Module):
|
||||
def __init__(self, config: LlavaNextConfig):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
|
||||
self.linear_1 = nn.Linear(
|
||||
config.vision_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
|
||||
)
|
||||
self.act = ACT2FN[config.projector_hidden_act]
|
||||
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
|
||||
self.linear_2 = nn.Linear(
|
||||
config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
|
||||
)
|
||||
|
||||
def forward(self, image_features):
|
||||
hidden_states = self.linear_1(image_features)
|
||||
|
||||
@@ -44,6 +44,8 @@ class LlavaNextVideoConfig(PretrainedConfig):
|
||||
The image token index to encode the image prompt.
|
||||
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
|
||||
The activation function used by the multimodal projector.
|
||||
multimodal_projector_bias (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use bias in the multimodal projector.
|
||||
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features.
|
||||
@@ -95,6 +97,7 @@ class LlavaNextVideoConfig(PretrainedConfig):
|
||||
ignore_index=-100,
|
||||
image_token_index=32001,
|
||||
projector_hidden_act="gelu",
|
||||
multimodal_projector_bias=True,
|
||||
vision_feature_select_strategy="default",
|
||||
vision_feature_layer=-2,
|
||||
image_grid_pinpoints=None,
|
||||
@@ -114,6 +117,7 @@ class LlavaNextVideoConfig(PretrainedConfig):
|
||||
self.ignore_index = ignore_index
|
||||
self.image_token_index = image_token_index
|
||||
self.projector_hidden_act = projector_hidden_act
|
||||
self.multimodal_projector_bias = multimodal_projector_bias
|
||||
|
||||
if vision_feature_select_strategy not in ["default", "full"]:
|
||||
raise ValueError(
|
||||
|
||||
@@ -179,10 +179,13 @@ class LlavaNextVideoPreTrainedModel(PreTrainedModel):
|
||||
class LlavaNextVideoMultiModalProjector(nn.Module):
|
||||
def __init__(self, config: LlavaNextVideoConfig):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
|
||||
self.linear_1 = nn.Linear(
|
||||
config.vision_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
|
||||
)
|
||||
self.act = ACT2FN[config.projector_hidden_act]
|
||||
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
|
||||
self.linear_2 = nn.Linear(
|
||||
config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
|
||||
)
|
||||
|
||||
def forward(self, image_features):
|
||||
hidden_states = self.linear_1(image_features)
|
||||
|
||||
@@ -58,6 +58,8 @@ class LlavaNextVideoConfig(PretrainedConfig):
|
||||
The image token index to encode the image prompt.
|
||||
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
|
||||
The activation function used by the multimodal projector.
|
||||
multimodal_projector_bias (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use bias in the multimodal projector.
|
||||
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
|
||||
The feature selection strategy used to select the vision feature from the vision backbone.
|
||||
Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features.
|
||||
@@ -109,6 +111,7 @@ class LlavaNextVideoConfig(PretrainedConfig):
|
||||
ignore_index=-100,
|
||||
image_token_index=32001,
|
||||
projector_hidden_act="gelu",
|
||||
multimodal_projector_bias=True,
|
||||
vision_feature_select_strategy="default",
|
||||
vision_feature_layer=-2,
|
||||
image_grid_pinpoints=None,
|
||||
@@ -128,6 +131,7 @@ class LlavaNextVideoConfig(PretrainedConfig):
|
||||
self.ignore_index = ignore_index
|
||||
self.image_token_index = image_token_index
|
||||
self.projector_hidden_act = projector_hidden_act
|
||||
self.multimodal_projector_bias = multimodal_projector_bias
|
||||
|
||||
if vision_feature_select_strategy not in ["default", "full"]:
|
||||
raise ValueError(
|
||||
|
||||
@@ -58,6 +58,8 @@ class LlavaOnevisionConfig(PretrainedConfig):
|
||||
of the form `(height, width)`.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model's input and output word embeddings should be tied.
|
||||
multimodal_projector_bias (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use bias in the multimodal projector.
|
||||
|
||||
Example:
|
||||
|
||||
@@ -95,11 +97,13 @@ class LlavaOnevisionConfig(PretrainedConfig):
|
||||
vision_aspect_ratio="anyres_max_9",
|
||||
image_grid_pinpoints=None,
|
||||
tie_word_embeddings=False,
|
||||
multimodal_projector_bias=True,
|
||||
**kwargs,
|
||||
):
|
||||
self.image_token_index = image_token_index
|
||||
self.video_token_index = video_token_index
|
||||
self.projector_hidden_act = projector_hidden_act
|
||||
self.multimodal_projector_bias = multimodal_projector_bias
|
||||
|
||||
if vision_feature_select_strategy not in ["default", "full"]:
|
||||
raise ValueError(
|
||||
|
||||
@@ -201,10 +201,13 @@ class LlavaOnevisionCausalLMOutputWithPast(ModelOutput):
|
||||
class LlavaOnevisionMultiModalProjector(nn.Module):
|
||||
def __init__(self, config: LlavaOnevisionConfig):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
|
||||
self.linear_1 = nn.Linear(
|
||||
config.vision_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
|
||||
)
|
||||
self.act = ACT2FN[config.projector_hidden_act]
|
||||
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
|
||||
self.linear_2 = nn.Linear(
|
||||
config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
|
||||
)
|
||||
|
||||
def forward(self, image_features):
|
||||
hidden_states = self.linear_1(image_features)
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
@@ -27,7 +29,6 @@ from transformers import (
|
||||
PixtralImageProcessor,
|
||||
PixtralProcessor,
|
||||
PixtralVisionConfig,
|
||||
PreTrainedTokenizerFast,
|
||||
)
|
||||
from transformers.convert_slow_tokenizer import bytes_to_unicode
|
||||
|
||||
@@ -156,29 +157,18 @@ class MistralConverter:
|
||||
return tokenizer
|
||||
|
||||
|
||||
def convert_mistral_tokenizer():
|
||||
model_name = "mistralai/Pixtral-12B-2409"
|
||||
def convert_mistral_tokenizer(model_file):
|
||||
from transformers import LlamaTokenizer
|
||||
|
||||
tokenizer = MistralTokenizer.from_model(model_name)
|
||||
|
||||
vocab = tokenizer.instruct_tokenizer.tokenizer._tekken_token2id_nospecial
|
||||
all_special = [
|
||||
token.value if hasattr(token, "value") else token
|
||||
for token in tokenizer.instruct_tokenizer.tokenizer._all_special_tokens
|
||||
]
|
||||
specials_tokens = {token: all_special.index(token) for token in all_special}
|
||||
specials_tokens.update(vocab)
|
||||
vocab = specials_tokens
|
||||
|
||||
tokenizer = PreTrainedTokenizerFast(
|
||||
tokenizer_object=MistralConverter(vocab=vocab, additional_special_tokens=all_special).converted(),
|
||||
bos_token="<s>",
|
||||
unk_token="<unk>",
|
||||
eos_token="</s>",
|
||||
)
|
||||
tokenizer.model_input_names = ["input_ids", "attention_mask"]
|
||||
|
||||
return tokenizer
|
||||
mistral_tokenizer = MistralTokenizer.from_file(model_file)
|
||||
vocab = mistral_tokenizer.instruct_tokenizer.tokenizer.vocab()
|
||||
control_token_ids = mistral_tokenizer.instruct_tokenizer.tokenizer._control_tokens
|
||||
all_special = [vocab[id] for id in control_token_ids]
|
||||
hf_tokenizer = LlamaTokenizer(model_file)
|
||||
# Do I need to exclude tokens that are already special?
|
||||
hf_tokenizer.add_special_tokens({"additional_special_tokens": all_special})
|
||||
hf_tokenizer.model_input_names = ["input_ids", "attention_mask"]
|
||||
return hf_tokenizer
|
||||
|
||||
|
||||
def permute_for_rope(value, n_heads, config):
|
||||
@@ -187,7 +177,7 @@ def permute_for_rope(value, n_heads, config):
|
||||
return value.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
|
||||
|
||||
|
||||
def convert_dictionnary(original_state_dict, vision_config, text_config):
|
||||
def convert_dictionary(original_state_dict, vision_config, text_config):
|
||||
new_dict = {}
|
||||
|
||||
all_keys = "\n" + "\n".join(original_state_dict.keys())
|
||||
@@ -208,7 +198,6 @@ def convert_dictionnary(original_state_dict, vision_config, text_config):
|
||||
num_attention_heads = _config.num_attention_heads
|
||||
if "k_proj" in new_key:
|
||||
num_attention_heads = _config.num_key_value_heads
|
||||
# convert the text model (basically mistral model)
|
||||
|
||||
if "q_proj" in new_key or "k_proj" in new_key:
|
||||
value = permute_for_rope(value, num_attention_heads, _config)
|
||||
@@ -217,29 +206,57 @@ def convert_dictionnary(original_state_dict, vision_config, text_config):
|
||||
return new_dict
|
||||
|
||||
|
||||
def convert_mistral_model(input_dir, output_dir):
|
||||
text_config = MistralConfig(
|
||||
attention_dropout=0.0,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
head_dim=128,
|
||||
hidden_act="silu",
|
||||
hidden_size=5120,
|
||||
initializer_range=0.02,
|
||||
intermediate_size=14336,
|
||||
max_position_embeddings=1024000,
|
||||
model_type="mistral",
|
||||
num_attention_heads=32,
|
||||
num_hidden_layers=40,
|
||||
num_key_value_heads=8,
|
||||
rms_norm_eps=1e-05,
|
||||
rope_theta=1000000000.0,
|
||||
sliding_window=None,
|
||||
tie_word_embeddings=False,
|
||||
vocab_size=131072,
|
||||
)
|
||||
MISTRAL_CONFIG_MAPPING = {
|
||||
"dim": "hidden_size",
|
||||
"hidden_dim": "intermediate_size",
|
||||
"n_kv_heads": "num_key_value_heads",
|
||||
"n_heads": "num_attention_heads",
|
||||
"n_layers": "num_hidden_layers",
|
||||
}
|
||||
|
||||
vision_config = PixtralVisionConfig()
|
||||
|
||||
def convert_mistral_model(input_dir, output_dir):
|
||||
vision_config = {}
|
||||
if os.path.isfile(f"{input_dir}/params.json"):
|
||||
with open(f"{input_dir}/params.json") as f:
|
||||
param_json = json.load(f)
|
||||
vision_config = param_json.pop("vision_encoder")
|
||||
for k, v in MISTRAL_CONFIG_MAPPING.items():
|
||||
value = param_json.pop(k)
|
||||
param_json[v] = value
|
||||
if "hidden_act" not in vision_config:
|
||||
vision_config["hidden_act"] = "silu"
|
||||
text_config = MistralConfig(
|
||||
**param_json,
|
||||
hidden_act="silu",
|
||||
sliding_window=None,
|
||||
tie_word_embeddings=False,
|
||||
is_composition=True,
|
||||
rms_norm_eps=1e-5,
|
||||
)
|
||||
else:
|
||||
text_config = MistralConfig(
|
||||
attention_dropout=0.0,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
head_dim=128,
|
||||
hidden_act="silu",
|
||||
hidden_size=5120,
|
||||
initializer_range=0.02,
|
||||
intermediate_size=14336,
|
||||
max_position_embeddings=1024000,
|
||||
model_type="mistral",
|
||||
num_attention_heads=32,
|
||||
num_hidden_layers=40,
|
||||
num_key_value_heads=8,
|
||||
rms_norm_eps=1e-05,
|
||||
rope_theta=1000000000.0,
|
||||
sliding_window=None,
|
||||
tie_word_embeddings=False,
|
||||
vocab_size=131072,
|
||||
)
|
||||
adapter_bias = vision_config.pop("adapter_bias", True)
|
||||
vision_config = PixtralVisionConfig(**vision_config)
|
||||
config = LlavaConfig(
|
||||
vision_config,
|
||||
text_config,
|
||||
@@ -247,38 +264,55 @@ def convert_mistral_model(input_dir, output_dir):
|
||||
image_token_index=10,
|
||||
vision_feature_select_strategy="full",
|
||||
image_seq_length=1,
|
||||
multimodal_projector_bias=adapter_bias,
|
||||
)
|
||||
config.architectures = ["LlavaForConditionalGeneration"]
|
||||
config.save_pretrained(output_dir)
|
||||
full_original_state_dict = {}
|
||||
safetensors_files = sorted([file for file in os.listdir(input_dir) if file.endswith(".safetensors")])
|
||||
if len(safetensors_files) == 1:
|
||||
full_original_state_dict = safe_load_file(f"{input_dir}/consolidated.safetensors")
|
||||
else:
|
||||
for file in safetensors_files:
|
||||
loaded_dict = safe_load_file(f"{input_dir}/{file}")
|
||||
full_original_state_dict.update(loaded_dict)
|
||||
|
||||
original_state_dict = safe_load_file(f"{input_dir}/consolidated.safetensors")
|
||||
new_dict = convert_dictionnary(original_state_dict, vision_config, text_config)
|
||||
|
||||
new_dict = convert_dictionary(full_original_state_dict, vision_config, text_config)
|
||||
with torch.device("meta"):
|
||||
model = LlavaForConditionalGeneration(config)
|
||||
model.load_state_dict(new_dict, strict=True, assign=True)
|
||||
|
||||
model.save_pretrained(output_dir)
|
||||
|
||||
tokenizer = convert_mistral_tokenizer()
|
||||
image_processor = PixtralImageProcessor()
|
||||
processor = PixtralProcessor(tokenizer=tokenizer, image_processor=image_processor, image_token="[IMG]")
|
||||
processor.save_pretrained(output_dir)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--input_dir",
|
||||
help="Location of LLaMA weights, which contains tokenizer.model and model folders",
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
help="Location to write HF model and tokenizer",
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_file", help="Location of the specific tokenizer model file to use.", required=True
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chat_template_file",
|
||||
help="Optional file containing a raw chat template. Will be set as the processor's chat template.",
|
||||
required=False,
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_mistral_model(args.input_dir, args.output_dir)
|
||||
tokenizer = convert_mistral_tokenizer(args.tokenizer_file)
|
||||
image_processor = PixtralImageProcessor()
|
||||
processor = PixtralProcessor(tokenizer=tokenizer, image_processor=image_processor, image_token="[IMG]")
|
||||
if args.chat_template_file:
|
||||
processor.chat_template = open(args.chat_template_file).read()
|
||||
processor.save_pretrained(args.output_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -37,7 +37,7 @@ from ...image_utils import (
|
||||
validate_kwargs,
|
||||
validate_preprocess_arguments,
|
||||
)
|
||||
from ...utils import TensorType, is_torch_device, is_torch_dtype, is_torch_tensor, is_vision_available, logging
|
||||
from ...utils import TensorType, is_torch_device, is_torch_dtype, is_vision_available, logging
|
||||
from ...utils.import_utils import requires_backends
|
||||
|
||||
|
||||
@@ -63,10 +63,24 @@ class BatchMixFeature(BatchFeature):
|
||||
Returns:
|
||||
[`BatchFeature`]: The same instance after modification.
|
||||
"""
|
||||
|
||||
def _recursive_to(obj, device, *args, **kwargs):
|
||||
# Lists can be nested, so keep digging until we hit tensors
|
||||
if isinstance(obj, list):
|
||||
return [_recursive_to(o, device, *args, **kwargs) for o in obj]
|
||||
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
|
||||
elif isinstance(obj, torch.Tensor) and torch.is_floating_point(obj):
|
||||
# cast and send to device
|
||||
return obj.to(*args, **kwargs)
|
||||
elif isinstance(obj, torch.Tensor) and device is not None:
|
||||
# only send to device, don't cast
|
||||
return obj.to(device=device)
|
||||
else:
|
||||
return obj
|
||||
|
||||
requires_backends(self, ["torch"])
|
||||
import torch # noqa
|
||||
|
||||
new_data = {}
|
||||
device = kwargs.get("device")
|
||||
# Check if the args are a device or a dtype
|
||||
if device is None and len(args) > 0:
|
||||
@@ -80,21 +94,8 @@ class BatchMixFeature(BatchFeature):
|
||||
else:
|
||||
# it's something else
|
||||
raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
|
||||
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
|
||||
for k, v in self.items():
|
||||
# check if v is a floating point
|
||||
if isinstance(v, list):
|
||||
new_data[k] = [
|
||||
element.to(*args, **kwargs) for sample in v for element in sample if is_torch_tensor(element)
|
||||
]
|
||||
elif isinstance(v, torch.Tensor) and torch.is_floating_point(v):
|
||||
# cast and send to device
|
||||
new_data[k] = v.to(*args, **kwargs)
|
||||
elif isinstance(v, torch.Tensor) and device is not None:
|
||||
new_data[k] = v.to(device=device)
|
||||
else:
|
||||
new_data[k] = v
|
||||
self.data = new_data
|
||||
|
||||
self.data = {k: _recursive_to(v, device, *args, **kwargs) for k, v in self.data.items()}
|
||||
return self
|
||||
|
||||
|
||||
|
||||
@@ -126,7 +126,6 @@ def rotate_half(x):
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
@@ -491,18 +490,20 @@ class PixtralVisionModel(PixtralPreTrainedModel):
|
||||
all tokens of all images of shape (N_toks, D)
|
||||
"""
|
||||
# pass images through initial convolution independently
|
||||
patch_embeds_list = [self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in pixel_values]
|
||||
if len(pixel_values) > 1:
|
||||
raise ValueError("Batching/padding not supported yet!")
|
||||
patch_embeds_list = [self.patch_conv(img.to(self.dtype)) for sample in pixel_values for img in sample]
|
||||
|
||||
# flatten to a single sequence
|
||||
patch_embeds = torch.cat([p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1)
|
||||
patch_embeds = torch.cat([p.flatten(1).T for p in patch_embeds_list], dim=0).unsqueeze(0)
|
||||
patch_embeds = self.ln_pre(patch_embeds)
|
||||
|
||||
# positional embeddings
|
||||
position_ids = position_ids_in_meshgrid(
|
||||
patch_embeds_list, max_width=self.config.image_size // self.config.patch_size
|
||||
).to(self.device)
|
||||
|
||||
position_embedding = self.patch_positional_embedding(patch_embeds, position_ids)
|
||||
|
||||
attention_mask = generate_block_attention_mask(
|
||||
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds
|
||||
)
|
||||
|
||||
@@ -22,7 +22,7 @@ from ...feature_extraction_utils import BatchFeature
|
||||
from ...image_utils import ImageInput, is_valid_image, load_image
|
||||
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
|
||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from ...utils import is_torch_device, is_torch_dtype, is_torch_tensor, logging, requires_backends
|
||||
from ...utils import is_torch_device, is_torch_dtype, logging, requires_backends
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -66,10 +66,24 @@ class BatchMixFeature(BatchFeature):
|
||||
Returns:
|
||||
[`BatchFeature`]: The same instance after modification.
|
||||
"""
|
||||
|
||||
def _recursive_to(obj, device, *args, **kwargs):
|
||||
# Lists can be nested, so keep digging until we hit tensors
|
||||
if isinstance(obj, list):
|
||||
return [_recursive_to(o, device, *args, **kwargs) for o in obj]
|
||||
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
|
||||
elif isinstance(obj, torch.Tensor) and torch.is_floating_point(obj):
|
||||
# cast and send to device
|
||||
return obj.to(*args, **kwargs)
|
||||
elif isinstance(obj, torch.Tensor) and device is not None:
|
||||
# only send to device, don't cast
|
||||
return obj.to(device=device)
|
||||
else:
|
||||
return obj
|
||||
|
||||
requires_backends(self, ["torch"])
|
||||
import torch # noqa
|
||||
|
||||
new_data = {}
|
||||
device = kwargs.get("device")
|
||||
# Check if the args are a device or a dtype
|
||||
if device is None and len(args) > 0:
|
||||
@@ -83,21 +97,8 @@ class BatchMixFeature(BatchFeature):
|
||||
else:
|
||||
# it's something else
|
||||
raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
|
||||
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
|
||||
for k, v in self.items():
|
||||
# check if v is a floating point
|
||||
if isinstance(v, list):
|
||||
new_data[k] = [
|
||||
element.to(*args, **kwargs) for sample in v for element in sample if is_torch_tensor(element)
|
||||
]
|
||||
elif isinstance(v, torch.Tensor) and torch.is_floating_point(v):
|
||||
# cast and send to device
|
||||
new_data[k] = v.to(*args, **kwargs)
|
||||
elif isinstance(v, torch.Tensor) and device is not None:
|
||||
new_data[k] = v.to(device=device)
|
||||
else:
|
||||
new_data[k] = v
|
||||
self.data = new_data
|
||||
|
||||
self.data = {k: _recursive_to(v, device, *args, **kwargs) for k, v in self.data.items()}
|
||||
return self
|
||||
|
||||
|
||||
@@ -204,12 +205,21 @@ class PixtralProcessor(ProcessorMixin):
|
||||
|
||||
if images is not None:
|
||||
if is_image_or_image_url(images):
|
||||
images = [[images]]
|
||||
elif isinstance(images, list) and is_image_or_image_url(images[0]):
|
||||
if isinstance(text, list):
|
||||
images = [[im] for im in images]
|
||||
if isinstance(text, str) or isinstance(text, list) and len(text) == 1:
|
||||
# If there's a single sample, the image must belong to it
|
||||
images = [[images]]
|
||||
else:
|
||||
raise ValueError(
|
||||
"You have supplied multiple text samples, but `images` is not a nested list. When processing multiple samples, `images` should be a list of lists of images, one list per sample."
|
||||
)
|
||||
elif isinstance(images, list) and is_image_or_image_url(images[0]):
|
||||
if isinstance(text, str) or isinstance(text, list) and len(text) == 1:
|
||||
# If there's a single sample, all images must belong to it
|
||||
images = [images]
|
||||
else:
|
||||
raise ValueError(
|
||||
"You have supplied multiple text samples, but `images` is not a nested list. When processing multiple samples, `images` should be a list of lists of images, one list per sample."
|
||||
)
|
||||
elif isinstance(images, list) and isinstance(images[0], list) and is_image_or_image_url(images[0][0]):
|
||||
pass
|
||||
else:
|
||||
|
||||
@@ -55,6 +55,8 @@ class VideoLlavaConfig(PretrainedConfig):
|
||||
Sequence length of one image embedding.
|
||||
video_seq_length (`int`, *optional*, defaults to 2056):
|
||||
Sequence length of one video embedding.
|
||||
multimodal_projector_bias (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use bias in the multimodal projector.
|
||||
|
||||
Example:
|
||||
|
||||
@@ -92,6 +94,7 @@ class VideoLlavaConfig(PretrainedConfig):
|
||||
vision_feature_layer=-2,
|
||||
image_seq_length=256,
|
||||
video_seq_length=2056,
|
||||
multimodal_projector_bias=True,
|
||||
**kwargs,
|
||||
):
|
||||
self.ignore_index = ignore_index
|
||||
@@ -102,6 +105,7 @@ class VideoLlavaConfig(PretrainedConfig):
|
||||
self.vision_feature_layer = vision_feature_layer
|
||||
self.image_seq_length = image_seq_length
|
||||
self.video_seq_length = video_seq_length
|
||||
self.multimodal_projector_bias = multimodal_projector_bias
|
||||
|
||||
self.vision_config = vision_config
|
||||
|
||||
|
||||
@@ -88,10 +88,13 @@ class VideoLlavaCausalLMOutputWithPast(ModelOutput):
|
||||
class VideoLlavaMultiModalProjector(nn.Module):
|
||||
def __init__(self, config: VideoLlavaConfig):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
|
||||
self.linear_1 = nn.Linear(
|
||||
config.vision_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
|
||||
)
|
||||
self.act = ACT2FN[config.projector_hidden_act]
|
||||
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
|
||||
self.linear_2 = nn.Linear(
|
||||
config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
|
||||
)
|
||||
|
||||
def forward(self, image_features):
|
||||
hidden_states = self.linear_1(image_features)
|
||||
|
||||
@@ -253,7 +253,7 @@ class PixtralProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
"USER: [IMG]\nWhat's the content of the image? ASSISTANT:",
|
||||
] * 5
|
||||
processor.tokenizer.pad_token = "</s>"
|
||||
image_inputs = [self.image_0] * 5
|
||||
image_inputs = [[self.image_0]] * 5
|
||||
|
||||
# Make small for checking image token expansion
|
||||
processor.image_processor.size = {"longest_edge": 30}
|
||||
|
||||
Reference in New Issue
Block a user