Large modular logic refactoring (#34487)
* rework converter * Update modular_model_converter.py * Update modular_model_converter.py * Update modular_model_converter.py * Update modular_model_converter.py * cleaning * cleaning * finalize imports * imports * Update modular_model_converter.py * Better renaming to avoid visiting same file multiple times * start converting files * style * address most comments * style * remove unused stuff in get_needed_imports * style * move class dependency functions outside class * Move main functions outside class * style * Update modular_model_converter.py * rename func * add augmented dependencies * Update modular_model_converter.py * Add types_to_file_type + tweak annotation handling * Allow assignment dependency mapping + fix regex * style + update modular examples * fix modular_roberta example (wrong redefinition of __init__) * slightly correct order in which dependencies will appear * style * review comments * Performance + better handling of dependencies when they are imported * style * Add advanced new classes capabilities * style * add forgotten check * Update modeling_llava_next_video.py * Add prority list ordering in check_conversion as well * Update check_modular_conversion.py * Update configuration_gemma.py
This commit is contained in:
@@ -8,7 +8,6 @@ from dataclasses import dataclass
|
||||
from typing import ClassVar, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from ...cache_utils import Cache, StaticCache
|
||||
@@ -18,92 +17,15 @@ from ...utils import (
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ..auto import AutoModel, AutoModelForCausalLM
|
||||
from .configuration_new_task_model import NewTaskModelConfig
|
||||
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
||||
|
||||
from ..auto import AutoModel, AutoModelForCausalLM
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "NewTaskModelConfig"
|
||||
|
||||
|
||||
# Adapted from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
|
||||
# But NewTaskModel has no causal mask on prefix
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
min_dtype: float,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
is_training: bool = False,
|
||||
token_type_ids: torch.Tensor = None,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
min_dtype (`float`):
|
||||
The minimum value representable with the dtype `dtype`.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
is_training (`bool`):
|
||||
Whether the model is in training mode or in inference. The condition is checked by presence/absence of `token_type_ids/labels`
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
||||
# Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
|
||||
if sequence_length != 1:
|
||||
if is_training:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
else:
|
||||
causal_mask[:, :sequence_length] = 0.0
|
||||
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
# we are training thus we need to create a full mask on the image + prefix but causal on suffix
|
||||
if is_training:
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
|
||||
)
|
||||
return causal_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
class NewTaskModelCausalLMOutputWithPast(ModelOutput):
|
||||
"""
|
||||
@@ -182,12 +104,12 @@ class NewTaskModelPreTrainedModel(PreTrainedModel):
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["NewTaskModelMultiModalProjector"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_flash_attn_2 = False
|
||||
_supports_cache_class = True
|
||||
_supports_quantized_cache = True
|
||||
_supports_static_cache = True
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = True
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
# important: this ported version of NewTaskModelisn't meant for training from scratch - only
|
||||
@@ -210,14 +132,6 @@ class NewTaskModelPreTrainedModel(PreTrainedModel):
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
@property
|
||||
def _supports_sdpa(self):
|
||||
"""
|
||||
Retrieve language_model's attribute to check whether the model supports
|
||||
SDPA or not.
|
||||
"""
|
||||
return self.language_model._supports_sdpa
|
||||
|
||||
|
||||
NEW_TASK_MODEL_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
@@ -301,11 +215,8 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
||||
self.vision_tower = AutoModel.from_config(config=config.vision_config)
|
||||
self.multi_modal_projector = NewTaskModelMultiModalProjector(config)
|
||||
self.vocab_size = config.text_config.vocab_size
|
||||
self._attn_implementation = config._attn_implementation
|
||||
|
||||
language_model = AutoModelForCausalLM.from_config(
|
||||
config=config.text_config, attn_implementation=self._attn_implementation
|
||||
)
|
||||
language_model = AutoModelForCausalLM.from_config(config=config.text_config)
|
||||
|
||||
if language_model._tied_weights_keys is not None:
|
||||
self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
|
||||
@@ -344,6 +255,11 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
||||
def _update_causal_mask(
|
||||
self, attention_mask, token_type_ids, inputs_embeds, past_key_values, cache_position, is_training: bool = False
|
||||
):
|
||||
if self.config.text_config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
return attention_mask
|
||||
return None
|
||||
|
||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||
dtype = inputs_embeds.dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
@@ -388,6 +304,22 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
||||
)
|
||||
return causal_mask
|
||||
|
||||
def get_image_features(self, pixel_values: torch.FloatTensor):
|
||||
"""
|
||||
Obtains image last hidden states from the vision tower and apply multimodal projection.
|
||||
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
|
||||
The tensors corresponding to the input images.
|
||||
Returns:
|
||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||
"""
|
||||
image_outputs = self.vision_tower(pixel_values)
|
||||
selected_image_feature = image_outputs.last_hidden_state
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
image_features = image_features / (self.config.hidden_size**0.5)
|
||||
return image_features
|
||||
|
||||
@add_start_docstrings_to_model_forward(NEW_TASK_MODEL_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=NewTaskModelCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
@@ -426,9 +358,9 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, NewTaskModelForNewTask
|
||||
>>> from transformers import AutoProcessor, NewTaskModelForConditionalGeneration
|
||||
|
||||
>>> model = NewTaskModelForNewTask.from_pretrained("google/NewTaskModel-test-224px-hf")
|
||||
>>> model = NewTaskModelForConditionalGeneration.from_pretrained("google/NewTaskModel-test-224px-hf")
|
||||
>>> processor = AutoProcessor.from_pretrained("google/NewTaskModel-test-224px-hf")
|
||||
|
||||
>>> prompt = "answer en Where is the cow standing?"
|
||||
@@ -484,6 +416,7 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
||||
num_logits_to_keep=None,
|
||||
**kwargs,
|
||||
):
|
||||
# Overwritten -- custom `position_ids` and `pixel_values` handling
|
||||
model_inputs = self.language_model.prepare_inputs_for_generation(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
@@ -493,33 +426,10 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
||||
cache_position=cache_position,
|
||||
use_cache=use_cache,
|
||||
num_logits_to_keep=num_logits_to_keep,
|
||||
token_type_ids=token_type_ids,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
||||
if model_inputs["inputs_embeds"] is not None:
|
||||
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
||||
device = model_inputs["inputs_embeds"].device
|
||||
else:
|
||||
batch_size, sequence_length = model_inputs["input_ids"].shape
|
||||
device = model_inputs["input_ids"].device
|
||||
|
||||
dtype = self.get_output_embeddings().weight.dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
|
||||
model_inputs["attention_mask"] = _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=past_key_values.get_max_length(),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
min_dtype=min_dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
model_inputs["token_type_ids"] = token_type_ids
|
||||
|
||||
# position_ids in NewTaskModel are 1-indexed
|
||||
if model_inputs.get("position_ids") is not None:
|
||||
model_inputs["position_ids"] += 1
|
||||
|
||||
Reference in New Issue
Block a user