[modular] Simplify logic and docstring handling (#39185)

* simplify a lot

* Update modular_model_converter.py

* finalize

* remove outdated functions

* apply it

* and examples
This commit is contained in:
Cyril Vallez
2025-07-07 14:52:57 +02:00
committed by GitHub
parent f16fbfb89a
commit 056fa73fae
25 changed files with 380 additions and 465 deletions

View File

@@ -125,8 +125,6 @@ class MyNewModelConfig(PretrainedConfig):
>>> # Accessing the model configuration
>>> configuration = model.config
```
new_param (`int`, *optional*, defaults to `False`):
A fun new parameter
"""
model_type = "my_new_model"

View File

@@ -437,32 +437,6 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
num_logits_to_keep: int = 0,
) -> Union[tuple, NewTaskModelCausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, NewTaskModelForNewTask
>>> model = NewTaskModelForNewTask.from_pretrained("google/new_task_model2-3b-mix-224")
>>> processor = AutoProcessor.from_pretrained("google/new_task_model2-3b-mix-224")
>>> prompt = "Where is the cat standing?"
>>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(**inputs,)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Where is the cat standing?\nsnow"
```
Returns:
"""
vlm_outputs = super().forward(

View File

@@ -2,11 +2,122 @@ from transformers.models.llama.configuration_llama import LlamaConfig
# Example where we only want to only add a new config argument and new arg doc
# here there is no `ARG` so we are gonna take parent doc
class MyNewModelConfig(LlamaConfig):
r"""
new_param (`int`, *optional*, defaults to `False`):
A fun new parameter
This is the configuration class to store the configuration of a [`MyNewModelModel`]. It is used to instantiate an MyNewModel
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the MyNewModel-7B.
e.g. [meta-my_new_model/MyNewModel-2-7b-hf](https://huggingface.co/meta-my_new_model/MyNewModel-2-7b-hf)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the MyNewModel model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`MyNewModelModel`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 11008):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details, check out [this
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
`num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with. MyNewModel 1 supports up to 2048 tokens,
MyNewModel 2 up to 4096, CodeLlama up to 16384.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*):
Padding token id.
bos_token_id (`int`, *optional*, defaults to 1):
Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to 2):
End of stream token id.
pretraining_tp (`int`, *optional*, defaults to 1):
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'my_new_model3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'my_new_model3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`list[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`list[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'my_new_model3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'my_new_model3'. Scaling factor applied to high frequency components of the RoPE
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
mlp_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
head_dim (`int`, *optional*):
The attention head dimension. If None, it will default to hidden_size // num_attention_heads
```python
>>> from transformers import MyNewModelModel, MyNewModelConfig
>>> # Initializing a MyNewModel my_new_model-7b style configuration
>>> configuration = MyNewModelConfig()
>>> # Initializing a model from the my_new_model-7b style configuration
>>> model = MyNewModelModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
def __init__(self, mlp_bias=True, new_param=0, **super_kwargs):

View File

@@ -1674,21 +1674,7 @@ class DFineForObjectDetection(DFinePreTrainedModel):
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[tuple[torch.FloatTensor], DFineObjectDetectionOutput]:
r"""
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
can choose to directly pass a flattened representation of an image.
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
embedded representation.
labels (`list[Dict]` of len `(batch_size,)`, *optional*):
Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes
in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`.
Examples:
"""
```python
>>> import torch
>>> from transformers.image_utils import load_image
@@ -1729,7 +1715,8 @@ class DFineForObjectDetection(DFinePreTrainedModel):
Detected cat with confidence 0.956 at location [11.71, 53.52, 316.64, 472.33]
Detected remote with confidence 0.947 at location [40.46, 73.7, 175.62, 117.57]
Detected sofa with confidence 0.918 at location [0.59, 1.88, 640.25, 474.74]
```"""
```
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states

View File

@@ -729,11 +729,6 @@ class DiffLlamaForCausalLM(DiffLlamaPreTrainedModel, GenerationMixin):
**kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Example:
```python

View File

@@ -746,7 +746,7 @@ class Dinov2WithRegistersBackbone(Dinov2WithRegistersPreTrainedModel, BackboneMi
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> BackboneOutput:
r"""
"""
Examples:
```python

View File

@@ -1292,11 +1292,6 @@ class Emu3ForCausalLM(Emu3PreTrainedModel, GenerationMixin):
**kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Example:
```python

View File

@@ -1530,11 +1530,6 @@ class FalconH1ForCausalLM(FalconH1PreTrainedModel, GenerationMixin):
**kwargs,
) -> Union[tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Example:
```python

View File

@@ -479,11 +479,6 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
**kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Example:
```python

View File

@@ -553,11 +553,6 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
**kwargs,
) -> CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Example:
```python

View File

@@ -658,11 +658,6 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin):
**kwargs,
) -> CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Example:
```python

View File

@@ -1830,11 +1830,6 @@ class Gemma3nForCausalLM(Gemma3nPreTrainedModel, GenerationMixin):
**kwargs,
) -> CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Example:
```python

View File

@@ -39,13 +39,13 @@ class Glm4vImagesKwargs(ImagesKwargs):
class Glm4vProcessorKwargs(ProcessingKwargs, total=False):
images_kwargs: Glm4vImagesKwargs
videos_kwargs: Glm4vVideosProcessorKwargs
_defaults = {
"text_kwargs": {
"padding": False,
},
}
images_kwargs: Glm4vImagesKwargs
videos_kwargs: Glm4vVideosProcessorKwargs
class Glm4vProcessor(ProcessorMixin):

View File

@@ -472,11 +472,6 @@ class HeliumForCausalLM(HeliumPreTrainedModel, GenerationMixin):
**kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Example:
```python

View File

@@ -1501,41 +1501,6 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
r"""
qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided
to serve as text prompt, which the Q-Former model will encode.
Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for
details.
[What are input IDs?](../glossary#input-ids)
qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of input sequence tokens in the vocabulary of the language model. Input tokens can optionally be
provided to serve as text prompt, which the language model can continue.
Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for
details.
[What are input IDs?](../glossary#input-ids)
decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
be used by default.
Only relevant in case an encoder-decoder language model (like T5) is used.
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the language modeling loss. Indices should be in `[-100, 0, ..., config.vocab_size -
1]`. All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
config.vocab_size]`
Examples:
```python
>>> from transformers import InstructBlipVideoProcessor, InstructBlipVideoForConditionalGeneration
>>> import torch

View File

@@ -901,11 +901,6 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin)
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, InternVLCausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Example:
```python

View File

@@ -223,8 +223,7 @@ def apply_rotary_pos_emb_vision(
class MLCDAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper
Multi-headed attention with RoPE. Refer to papers:
"""Multi-headed attention with RoPE. Refer to papers:
- Attention is all you need:
https://huggingface.co/papers/1706.03762
- RoFormer: Enhanced Transformer with Rotary Position Embedding:

View File

@@ -221,8 +221,6 @@ class SamHQMaskDecoderConfig(PretrainedConfig):
The dimensionality of the hidden states in the IoU head module.
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the layer normalization layers.
vit_dim (`int`, *optional*, defaults to 768):
Dimensionality of the Vision Transformer (ViT) used in the `SamHQMaskDecoder` module.
"""

View File

@@ -77,6 +77,35 @@ class SamHQVisionConfig(SamVisionConfig):
class SamHQMaskDecoderConfig(SamMaskDecoderConfig):
r"""
This is the configuration class to store the configuration of a [`SamHQMaskDecoder`]. It is used to instantiate a SAM_HQ
mask decoder to the specified arguments, defining the model architecture. Instantiating a configuration defaults
will yield a similar configuration to that of the SAM_HQ-vit-h
[facebook/sam_hq-vit-huge](https://huggingface.co/facebook/sam_hq-vit-huge) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 256):
Dimensionality of the hidden states.
hidden_act (`str`, *optional*, defaults to `"relu"`):
The non-linear activation function used inside the `SamHQMaskDecoder` module.
mlp_dim (`int`, *optional*, defaults to 2048):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
num_hidden_layers (`int`, *optional*, defaults to 2):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 8):
Number of attention heads for each attention layer in the Transformer encoder.
attention_downsample_rate (`int`, *optional*, defaults to 2):
The downsampling rate of the attention layer.
num_multimask_outputs (`int`, *optional*, defaults to 3):
The number of outputs from the `SamHQMaskDecoder` module. In the Segment Anything paper, this is set to 3.
iou_head_depth (`int`, *optional*, defaults to 3):
The number of layers in the IoU head module.
iou_head_hidden_dim (`int`, *optional*, defaults to 256):
The dimensionality of the hidden states in the IoU head module.
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the layer normalization layers.
vit_dim (`int`, *optional*, defaults to 768):
Dimensionality of the Vision Transformer (ViT) used in the `SamHQMaskDecoder` module.
"""

View File

@@ -874,16 +874,6 @@ class SmolVLMForConditionalGeneration(SmolVLMPreTrainedModel, GenerationMixin):
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, SmolVLMCausalLMOutputWithPast]:
r"""
pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
Mask to avoid performing attention on padding pixel indices.
image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
The hidden states of the image encoder after modality projection.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `SmolVLMForConditionalGeneration`).
Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only
computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Example:
```python

View File

@@ -99,8 +99,7 @@ class T5GemmaModuleConfig(PretrainedConfig):
>>> model = T5GemmaModuleModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
Module config (encoder or decoder): the same as Gemma2Config."""
```"""
model_type = "t5_gemma_module"
keys_to_ignore_at_inference = ["past_key_values"]

View File

@@ -68,10 +68,7 @@ logger = logging.get_logger(__name__)
class T5GemmaModuleConfig(Gemma2Config):
"""Module config (encoder or decoder): the same as Gemma2Config."""
def __init__(self, **super_kwargs):
super().__init__(**super_kwargs)
pass
class T5GemmaConfig(PretrainedConfig):

View File

@@ -319,17 +319,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
class Zamba2Attention(nn.Module):
"""
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
and "Generating Long Sequences with Sparse Transformers".
Adapted from transformers.models.mistral.modeling_mistral.MistralAttention:
The input dimension here is attention_hidden_size = 2 * hidden_size, and head_dim = attention_hidden_size // num_heads.
The extra factor of 2 comes from the input being the concatenation of original_hidden_states with the output of the previous (mamba) layer
(see fig. 2 in https://huggingface.co/papers/2405.16712).
Additionally, replaced
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) with
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim/2)
Multi-headed attention from 'Attention Is All You Need' paper.
Adapted from transformers.models.mistral.modeling_mistral.MistralAttention:

View File

@@ -23,7 +23,6 @@ FILES_TO_PARSE = [
os.path.join(MODEL_ROOT, "rt_detr", "modular_rt_detr.py"),
os.path.join(MODEL_ROOT, "qwen2", "modular_qwen2.py"),
os.path.join(MODEL_ROOT, "qwen3", "modular_qwen3.py"),
os.path.join(MODEL_ROOT, "qwen3", "modular_qwen3_moe.py"),
os.path.join(MODEL_ROOT, "llava_next_video", "modular_llava_next_video.py"),
os.path.join(MODEL_ROOT, "cohere2", "modular_cohere2.py"),
os.path.join(MODEL_ROOT, "modernbert", "modular_modernbert.py"),

View File

@@ -249,90 +249,13 @@ class ReplaceMethodCallTransformer(cst.CSTTransformer):
return updated_node
def get_docstring_indent(docstring):
# Match the first line after the opening triple quotes
match = re.search(r'(?:"""|\'\'\'|```)\n(\s+)', docstring)
if match:
# Return the indentation spaces captured
return len(match.group(1))
return 0
def is_full_docstring(original_docstring: str, new_docstring: str, original_level: int) -> bool:
"""Check if `new_docstring` is a full docstring, or if it is only part of a docstring that should then
be merged with the existing old one.
"""
# libcst returns the docstrinbgs with literal `r"""` quotes in front
new_docstring = new_docstring.split('"""', 1)[1]
# The docstring contains Args definition, so it is self-contained
if re.search(r"\n\s*Args:\n", new_docstring):
return True
elif re.search(r"\n\s*Args:\n", original_docstring):
return False
# Check if the docstring contains args docstring (meaning it is self contained):
param_pattern = re.compile(
# |--- Group 1 ---|| Group 2 ||- Group 3 -||---------- Group 4 ----------|
rf"^\s{{0,{original_level}}}(\w+)\s*\(\s*([^, \)]*)(\s*.*?)\s*\)\s*:\s*((?:(?!\n^\s{{0,{original_level}}}\w+\s*\().)*)",
re.DOTALL | re.MULTILINE,
)
match_object = param_pattern.search(new_docstring)
if match_object is not None:
return True
# If it contains Returns, but starts with text indented with an additional 4 spaces before, it is self-contained
# (this is the scenario when using `@add_start_docstrings_to_model_forward`, but adding more args to docstring)
match_object = re.search(r"\n([^\S\n]*)Returns:\n", new_docstring)
if match_object is not None:
full_indent = match_object.group(1)
striped_doc = new_docstring.strip("\n")
if striped_doc.startswith(full_indent + " " * 4) or striped_doc.startswith(full_indent + "\t"):
return True
return False
def merge_docstrings(original_docstring, updated_docstring):
original_level = get_docstring_indent(original_docstring)
if not is_full_docstring(original_docstring, updated_docstring, original_level):
# Split the docstring at the example section, assuming `"""` is used to define the docstring
parts = original_docstring.split("```")
if "```" in updated_docstring and len(parts) > 1:
updated_docstring = updated_docstring.lstrip('r"')
new_parts = updated_docstring.split("```")
if len(new_parts) != 3:
raise ValueError("There should only be one example, and it should have opening and closing '```'")
parts[1] = new_parts[1]
updated_docstring = "".join(
[
f"\n{original_level * ' '}```",
parts[1],
"```",
parts[2],
]
)
docstring_opening, original_start_docstring = parts[0].rstrip(" \n").split('"""')[:2]
new_start_docstring = new_parts[0].rstrip(" \n")
docstring_opening += '"""'
if new_start_docstring.startswith(original_start_docstring):
updated_docstring = new_start_docstring + "\n" + updated_docstring
elif original_start_docstring.endswith(new_start_docstring):
updated_docstring = original_start_docstring + "\n" + updated_docstring
else:
updated_docstring = original_start_docstring + "\n" + new_start_docstring + "\n" + updated_docstring
updated_docstring = docstring_opening + updated_docstring
elif updated_docstring not in original_docstring:
# add tabulation if we are at the lowest level.
if re.search(r"\n\s*.*\(.*\)\:\n\s*\w", updated_docstring):
updated_docstring = updated_docstring.replace("\n ", "\n ")
updated_docstring = original_docstring.rstrip('"') + "\n" + updated_docstring.lstrip('r"\n')
return updated_docstring
class SuperTransformer(cst.CSTTransformer):
METADATA_DEPENDENCIES = (ParentNodeProvider,)
def __init__(self, python_module: cst.Module, original_methods, updated_methods, all_bases=None):
def __init__(self, python_module: cst.Module, original_modeling_methods, modular_methods, all_bases=None):
self.python_module = python_module
self.original_methods = original_methods
self.updated_methods = updated_methods
self.original_modeling_methods = original_modeling_methods
self.modular_methods = modular_methods
self.all_assign_target = {}
self.deleted_targets = {} # child node can delete some arguments
self.all_bases = all_bases or []
@@ -414,53 +337,39 @@ class SuperTransformer(cst.CSTTransformer):
break
return new_body
def replace_super_calls(self, node: cst.IndentedBlock, func_name: str) -> cst.CSTNode:
def replace_super_calls(self, node: cst.BaseSuite, func_name: str) -> cst.BaseSuite:
"""Updates the body of the input `node`'s `func_name` function by replacing calls
to super().func_name() with the source code of the parent class' `func_name`.
It keeps everything that is defined before `super().func_name()`.
"""
self.has_docstring = False
parent_has_docstring = False
if func_name in self.original_methods:
parent_has_docstring = m.matches(self.original_methods[func_name].body.body[0], DOCSTRING_NODE)
new_body = []
has_super_call = False
modular_node_body = node.body
for i, expr in enumerate(node.body):
for i, expr in enumerate(modular_node_body):
if is_call_to_super(expr, func_name):
has_super_call = True
new_body.extend(self.update_body(self.original_methods[func_name].body.body, node.body[i + 1 :]))
original_modeling_method_body = self.original_modeling_methods[func_name].body.body
new_body.extend(self.update_body(original_modeling_method_body, modular_node_body[i + 1 :]))
new_body = self._fix_init_location(new_body)
return node.with_changes(body=new_body)
else:
expr = expr.visit(self.transformer)
if m.matches(expr, DOCSTRING_NODE):
self.has_docstring = True
if parent_has_docstring: # actually here we ought to de-duplicate?
original_docstring = self.original_methods[func_name].body.body[0].body[0].value.value
updated_docstring = expr.body[0].value.value
merged_doc = merge_docstrings(original_docstring, updated_docstring)
new_node = [expr.with_changes(body=[cst.Expr(value=cst.SimpleString(value=merged_doc))])]
else:
new_node = [expr]
new_body.extend(new_node)
elif not m.matches(expr, m.SimpleStatementLine(body=[m.Del()])) and not has_super_call:
if not m.matches(expr, m.SimpleStatementLine(body=[m.Del()])):
new_body.append(expr)
if not self.has_docstring and parent_has_docstring:
new_body = [self.original_methods[func_name].body.body[0]] + new_body
return node.with_changes(body=new_body)
def leave_FunctionDef(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode:
if updated_node.name.value in self.updated_methods:
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
name = updated_node.name.value
if name in self.modular_methods:
new_body = self.replace_super_calls(updated_node.body, name)
return updated_node.with_changes(body=new_body, params=updated_node.params)
return updated_node
def leave_Return(self, original_node: cst.Return, updated_node: cst.Return) -> cst.CSTNode:
def leave_Return(self, original_node: cst.Return, updated_node: cst.Return) -> cst.Return:
""" "When a return statement is reached, it is replaced with the unrolled super code"""
if m.matches(updated_node.value, m.Call(func=m.Attribute(attr=m.Name("super")))):
func_def = self.get_metadata(ParentNodeProvider, original_node)
if m.matched(func_def, m.FunctionDef()) and func_def.name.value in self.original_methods:
if m.matched(func_def, m.FunctionDef()) and func_def.name.value in self.original_modeling_methods:
updated_return_value = updated_node.value.with_changes(
args=[
cst.Arg(
@@ -979,55 +888,52 @@ def common_partial_suffix(str1: str, str2: str) -> str:
def replace_class_node(
mapper: ModelFileMapper, class_node: cst.ClassDef, renamed_super_class: str, original_super_class: str
):
mapper: ModelFileMapper, modular_class_node: cst.ClassDef, renamed_super_class: str, original_super_class: str
) -> cst.ClassDef:
"""
Replace a class node which inherits from another modeling class. This function works in the following way:
- start from the base class node of the inherited class (a cst.Node)
- replace all methods of the base node with the methods defined in the child class
- append all new methods defined in the child class
- start from the methods and class attributes of the original modeling code node, and replace their definition
if overriden in the modular
- append all new methods and class attributes defined in the child class
- all potential method/class docstrings and decorators use the ones found in modular if any, else in original modeling
- replace all calls to super() with the unravelled code
| ```python | | ```python
| class GemmaModel(LlamaModel): | | class GemmaModel(nn.Module):
| def __init__(self): | | def __init__(self):
Going from: | super().__init__() | to: | super().__init__(config)
| self.dropout = 0.2 | | self.dropout = 0.2
| ``` | | self.padding_idx = config.pad_token_id
| self.vocab_size = config.vocab_size
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
| self.layers = nn.ModuleList(
| [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
| )
| self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
| self.gradient_checkpointing = False
| # Initialize weights and apply final processing
| self.post_init()
| ```
"""
all_bases = [get_full_attribute_name(k.value) for k in class_node.bases]
if any(base is None for base in all_bases):
raise ValueError(f"Could not parse the name of the bases for {class_node.name.value}")
Args:
mapper (`ModelFileMapper`):
The mapper corresponding to the visited file from which the modular class node inherits.
modular_class_node (`cst.ClassDef`):
The class node as found in the modular file.
renamed_super_class (`str`):
The name of the class from which `modular_class_node` inherits after automatic renaming.
original_super_class (`str`):
The name of the class from which `modular_class_node` inherits before automatic renaming.
original_node = mapper.classes[renamed_super_class]
Returns:
A new class node corresponding to the modular definition.
"""
all_bases = [get_full_attribute_name(k.value) for k in modular_class_node.bases]
if any(base is None for base in all_bases):
raise ValueError(f"Could not parse the name of the bases for {modular_class_node.name.value}")
original_modeling_node = mapper.classes[renamed_super_class]
# Always use the new name of the class (in case we use e.g. `ColPaliForRetrieval` inheriting from `PaliGemmaForConditionalGeneration`)
new_name = class_node.name
new_class_name = modular_class_node.name
# If the new class name is different from the renamed super class name, we need to update the docstrings/comments accordingly
if new_name.value != renamed_super_class:
common_suffix = common_partial_suffix(new_name.value, renamed_super_class)
if new_class_name.value != renamed_super_class:
common_suffix = common_partial_suffix(new_class_name.value, renamed_super_class)
# Note that this works even without common prefix, in which case it does not replace anything
old, new = renamed_super_class.replace(common_suffix, ""), new_name.value.replace(common_suffix, "")
temp_module = cst.Module(body=[original_node])
original_node = temp_module.visit(
old, new = renamed_super_class.replace(common_suffix, ""), new_class_name.value.replace(common_suffix, "")
temp_module = cst.Module(body=[original_modeling_node])
original_modeling_node = temp_module.visit(
ReplaceNameTransformer(get_lowercase_name(old), get_lowercase_name(new), only_doc=True)
).body[0]
# If we explicitly passed a new base with common suffix to an old base, it is for switching the prefix
# e.g. if the "natural" parent class is `PreTrainedModel` but we wanted to rename it to `PreTrainedVisionModel`
additional_bases = [base for base in all_bases if base != original_super_class]
new_bases = []
for original_base in original_node.bases:
new_class_bases = []
for original_base in original_modeling_node.bases:
new_base = original_base
# we only potentially switch base for Name-based bases, not Attribute
if m.matches(original_base.value, m.Name()):
@@ -1038,106 +944,125 @@ def replace_class_node(
new_name_node = original_base.value.with_changes(value=additional_base_name)
new_base = original_base.with_changes(value=new_name_node)
break
new_bases.append(new_base)
new_class_bases.append(new_base)
original_methods = {
f.name.value if hasattr(f, "name") else mapper.python_module.code_for_node(f): f
for f in original_node.body.body
}
updated_methods = {
f.name.value if hasattr(f, "name") else mapper.python_module.code_for_node(f): f for f in class_node.body.body
}
end_meth = []
# Use class decorators redefined in modular file if any
new_class_decorators = (
modular_class_node.decorators if len(modular_class_node.decorators) > 0 else original_modeling_node.decorators
)
assign_targets = {}
docstring_node = []
# Iterate directly from node.body as there can be property/setters with same names which are overwritten when we use a dict
for func in original_node.body.body:
name = func.name.value if hasattr(func, "name") else mapper.python_module.code_for_node(func)
if m.matches(func, m.FunctionDef()) and name in updated_methods and updated_methods[name] is not None:
new_params = updated_methods[name].params
# Replace the method in the replacement class, preserving decorators
kwarg_name = getattr(updated_methods[name].params, "star_kwarg", None)
# Compute new class docstring
original_modeling_docstring = [
node for node in original_modeling_node.body.body if m.matches(node, DOCSTRING_NODE)
]
modular_docstring = [node for node in modular_class_node.body.body if m.matches(node, DOCSTRING_NODE)]
# Use class docstring in modular if any, else original modeling code docstring
new_class_docstring = modular_docstring if len(modular_docstring) > 0 else original_modeling_docstring
# Compute new class attributes
original_modeling_class_attributes = {
node.body[0].targets[0].target.value: node
for node in original_modeling_node.body.body
if m.matches(node, m.SimpleStatementLine(body=[m.Assign()]))
}
original_modeling_class_attributes.update(
{
node.body[0].target.value: node
for node in original_modeling_node.body.body
if m.matches(node, m.SimpleStatementLine(body=[m.AnnAssign()]))
}
)
modular_class_attributes = {
node.body[0].targets[0].target.value: node
for node in modular_class_node.body.body
if m.matches(node, m.SimpleStatementLine(body=[m.Assign()]))
}
modular_class_attributes.update(
{
node.body[0].target.value: node
for node in modular_class_node.body.body
if m.matches(node, m.SimpleStatementLine(body=[m.AnnAssign()]))
}
)
# Use all original modeling attributes, and potentially override some with values in the modular
new_class_attributes = list({**original_modeling_class_attributes, **modular_class_attributes}.values())
original_modeling_methods = {
node.name.value: node for node in original_modeling_node.body.body if m.matches(node, m.FunctionDef())
}
modular_methods = {
node.name.value: node for node in modular_class_node.body.body if m.matches(node, m.FunctionDef())
}
new_class_methods = []
# Iterate over the methods of the original modeling code, and add them to the list of methods to add
for name, node in original_modeling_methods.items():
# If the method was redefined in modular, make appropriate changes to the node
if name in modular_methods:
# Get the corresponding method node in modular
modular_node = modular_methods[name]
# If we match the pattern, we should avoid inheriting the method
if re.match(r"\ndef .*\(.*\):\n raise.*Error\(.*", mapper.python_module.code_for_node(modular_node)):
continue
# Compute new method docstring
modeling_docstring = [node_ for node_ in node.body.body if m.matches(node_, DOCSTRING_NODE)]
modular_docstring = [node_ for node_ in modular_node.body.body if m.matches(node_, DOCSTRING_NODE)]
# Use method docstring in modular if any, else original modeling code docstring
new_body = (
modular_node.body.body
if len(modular_docstring) > 0
else modeling_docstring + list(modular_node.body.body)
)
new_body = modular_node.body.with_changes(body=new_body)
# Use arguments as defined in the modular
new_params = modular_node.params
# If using the `**super_kwargs` syntax in modular, merge any existing modular arg with all the original modeling ones
kwarg_name = getattr(modular_node.params, "star_kwarg", None)
if kwarg_name and kwarg_name.name.value == "super_kwargs":
parent_params = {k.name.value: k for k in func.params.params}
parent_params.update({k.name.value: k for k in new_params.params[1:]})
new_params = new_params.with_changes(
params=list(parent_params.values()), star_kwarg=func.params.star_kwarg
)
# Keep decorators in `modular_xxx.py` if any, else original decorators
new_decorators = (
updated_methods[name].decorators if len(updated_methods[name].decorators) > 0 else func.decorators
)
original_modeling_params = {k.name.value: k for k in node.params.params}
modular_params = {k.name.value: k for k in new_params.params[1:]}
new_param_list = list({**original_modeling_params, **modular_params}.values())
new_params = new_params.with_changes(params=new_param_list, star_kwarg=node.params.star_kwarg)
# Keep return annotation in `modular_xxx.py` if any, else original return annotation
new_return_annotation = updated_methods[name].returns if updated_methods[name].returns else func.returns
# Keep decorators in modular if any, else original decorators
new_decorators = modular_node.decorators if len(modular_node.decorators) > 0 else node.decorators
if not re.match(
r"\ndef .*\(.*\):\n raise.*Error\(.*",
mapper.python_module.code_for_node(updated_methods[name]),
):
func = func.with_changes(
body=updated_methods[name].body,
# Keep return annotation in modular if any, else original return annotation
new_return_annotation = modular_node.returns if modular_node.returns else node.returns
# Update the method node
node = node.with_changes(
body=new_body,
params=new_params,
decorators=new_decorators,
returns=new_return_annotation,
)
else:
continue
if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])):
target = mapper.python_module.code_for_node(func.body[0].targets[0])
assign_targets[target] = func
elif m.matches(func, m.SimpleStatementLine(body=[m.AnnAssign()])):
target = mapper.python_module.code_for_node(func.body[0].target)
assign_targets[target] = func
elif m.matches(func, DOCSTRING_NODE):
docstring_node = [func]
else:
end_meth.append(func)
new_class_methods.append(node)
# Port new methods that are defined only in modular-file and append at the end
for func in class_node.body.body:
name = func.name.value if hasattr(func, "name") else mapper.python_module.code_for_node(func)
if m.matches(func, DOCSTRING_NODE): # This processes the docstring of the class!
# Extract the original docstring
updated_docstring = func.body[0].value.value
if len(docstring_node) == 0: # If the original docstring is empty, just create one from the updated.
docstring_node = [
cst.SimpleStatementLine(body=[cst.Expr(value=cst.SimpleString(value=updated_docstring))])
]
else:
original_docstring = docstring_node[0].body[0].value.value
merged_doc = merge_docstrings(original_docstring, updated_docstring)
# Update the docstring in the original function
docstring_node = [
docstring_node[0].with_changes(body=[cst.Expr(value=cst.SimpleString(value=merged_doc))])
]
if name not in original_methods and func is not None and isinstance(func, cst.FunctionDef):
end_meth.append(func)
if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])):
# TODO we only use single assign might cause issues
target = mapper.python_module.code_for_node(func.body[0].targets[0])
assign_targets[target] = func
if m.matches(func, m.SimpleStatementLine(body=[m.AnnAssign()])):
target = mapper.python_module.code_for_node(func.body[0].target)
assign_targets[target] = func
end_meth = docstring_node + list(assign_targets.values()) + end_meth
for name, node in modular_methods.items():
if name not in original_modeling_methods:
new_class_methods.append(node)
# Replace the calls to `super()` with the unrolled code
result_node = original_node.with_changes(body=cst.IndentedBlock(body=end_meth))
# Recreate the whole new class body
new_class_body = new_class_docstring + new_class_attributes + new_class_methods
# Replace the calls to `super()` of the redefined modular methods with the unrolled code
result_node = original_modeling_node.with_changes(body=cst.IndentedBlock(body=new_class_body))
temp_module = cst.Module(body=[result_node])
new_module = MetadataWrapper(temp_module)
new_replacement_class = new_module.visit(
SuperTransformer(temp_module, original_methods, updated_methods, all_bases)
SuperTransformer(temp_module, original_modeling_methods, modular_methods, all_bases)
)
new_replacement_body = new_replacement_class.body[0].body # get the indented block
new_class_body = new_replacement_class.body[0].body # get the indented block
# Use decorators redefined in `modular_xxx.py` if any
new_decorators = class_node.decorators if len(class_node.decorators) > 0 else original_node.decorators
return original_node.with_changes(
body=new_replacement_body, decorators=new_decorators, bases=new_bases, name=new_name
return original_modeling_node.with_changes(
body=new_class_body, decorators=new_class_decorators, bases=new_class_bases, name=new_class_name
)