Large modular logic refactoring (#34487)

* rework converter

* Update modular_model_converter.py

* Update modular_model_converter.py

* Update modular_model_converter.py

* Update modular_model_converter.py

* cleaning

* cleaning

* finalize imports

* imports

* Update modular_model_converter.py

* Better renaming to avoid visiting same file multiple times

* start converting files

* style

* address most comments

* style

* remove unused stuff in get_needed_imports

* style

* move class dependency functions outside class

* Move main functions outside class

* style

* Update modular_model_converter.py

* rename func

* add augmented dependencies

* Update modular_model_converter.py

* Add types_to_file_type + tweak annotation handling

* Allow assignment dependency mapping + fix regex

* style + update modular examples

* fix modular_roberta example (wrong redefinition of __init__)

* slightly correct order in which dependencies will appear

* style

* review comments

* Performance + better handling of dependencies when they are imported

* style

* Add advanced new classes capabilities

* style

* add forgotten check

* Update modeling_llava_next_video.py

* Add prority list ordering in check_conversion as well

* Update check_modular_conversion.py

* Update configuration_gemma.py
This commit is contained in:
Cyril Vallez
2024-11-01 10:13:51 +01:00
committed by GitHub
parent 86701f2b6f
commit e2ac16b28a
19 changed files with 2726 additions and 1658 deletions

View File

@@ -8,7 +8,6 @@ from dataclasses import dataclass
from typing import ClassVar, List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from ...cache_utils import Cache, StaticCache
@@ -18,92 +17,15 @@ from ...utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
logging,
replace_return_docstrings,
)
from ..auto import AutoModel, AutoModelForCausalLM
from .configuration_new_task_model import NewTaskModelConfig
if is_flash_attn_2_available():
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
from ..auto import AutoModel, AutoModelForCausalLM
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "NewTaskModelConfig"
# Adapted from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
# But NewTaskModel has no causal mask on prefix
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
is_training: bool = False,
token_type_ids: torch.Tensor = None,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
is_training (`bool`):
Whether the model is in training mode or in inference. The condition is checked by presence/absence of `token_type_ids/labels`
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
# Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
if sequence_length != 1:
if is_training:
causal_mask = torch.triu(causal_mask, diagonal=1)
else:
causal_mask[:, :sequence_length] = 0.0
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
# we are training thus we need to create a full mask on the image + prefix but causal on suffix
if is_training:
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
)
return causal_mask
@dataclass
class NewTaskModelCausalLMOutputWithPast(ModelOutput):
"""
@@ -182,12 +104,12 @@ class NewTaskModelPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["NewTaskModelMultiModalProjector"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = False
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_sdpa = True
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
def _init_weights(self, module):
# important: this ported version of NewTaskModelisn't meant for training from scratch - only
@@ -210,14 +132,6 @@ class NewTaskModelPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@property
def _supports_sdpa(self):
"""
Retrieve language_model's attribute to check whether the model supports
SDPA or not.
"""
return self.language_model._supports_sdpa
NEW_TASK_MODEL_INPUTS_DOCSTRING = r"""
Args:
@@ -301,11 +215,8 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
self.vision_tower = AutoModel.from_config(config=config.vision_config)
self.multi_modal_projector = NewTaskModelMultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size
self._attn_implementation = config._attn_implementation
language_model = AutoModelForCausalLM.from_config(
config=config.text_config, attn_implementation=self._attn_implementation
)
language_model = AutoModelForCausalLM.from_config(config=config.text_config)
if language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
@@ -344,6 +255,11 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
def _update_causal_mask(
self, attention_mask, token_type_ids, inputs_embeds, past_key_values, cache_position, is_training: bool = False
):
if self.config.text_config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
using_static_cache = isinstance(past_key_values, StaticCache)
dtype = inputs_embeds.dtype
min_dtype = torch.finfo(dtype).min
@@ -388,6 +304,22 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
)
return causal_mask
def get_image_features(self, pixel_values: torch.FloatTensor):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
The tensors corresponding to the input images.
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
image_outputs = self.vision_tower(pixel_values)
selected_image_feature = image_outputs.last_hidden_state
image_features = self.multi_modal_projector(selected_image_feature)
image_features = image_features / (self.config.hidden_size**0.5)
return image_features
@add_start_docstrings_to_model_forward(NEW_TASK_MODEL_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=NewTaskModelCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
@@ -426,9 +358,9 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, NewTaskModelForNewTask
>>> from transformers import AutoProcessor, NewTaskModelForConditionalGeneration
>>> model = NewTaskModelForNewTask.from_pretrained("google/NewTaskModel-test-224px-hf")
>>> model = NewTaskModelForConditionalGeneration.from_pretrained("google/NewTaskModel-test-224px-hf")
>>> processor = AutoProcessor.from_pretrained("google/NewTaskModel-test-224px-hf")
>>> prompt = "answer en Where is the cow standing?"
@@ -484,6 +416,7 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
num_logits_to_keep=None,
**kwargs,
):
# Overwritten -- custom `position_ids` and `pixel_values` handling
model_inputs = self.language_model.prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
@@ -493,33 +426,10 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
cache_position=cache_position,
use_cache=use_cache,
num_logits_to_keep=num_logits_to_keep,
token_type_ids=token_type_ids,
**kwargs,
)
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = model_inputs["inputs_embeds"].device
else:
batch_size, sequence_length = model_inputs["input_ids"].shape
device = model_inputs["input_ids"].device
dtype = self.get_output_embeddings().weight.dtype
min_dtype = torch.finfo(dtype).min
model_inputs["attention_mask"] = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=batch_size,
)
model_inputs["token_type_ids"] = token_type_ids
# position_ids in NewTaskModel are 1-indexed
if model_inputs.get("position_ids") is not None:
model_inputs["position_ids"] += 1