[modular] Simplify logic and docstring handling (#39185)
* simplify a lot * Update modular_model_converter.py * finalize * remove outdated functions * apply it * and examples
This commit is contained in:
@@ -125,8 +125,6 @@ class MyNewModelConfig(PretrainedConfig):
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```
|
||||
new_param (`int`, *optional*, defaults to `False`):
|
||||
A fun new parameter
|
||||
"""
|
||||
|
||||
model_type = "my_new_model"
|
||||
|
||||
@@ -437,32 +437,6 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
||||
num_logits_to_keep: int = 0,
|
||||
) -> Union[tuple, NewTaskModelCausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, NewTaskModelForNewTask
|
||||
|
||||
>>> model = NewTaskModelForNewTask.from_pretrained("google/new_task_model2-3b-mix-224")
|
||||
>>> processor = AutoProcessor.from_pretrained("google/new_task_model2-3b-mix-224")
|
||||
|
||||
>>> prompt = "Where is the cat standing?"
|
||||
>>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(**inputs,)
|
||||
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Where is the cat standing?\nsnow"
|
||||
```
|
||||
Returns:
|
||||
"""
|
||||
vlm_outputs = super().forward(
|
||||
|
||||
@@ -2,11 +2,122 @@ from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
|
||||
|
||||
# Example where we only want to only add a new config argument and new arg doc
|
||||
# here there is no `ARG` so we are gonna take parent doc
|
||||
class MyNewModelConfig(LlamaConfig):
|
||||
r"""
|
||||
new_param (`int`, *optional*, defaults to `False`):
|
||||
A fun new parameter
|
||||
This is the configuration class to store the configuration of a [`MyNewModelModel`]. It is used to instantiate an MyNewModel
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the MyNewModel-7B.
|
||||
e.g. [meta-my_new_model/MyNewModel-2-7b-hf](https://huggingface.co/meta-my_new_model/MyNewModel-2-7b-hf)
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 32000):
|
||||
Vocabulary size of the MyNewModel model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`MyNewModelModel`]
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 11008):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||
Number of hidden layers in the Transformer decoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
num_key_value_heads (`int`, *optional*):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details, check out [this
|
||||
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
|
||||
`num_attention_heads`.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
||||
The maximum sequence length that this model might ever be used with. MyNewModel 1 supports up to 2048 tokens,
|
||||
MyNewModel 2 up to 4096, CodeLlama up to 16384.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
pad_token_id (`int`, *optional*):
|
||||
Padding token id.
|
||||
bos_token_id (`int`, *optional*, defaults to 1):
|
||||
Beginning of stream token id.
|
||||
eos_token_id (`int`, *optional*, defaults to 2):
|
||||
End of stream token id.
|
||||
pretraining_tp (`int`, *optional*, defaults to 1):
|
||||
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
|
||||
document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
|
||||
understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
|
||||
results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether to tie weight embeddings
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'my_new_model3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'my_new_model3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`list[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`list[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'my_new_model3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'my_new_model3'. Scaling factor applied to high frequency components of the RoPE
|
||||
attention_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
mlp_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
|
||||
head_dim (`int`, *optional*):
|
||||
The attention head dimension. If None, it will default to hidden_size // num_attention_heads
|
||||
|
||||
```python
|
||||
>>> from transformers import MyNewModelModel, MyNewModelConfig
|
||||
|
||||
>>> # Initializing a MyNewModel my_new_model-7b style configuration
|
||||
>>> configuration = MyNewModelConfig()
|
||||
|
||||
>>> # Initializing a model from the my_new_model-7b style configuration
|
||||
>>> model = MyNewModelModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, mlp_bias=True, new_param=0, **super_kwargs):
|
||||
|
||||
@@ -1674,21 +1674,7 @@ class DFineForObjectDetection(DFinePreTrainedModel):
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> Union[tuple[torch.FloatTensor], DFineObjectDetectionOutput]:
|
||||
r"""
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
|
||||
can choose to directly pass a flattened representation of an image.
|
||||
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
|
||||
Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
|
||||
embedded representation.
|
||||
labels (`list[Dict]` of len `(batch_size,)`, *optional*):
|
||||
Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
|
||||
following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
|
||||
respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes
|
||||
in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`.
|
||||
|
||||
Examples:
|
||||
|
||||
"""
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from transformers.image_utils import load_image
|
||||
@@ -1729,7 +1715,8 @@ class DFineForObjectDetection(DFinePreTrainedModel):
|
||||
Detected cat with confidence 0.956 at location [11.71, 53.52, 316.64, 472.33]
|
||||
Detected remote with confidence 0.947 at location [40.46, 73.7, 175.62, 117.57]
|
||||
Detected sofa with confidence 0.918 at location [0.59, 1.88, 640.25, 474.74]
|
||||
```"""
|
||||
```
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
|
||||
@@ -729,11 +729,6 @@ class DiffLlamaForCausalLM(DiffLlamaPreTrainedModel, GenerationMixin):
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
|
||||
@@ -746,7 +746,7 @@ class Dinov2WithRegistersBackbone(Dinov2WithRegistersPreTrainedModel, BackboneMi
|
||||
output_attentions: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> BackboneOutput:
|
||||
r"""
|
||||
"""
|
||||
Examples:
|
||||
|
||||
```python
|
||||
|
||||
@@ -1292,11 +1292,6 @@ class Emu3ForCausalLM(Emu3PreTrainedModel, GenerationMixin):
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
|
||||
@@ -1530,11 +1530,6 @@ class FalconH1ForCausalLM(FalconH1PreTrainedModel, GenerationMixin):
|
||||
**kwargs,
|
||||
) -> Union[tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
|
||||
@@ -479,11 +479,6 @@ class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
|
||||
@@ -553,11 +553,6 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
|
||||
**kwargs,
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
|
||||
@@ -658,11 +658,6 @@ class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin):
|
||||
**kwargs,
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
|
||||
@@ -1830,11 +1830,6 @@ class Gemma3nForCausalLM(Gemma3nPreTrainedModel, GenerationMixin):
|
||||
**kwargs,
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
|
||||
@@ -39,13 +39,13 @@ class Glm4vImagesKwargs(ImagesKwargs):
|
||||
|
||||
|
||||
class Glm4vProcessorKwargs(ProcessingKwargs, total=False):
|
||||
images_kwargs: Glm4vImagesKwargs
|
||||
videos_kwargs: Glm4vVideosProcessorKwargs
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"padding": False,
|
||||
},
|
||||
}
|
||||
images_kwargs: Glm4vImagesKwargs
|
||||
videos_kwargs: Glm4vVideosProcessorKwargs
|
||||
|
||||
|
||||
class Glm4vProcessor(ProcessorMixin):
|
||||
|
||||
@@ -472,11 +472,6 @@ class HeliumForCausalLM(HeliumPreTrainedModel, GenerationMixin):
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
|
||||
@@ -1501,41 +1501,6 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> Union[tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
|
||||
r"""
|
||||
qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided
|
||||
to serve as text prompt, which the Q-Former model will encode.
|
||||
|
||||
Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for
|
||||
details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of input sequence tokens in the vocabulary of the language model. Input tokens can optionally be
|
||||
provided to serve as text prompt, which the language model can continue.
|
||||
|
||||
Indices can be obtained using [`InstructBlipVideoProcessor`]. See [`InstructBlipVideoProcessor.__call__`] for
|
||||
details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||||
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
||||
be used by default.
|
||||
|
||||
Only relevant in case an encoder-decoder language model (like T5) is used.
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the language modeling loss. Indices should be in `[-100, 0, ..., config.vocab_size -
|
||||
1]`. All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
|
||||
config.vocab_size]`
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import InstructBlipVideoProcessor, InstructBlipVideoForConditionalGeneration
|
||||
>>> import torch
|
||||
|
||||
@@ -901,11 +901,6 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin)
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> Union[tuple, InternVLCausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
|
||||
@@ -223,8 +223,7 @@ def apply_rotary_pos_emb_vision(
|
||||
|
||||
|
||||
class MLCDAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper
|
||||
Multi-headed attention with RoPE. Refer to papers:
|
||||
"""Multi-headed attention with RoPE. Refer to papers:
|
||||
- Attention is all you need:
|
||||
https://huggingface.co/papers/1706.03762
|
||||
- RoFormer: Enhanced Transformer with Rotary Position Embedding:
|
||||
|
||||
@@ -221,8 +221,6 @@ class SamHQMaskDecoderConfig(PretrainedConfig):
|
||||
The dimensionality of the hidden states in the IoU head module.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the layer normalization layers.
|
||||
|
||||
|
||||
vit_dim (`int`, *optional*, defaults to 768):
|
||||
Dimensionality of the Vision Transformer (ViT) used in the `SamHQMaskDecoder` module.
|
||||
"""
|
||||
|
||||
@@ -77,6 +77,35 @@ class SamHQVisionConfig(SamVisionConfig):
|
||||
|
||||
class SamHQMaskDecoderConfig(SamMaskDecoderConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`SamHQMaskDecoder`]. It is used to instantiate a SAM_HQ
|
||||
mask decoder to the specified arguments, defining the model architecture. Instantiating a configuration defaults
|
||||
will yield a similar configuration to that of the SAM_HQ-vit-h
|
||||
[facebook/sam_hq-vit-huge](https://huggingface.co/facebook/sam_hq-vit-huge) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`, *optional*, defaults to 256):
|
||||
Dimensionality of the hidden states.
|
||||
hidden_act (`str`, *optional*, defaults to `"relu"`):
|
||||
The non-linear activation function used inside the `SamHQMaskDecoder` module.
|
||||
mlp_dim (`int`, *optional*, defaults to 2048):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 2):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 8):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
attention_downsample_rate (`int`, *optional*, defaults to 2):
|
||||
The downsampling rate of the attention layer.
|
||||
num_multimask_outputs (`int`, *optional*, defaults to 3):
|
||||
The number of outputs from the `SamHQMaskDecoder` module. In the Segment Anything paper, this is set to 3.
|
||||
iou_head_depth (`int`, *optional*, defaults to 3):
|
||||
The number of layers in the IoU head module.
|
||||
iou_head_hidden_dim (`int`, *optional*, defaults to 256):
|
||||
The dimensionality of the hidden states in the IoU head module.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the layer normalization layers.
|
||||
vit_dim (`int`, *optional*, defaults to 768):
|
||||
Dimensionality of the Vision Transformer (ViT) used in the `SamHQMaskDecoder` module.
|
||||
"""
|
||||
|
||||
@@ -874,16 +874,6 @@ class SmolVLMForConditionalGeneration(SmolVLMPreTrainedModel, GenerationMixin):
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> Union[tuple, SmolVLMCausalLMOutputWithPast]:
|
||||
r"""
|
||||
pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
|
||||
Mask to avoid performing attention on padding pixel indices.
|
||||
image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
||||
The hidden states of the image encoder after modality projection.
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `SmolVLMForConditionalGeneration`).
|
||||
Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only
|
||||
computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
|
||||
@@ -99,8 +99,7 @@ class T5GemmaModuleConfig(PretrainedConfig):
|
||||
>>> model = T5GemmaModuleModel(configuration)
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```
|
||||
Module config (encoder or decoder): the same as Gemma2Config."""
|
||||
```"""
|
||||
|
||||
model_type = "t5_gemma_module"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
@@ -68,10 +68,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class T5GemmaModuleConfig(Gemma2Config):
|
||||
"""Module config (encoder or decoder): the same as Gemma2Config."""
|
||||
|
||||
def __init__(self, **super_kwargs):
|
||||
super().__init__(**super_kwargs)
|
||||
pass
|
||||
|
||||
|
||||
class T5GemmaConfig(PretrainedConfig):
|
||||
|
||||
@@ -319,17 +319,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
|
||||
class Zamba2Attention(nn.Module):
|
||||
"""
|
||||
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
|
||||
and "Generating Long Sequences with Sparse Transformers".
|
||||
|
||||
Adapted from transformers.models.mistral.modeling_mistral.MistralAttention:
|
||||
The input dimension here is attention_hidden_size = 2 * hidden_size, and head_dim = attention_hidden_size // num_heads.
|
||||
The extra factor of 2 comes from the input being the concatenation of original_hidden_states with the output of the previous (mamba) layer
|
||||
(see fig. 2 in https://huggingface.co/papers/2405.16712).
|
||||
Additionally, replaced
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) with
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim/2)
|
||||
|
||||
Multi-headed attention from 'Attention Is All You Need' paper.
|
||||
|
||||
Adapted from transformers.models.mistral.modeling_mistral.MistralAttention:
|
||||
|
||||
@@ -23,7 +23,6 @@ FILES_TO_PARSE = [
|
||||
os.path.join(MODEL_ROOT, "rt_detr", "modular_rt_detr.py"),
|
||||
os.path.join(MODEL_ROOT, "qwen2", "modular_qwen2.py"),
|
||||
os.path.join(MODEL_ROOT, "qwen3", "modular_qwen3.py"),
|
||||
os.path.join(MODEL_ROOT, "qwen3", "modular_qwen3_moe.py"),
|
||||
os.path.join(MODEL_ROOT, "llava_next_video", "modular_llava_next_video.py"),
|
||||
os.path.join(MODEL_ROOT, "cohere2", "modular_cohere2.py"),
|
||||
os.path.join(MODEL_ROOT, "modernbert", "modular_modernbert.py"),
|
||||
|
||||
@@ -249,90 +249,13 @@ class ReplaceMethodCallTransformer(cst.CSTTransformer):
|
||||
return updated_node
|
||||
|
||||
|
||||
def get_docstring_indent(docstring):
|
||||
# Match the first line after the opening triple quotes
|
||||
match = re.search(r'(?:"""|\'\'\'|```)\n(\s+)', docstring)
|
||||
if match:
|
||||
# Return the indentation spaces captured
|
||||
return len(match.group(1))
|
||||
return 0
|
||||
|
||||
|
||||
def is_full_docstring(original_docstring: str, new_docstring: str, original_level: int) -> bool:
|
||||
"""Check if `new_docstring` is a full docstring, or if it is only part of a docstring that should then
|
||||
be merged with the existing old one.
|
||||
"""
|
||||
# libcst returns the docstrinbgs with literal `r"""` quotes in front
|
||||
new_docstring = new_docstring.split('"""', 1)[1]
|
||||
# The docstring contains Args definition, so it is self-contained
|
||||
if re.search(r"\n\s*Args:\n", new_docstring):
|
||||
return True
|
||||
elif re.search(r"\n\s*Args:\n", original_docstring):
|
||||
return False
|
||||
# Check if the docstring contains args docstring (meaning it is self contained):
|
||||
param_pattern = re.compile(
|
||||
# |--- Group 1 ---|| Group 2 ||- Group 3 -||---------- Group 4 ----------|
|
||||
rf"^\s{{0,{original_level}}}(\w+)\s*\(\s*([^, \)]*)(\s*.*?)\s*\)\s*:\s*((?:(?!\n^\s{{0,{original_level}}}\w+\s*\().)*)",
|
||||
re.DOTALL | re.MULTILINE,
|
||||
)
|
||||
match_object = param_pattern.search(new_docstring)
|
||||
if match_object is not None:
|
||||
return True
|
||||
# If it contains Returns, but starts with text indented with an additional 4 spaces before, it is self-contained
|
||||
# (this is the scenario when using `@add_start_docstrings_to_model_forward`, but adding more args to docstring)
|
||||
match_object = re.search(r"\n([^\S\n]*)Returns:\n", new_docstring)
|
||||
if match_object is not None:
|
||||
full_indent = match_object.group(1)
|
||||
striped_doc = new_docstring.strip("\n")
|
||||
if striped_doc.startswith(full_indent + " " * 4) or striped_doc.startswith(full_indent + "\t"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def merge_docstrings(original_docstring, updated_docstring):
|
||||
original_level = get_docstring_indent(original_docstring)
|
||||
if not is_full_docstring(original_docstring, updated_docstring, original_level):
|
||||
# Split the docstring at the example section, assuming `"""` is used to define the docstring
|
||||
parts = original_docstring.split("```")
|
||||
if "```" in updated_docstring and len(parts) > 1:
|
||||
updated_docstring = updated_docstring.lstrip('r"')
|
||||
new_parts = updated_docstring.split("```")
|
||||
if len(new_parts) != 3:
|
||||
raise ValueError("There should only be one example, and it should have opening and closing '```'")
|
||||
parts[1] = new_parts[1]
|
||||
updated_docstring = "".join(
|
||||
[
|
||||
f"\n{original_level * ' '}```",
|
||||
parts[1],
|
||||
"```",
|
||||
parts[2],
|
||||
]
|
||||
)
|
||||
docstring_opening, original_start_docstring = parts[0].rstrip(" \n").split('"""')[:2]
|
||||
new_start_docstring = new_parts[0].rstrip(" \n")
|
||||
docstring_opening += '"""'
|
||||
if new_start_docstring.startswith(original_start_docstring):
|
||||
updated_docstring = new_start_docstring + "\n" + updated_docstring
|
||||
elif original_start_docstring.endswith(new_start_docstring):
|
||||
updated_docstring = original_start_docstring + "\n" + updated_docstring
|
||||
else:
|
||||
updated_docstring = original_start_docstring + "\n" + new_start_docstring + "\n" + updated_docstring
|
||||
updated_docstring = docstring_opening + updated_docstring
|
||||
elif updated_docstring not in original_docstring:
|
||||
# add tabulation if we are at the lowest level.
|
||||
if re.search(r"\n\s*.*\(.*\)\:\n\s*\w", updated_docstring):
|
||||
updated_docstring = updated_docstring.replace("\n ", "\n ")
|
||||
updated_docstring = original_docstring.rstrip('"') + "\n" + updated_docstring.lstrip('r"\n')
|
||||
return updated_docstring
|
||||
|
||||
|
||||
class SuperTransformer(cst.CSTTransformer):
|
||||
METADATA_DEPENDENCIES = (ParentNodeProvider,)
|
||||
|
||||
def __init__(self, python_module: cst.Module, original_methods, updated_methods, all_bases=None):
|
||||
def __init__(self, python_module: cst.Module, original_modeling_methods, modular_methods, all_bases=None):
|
||||
self.python_module = python_module
|
||||
self.original_methods = original_methods
|
||||
self.updated_methods = updated_methods
|
||||
self.original_modeling_methods = original_modeling_methods
|
||||
self.modular_methods = modular_methods
|
||||
self.all_assign_target = {}
|
||||
self.deleted_targets = {} # child node can delete some arguments
|
||||
self.all_bases = all_bases or []
|
||||
@@ -414,53 +337,39 @@ class SuperTransformer(cst.CSTTransformer):
|
||||
break
|
||||
return new_body
|
||||
|
||||
def replace_super_calls(self, node: cst.IndentedBlock, func_name: str) -> cst.CSTNode:
|
||||
def replace_super_calls(self, node: cst.BaseSuite, func_name: str) -> cst.BaseSuite:
|
||||
"""Updates the body of the input `node`'s `func_name` function by replacing calls
|
||||
to super().func_name() with the source code of the parent class' `func_name`.
|
||||
It keeps everything that is defined before `super().func_name()`.
|
||||
"""
|
||||
self.has_docstring = False
|
||||
parent_has_docstring = False
|
||||
if func_name in self.original_methods:
|
||||
parent_has_docstring = m.matches(self.original_methods[func_name].body.body[0], DOCSTRING_NODE)
|
||||
new_body = []
|
||||
has_super_call = False
|
||||
modular_node_body = node.body
|
||||
|
||||
for i, expr in enumerate(node.body):
|
||||
for i, expr in enumerate(modular_node_body):
|
||||
if is_call_to_super(expr, func_name):
|
||||
has_super_call = True
|
||||
new_body.extend(self.update_body(self.original_methods[func_name].body.body, node.body[i + 1 :]))
|
||||
original_modeling_method_body = self.original_modeling_methods[func_name].body.body
|
||||
new_body.extend(self.update_body(original_modeling_method_body, modular_node_body[i + 1 :]))
|
||||
new_body = self._fix_init_location(new_body)
|
||||
return node.with_changes(body=new_body)
|
||||
else:
|
||||
expr = expr.visit(self.transformer)
|
||||
if m.matches(expr, DOCSTRING_NODE):
|
||||
self.has_docstring = True
|
||||
if parent_has_docstring: # actually here we ought to de-duplicate?
|
||||
original_docstring = self.original_methods[func_name].body.body[0].body[0].value.value
|
||||
updated_docstring = expr.body[0].value.value
|
||||
merged_doc = merge_docstrings(original_docstring, updated_docstring)
|
||||
new_node = [expr.with_changes(body=[cst.Expr(value=cst.SimpleString(value=merged_doc))])]
|
||||
else:
|
||||
new_node = [expr]
|
||||
new_body.extend(new_node)
|
||||
elif not m.matches(expr, m.SimpleStatementLine(body=[m.Del()])) and not has_super_call:
|
||||
if not m.matches(expr, m.SimpleStatementLine(body=[m.Del()])):
|
||||
new_body.append(expr)
|
||||
if not self.has_docstring and parent_has_docstring:
|
||||
new_body = [self.original_methods[func_name].body.body[0]] + new_body
|
||||
|
||||
return node.with_changes(body=new_body)
|
||||
|
||||
def leave_FunctionDef(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode:
|
||||
if updated_node.name.value in self.updated_methods:
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
|
||||
name = updated_node.name.value
|
||||
if name in self.modular_methods:
|
||||
new_body = self.replace_super_calls(updated_node.body, name)
|
||||
return updated_node.with_changes(body=new_body, params=updated_node.params)
|
||||
return updated_node
|
||||
|
||||
def leave_Return(self, original_node: cst.Return, updated_node: cst.Return) -> cst.CSTNode:
|
||||
def leave_Return(self, original_node: cst.Return, updated_node: cst.Return) -> cst.Return:
|
||||
""" "When a return statement is reached, it is replaced with the unrolled super code"""
|
||||
if m.matches(updated_node.value, m.Call(func=m.Attribute(attr=m.Name("super")))):
|
||||
func_def = self.get_metadata(ParentNodeProvider, original_node)
|
||||
if m.matched(func_def, m.FunctionDef()) and func_def.name.value in self.original_methods:
|
||||
if m.matched(func_def, m.FunctionDef()) and func_def.name.value in self.original_modeling_methods:
|
||||
updated_return_value = updated_node.value.with_changes(
|
||||
args=[
|
||||
cst.Arg(
|
||||
@@ -979,55 +888,52 @@ def common_partial_suffix(str1: str, str2: str) -> str:
|
||||
|
||||
|
||||
def replace_class_node(
|
||||
mapper: ModelFileMapper, class_node: cst.ClassDef, renamed_super_class: str, original_super_class: str
|
||||
):
|
||||
mapper: ModelFileMapper, modular_class_node: cst.ClassDef, renamed_super_class: str, original_super_class: str
|
||||
) -> cst.ClassDef:
|
||||
"""
|
||||
Replace a class node which inherits from another modeling class. This function works in the following way:
|
||||
- start from the base class node of the inherited class (a cst.Node)
|
||||
- replace all methods of the base node with the methods defined in the child class
|
||||
- append all new methods defined in the child class
|
||||
- start from the methods and class attributes of the original modeling code node, and replace their definition
|
||||
if overriden in the modular
|
||||
- append all new methods and class attributes defined in the child class
|
||||
- all potential method/class docstrings and decorators use the ones found in modular if any, else in original modeling
|
||||
- replace all calls to super() with the unravelled code
|
||||
|
||||
| ```python | | ```python
|
||||
| class GemmaModel(LlamaModel): | | class GemmaModel(nn.Module):
|
||||
| def __init__(self): | | def __init__(self):
|
||||
Going from: | super().__init__() | to: | super().__init__(config)
|
||||
| self.dropout = 0.2 | | self.dropout = 0.2
|
||||
| ``` | | self.padding_idx = config.pad_token_id
|
||||
| self.vocab_size = config.vocab_size
|
||||
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||
| self.layers = nn.ModuleList(
|
||||
| [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
| )
|
||||
| self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
| self.gradient_checkpointing = False
|
||||
| # Initialize weights and apply final processing
|
||||
| self.post_init()
|
||||
| ```
|
||||
"""
|
||||
all_bases = [get_full_attribute_name(k.value) for k in class_node.bases]
|
||||
if any(base is None for base in all_bases):
|
||||
raise ValueError(f"Could not parse the name of the bases for {class_node.name.value}")
|
||||
Args:
|
||||
mapper (`ModelFileMapper`):
|
||||
The mapper corresponding to the visited file from which the modular class node inherits.
|
||||
modular_class_node (`cst.ClassDef`):
|
||||
The class node as found in the modular file.
|
||||
renamed_super_class (`str`):
|
||||
The name of the class from which `modular_class_node` inherits after automatic renaming.
|
||||
original_super_class (`str`):
|
||||
The name of the class from which `modular_class_node` inherits before automatic renaming.
|
||||
|
||||
original_node = mapper.classes[renamed_super_class]
|
||||
Returns:
|
||||
A new class node corresponding to the modular definition.
|
||||
"""
|
||||
all_bases = [get_full_attribute_name(k.value) for k in modular_class_node.bases]
|
||||
if any(base is None for base in all_bases):
|
||||
raise ValueError(f"Could not parse the name of the bases for {modular_class_node.name.value}")
|
||||
|
||||
original_modeling_node = mapper.classes[renamed_super_class]
|
||||
# Always use the new name of the class (in case we use e.g. `ColPaliForRetrieval` inheriting from `PaliGemmaForConditionalGeneration`)
|
||||
new_name = class_node.name
|
||||
new_class_name = modular_class_node.name
|
||||
|
||||
# If the new class name is different from the renamed super class name, we need to update the docstrings/comments accordingly
|
||||
if new_name.value != renamed_super_class:
|
||||
common_suffix = common_partial_suffix(new_name.value, renamed_super_class)
|
||||
if new_class_name.value != renamed_super_class:
|
||||
common_suffix = common_partial_suffix(new_class_name.value, renamed_super_class)
|
||||
# Note that this works even without common prefix, in which case it does not replace anything
|
||||
old, new = renamed_super_class.replace(common_suffix, ""), new_name.value.replace(common_suffix, "")
|
||||
temp_module = cst.Module(body=[original_node])
|
||||
original_node = temp_module.visit(
|
||||
old, new = renamed_super_class.replace(common_suffix, ""), new_class_name.value.replace(common_suffix, "")
|
||||
temp_module = cst.Module(body=[original_modeling_node])
|
||||
original_modeling_node = temp_module.visit(
|
||||
ReplaceNameTransformer(get_lowercase_name(old), get_lowercase_name(new), only_doc=True)
|
||||
).body[0]
|
||||
|
||||
# If we explicitly passed a new base with common suffix to an old base, it is for switching the prefix
|
||||
# e.g. if the "natural" parent class is `PreTrainedModel` but we wanted to rename it to `PreTrainedVisionModel`
|
||||
additional_bases = [base for base in all_bases if base != original_super_class]
|
||||
new_bases = []
|
||||
for original_base in original_node.bases:
|
||||
new_class_bases = []
|
||||
for original_base in original_modeling_node.bases:
|
||||
new_base = original_base
|
||||
# we only potentially switch base for Name-based bases, not Attribute
|
||||
if m.matches(original_base.value, m.Name()):
|
||||
@@ -1038,106 +944,125 @@ def replace_class_node(
|
||||
new_name_node = original_base.value.with_changes(value=additional_base_name)
|
||||
new_base = original_base.with_changes(value=new_name_node)
|
||||
break
|
||||
new_bases.append(new_base)
|
||||
new_class_bases.append(new_base)
|
||||
|
||||
original_methods = {
|
||||
f.name.value if hasattr(f, "name") else mapper.python_module.code_for_node(f): f
|
||||
for f in original_node.body.body
|
||||
}
|
||||
updated_methods = {
|
||||
f.name.value if hasattr(f, "name") else mapper.python_module.code_for_node(f): f for f in class_node.body.body
|
||||
}
|
||||
end_meth = []
|
||||
# Use class decorators redefined in modular file if any
|
||||
new_class_decorators = (
|
||||
modular_class_node.decorators if len(modular_class_node.decorators) > 0 else original_modeling_node.decorators
|
||||
)
|
||||
|
||||
assign_targets = {}
|
||||
docstring_node = []
|
||||
# Iterate directly from node.body as there can be property/setters with same names which are overwritten when we use a dict
|
||||
for func in original_node.body.body:
|
||||
name = func.name.value if hasattr(func, "name") else mapper.python_module.code_for_node(func)
|
||||
if m.matches(func, m.FunctionDef()) and name in updated_methods and updated_methods[name] is not None:
|
||||
new_params = updated_methods[name].params
|
||||
# Replace the method in the replacement class, preserving decorators
|
||||
kwarg_name = getattr(updated_methods[name].params, "star_kwarg", None)
|
||||
# Compute new class docstring
|
||||
original_modeling_docstring = [
|
||||
node for node in original_modeling_node.body.body if m.matches(node, DOCSTRING_NODE)
|
||||
]
|
||||
modular_docstring = [node for node in modular_class_node.body.body if m.matches(node, DOCSTRING_NODE)]
|
||||
# Use class docstring in modular if any, else original modeling code docstring
|
||||
new_class_docstring = modular_docstring if len(modular_docstring) > 0 else original_modeling_docstring
|
||||
|
||||
# Compute new class attributes
|
||||
original_modeling_class_attributes = {
|
||||
node.body[0].targets[0].target.value: node
|
||||
for node in original_modeling_node.body.body
|
||||
if m.matches(node, m.SimpleStatementLine(body=[m.Assign()]))
|
||||
}
|
||||
original_modeling_class_attributes.update(
|
||||
{
|
||||
node.body[0].target.value: node
|
||||
for node in original_modeling_node.body.body
|
||||
if m.matches(node, m.SimpleStatementLine(body=[m.AnnAssign()]))
|
||||
}
|
||||
)
|
||||
modular_class_attributes = {
|
||||
node.body[0].targets[0].target.value: node
|
||||
for node in modular_class_node.body.body
|
||||
if m.matches(node, m.SimpleStatementLine(body=[m.Assign()]))
|
||||
}
|
||||
modular_class_attributes.update(
|
||||
{
|
||||
node.body[0].target.value: node
|
||||
for node in modular_class_node.body.body
|
||||
if m.matches(node, m.SimpleStatementLine(body=[m.AnnAssign()]))
|
||||
}
|
||||
)
|
||||
# Use all original modeling attributes, and potentially override some with values in the modular
|
||||
new_class_attributes = list({**original_modeling_class_attributes, **modular_class_attributes}.values())
|
||||
|
||||
original_modeling_methods = {
|
||||
node.name.value: node for node in original_modeling_node.body.body if m.matches(node, m.FunctionDef())
|
||||
}
|
||||
modular_methods = {
|
||||
node.name.value: node for node in modular_class_node.body.body if m.matches(node, m.FunctionDef())
|
||||
}
|
||||
|
||||
new_class_methods = []
|
||||
# Iterate over the methods of the original modeling code, and add them to the list of methods to add
|
||||
for name, node in original_modeling_methods.items():
|
||||
# If the method was redefined in modular, make appropriate changes to the node
|
||||
if name in modular_methods:
|
||||
# Get the corresponding method node in modular
|
||||
modular_node = modular_methods[name]
|
||||
|
||||
# If we match the pattern, we should avoid inheriting the method
|
||||
if re.match(r"\ndef .*\(.*\):\n raise.*Error\(.*", mapper.python_module.code_for_node(modular_node)):
|
||||
continue
|
||||
|
||||
# Compute new method docstring
|
||||
modeling_docstring = [node_ for node_ in node.body.body if m.matches(node_, DOCSTRING_NODE)]
|
||||
modular_docstring = [node_ for node_ in modular_node.body.body if m.matches(node_, DOCSTRING_NODE)]
|
||||
# Use method docstring in modular if any, else original modeling code docstring
|
||||
new_body = (
|
||||
modular_node.body.body
|
||||
if len(modular_docstring) > 0
|
||||
else modeling_docstring + list(modular_node.body.body)
|
||||
)
|
||||
new_body = modular_node.body.with_changes(body=new_body)
|
||||
|
||||
# Use arguments as defined in the modular
|
||||
new_params = modular_node.params
|
||||
|
||||
# If using the `**super_kwargs` syntax in modular, merge any existing modular arg with all the original modeling ones
|
||||
kwarg_name = getattr(modular_node.params, "star_kwarg", None)
|
||||
if kwarg_name and kwarg_name.name.value == "super_kwargs":
|
||||
parent_params = {k.name.value: k for k in func.params.params}
|
||||
parent_params.update({k.name.value: k for k in new_params.params[1:]})
|
||||
new_params = new_params.with_changes(
|
||||
params=list(parent_params.values()), star_kwarg=func.params.star_kwarg
|
||||
)
|
||||
# Keep decorators in `modular_xxx.py` if any, else original decorators
|
||||
new_decorators = (
|
||||
updated_methods[name].decorators if len(updated_methods[name].decorators) > 0 else func.decorators
|
||||
)
|
||||
original_modeling_params = {k.name.value: k for k in node.params.params}
|
||||
modular_params = {k.name.value: k for k in new_params.params[1:]}
|
||||
new_param_list = list({**original_modeling_params, **modular_params}.values())
|
||||
new_params = new_params.with_changes(params=new_param_list, star_kwarg=node.params.star_kwarg)
|
||||
|
||||
# Keep return annotation in `modular_xxx.py` if any, else original return annotation
|
||||
new_return_annotation = updated_methods[name].returns if updated_methods[name].returns else func.returns
|
||||
# Keep decorators in modular if any, else original decorators
|
||||
new_decorators = modular_node.decorators if len(modular_node.decorators) > 0 else node.decorators
|
||||
|
||||
if not re.match(
|
||||
r"\ndef .*\(.*\):\n raise.*Error\(.*",
|
||||
mapper.python_module.code_for_node(updated_methods[name]),
|
||||
):
|
||||
func = func.with_changes(
|
||||
body=updated_methods[name].body,
|
||||
# Keep return annotation in modular if any, else original return annotation
|
||||
new_return_annotation = modular_node.returns if modular_node.returns else node.returns
|
||||
|
||||
# Update the method node
|
||||
node = node.with_changes(
|
||||
body=new_body,
|
||||
params=new_params,
|
||||
decorators=new_decorators,
|
||||
returns=new_return_annotation,
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])):
|
||||
target = mapper.python_module.code_for_node(func.body[0].targets[0])
|
||||
assign_targets[target] = func
|
||||
elif m.matches(func, m.SimpleStatementLine(body=[m.AnnAssign()])):
|
||||
target = mapper.python_module.code_for_node(func.body[0].target)
|
||||
assign_targets[target] = func
|
||||
elif m.matches(func, DOCSTRING_NODE):
|
||||
docstring_node = [func]
|
||||
else:
|
||||
end_meth.append(func)
|
||||
new_class_methods.append(node)
|
||||
|
||||
# Port new methods that are defined only in modular-file and append at the end
|
||||
for func in class_node.body.body:
|
||||
name = func.name.value if hasattr(func, "name") else mapper.python_module.code_for_node(func)
|
||||
if m.matches(func, DOCSTRING_NODE): # This processes the docstring of the class!
|
||||
# Extract the original docstring
|
||||
updated_docstring = func.body[0].value.value
|
||||
if len(docstring_node) == 0: # If the original docstring is empty, just create one from the updated.
|
||||
docstring_node = [
|
||||
cst.SimpleStatementLine(body=[cst.Expr(value=cst.SimpleString(value=updated_docstring))])
|
||||
]
|
||||
else:
|
||||
original_docstring = docstring_node[0].body[0].value.value
|
||||
merged_doc = merge_docstrings(original_docstring, updated_docstring)
|
||||
# Update the docstring in the original function
|
||||
docstring_node = [
|
||||
docstring_node[0].with_changes(body=[cst.Expr(value=cst.SimpleString(value=merged_doc))])
|
||||
]
|
||||
if name not in original_methods and func is not None and isinstance(func, cst.FunctionDef):
|
||||
end_meth.append(func)
|
||||
if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])):
|
||||
# TODO we only use single assign might cause issues
|
||||
target = mapper.python_module.code_for_node(func.body[0].targets[0])
|
||||
assign_targets[target] = func
|
||||
if m.matches(func, m.SimpleStatementLine(body=[m.AnnAssign()])):
|
||||
target = mapper.python_module.code_for_node(func.body[0].target)
|
||||
assign_targets[target] = func
|
||||
end_meth = docstring_node + list(assign_targets.values()) + end_meth
|
||||
for name, node in modular_methods.items():
|
||||
if name not in original_modeling_methods:
|
||||
new_class_methods.append(node)
|
||||
|
||||
# Replace the calls to `super()` with the unrolled code
|
||||
result_node = original_node.with_changes(body=cst.IndentedBlock(body=end_meth))
|
||||
# Recreate the whole new class body
|
||||
new_class_body = new_class_docstring + new_class_attributes + new_class_methods
|
||||
|
||||
# Replace the calls to `super()` of the redefined modular methods with the unrolled code
|
||||
result_node = original_modeling_node.with_changes(body=cst.IndentedBlock(body=new_class_body))
|
||||
temp_module = cst.Module(body=[result_node])
|
||||
new_module = MetadataWrapper(temp_module)
|
||||
new_replacement_class = new_module.visit(
|
||||
SuperTransformer(temp_module, original_methods, updated_methods, all_bases)
|
||||
SuperTransformer(temp_module, original_modeling_methods, modular_methods, all_bases)
|
||||
)
|
||||
new_replacement_body = new_replacement_class.body[0].body # get the indented block
|
||||
new_class_body = new_replacement_class.body[0].body # get the indented block
|
||||
|
||||
# Use decorators redefined in `modular_xxx.py` if any
|
||||
new_decorators = class_node.decorators if len(class_node.decorators) > 0 else original_node.decorators
|
||||
|
||||
return original_node.with_changes(
|
||||
body=new_replacement_body, decorators=new_decorators, bases=new_bases, name=new_name
|
||||
return original_modeling_node.with_changes(
|
||||
body=new_class_body, decorators=new_class_decorators, bases=new_class_bases, name=new_class_name
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user