modular_model_converter bugfix on assignments (#35642)
* added bugfix in modular converter to keep modular assignments for docstrings, expected outputs etc. * revert stracoder2 docstring copying, add forward in EMU3 to enable docstring assingment, remove verbatim assignments in modular converter * added _FOR_DOC in assignments to keep, corrected wrong checkpoint name in ijepa's configuration
This commit is contained in:
@@ -61,6 +61,7 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "BambaConfig"
|
_CONFIG_FOR_DOC = "BambaConfig"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ from .configuration_cohere import CohereConfig
|
|||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "CohereConfig"
|
_CONFIG_FOR_DOC = "CohereConfig"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ from .configuration_cohere2 import Cohere2Config
|
|||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "Cohere2Config"
|
_CONFIG_FOR_DOC = "Cohere2Config"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1257,7 +1257,7 @@ class Emu3RotaryEmbedding(nn.Module):
|
|||||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
EMU3_INPUTS_DOCSTRING = r"""
|
EMU3_TEXT_INPUTS_DOCSTRING = r"""
|
||||||
Args:
|
Args:
|
||||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||||
@@ -1292,19 +1292,15 @@ EMU3_INPUTS_DOCSTRING = r"""
|
|||||||
config.n_positions - 1]`.
|
config.n_positions - 1]`.
|
||||||
|
|
||||||
[What are position IDs?](../glossary#position-ids)
|
[What are position IDs?](../glossary#position-ids)
|
||||||
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
past_key_values (`Cache`, *optional*):
|
||||||
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||||
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
||||||
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
||||||
|
|
||||||
Two formats are allowed:
|
Has to be an instance of [`~cache_utils.Cache`] instance, see our
|
||||||
- a [`~cache_utils.Cache`] instance, see our
|
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||||
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
|
|
||||||
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
|
||||||
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
|
||||||
cache format.
|
|
||||||
|
|
||||||
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
The model will output the same cache type that is fed as input. If no `past_key_values` are passed, the
|
||||||
legacy cache format will be returned.
|
legacy cache format will be returned.
|
||||||
|
|
||||||
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
||||||
@@ -1366,7 +1362,7 @@ class Emu3TextModel(Emu3PreTrainedModel):
|
|||||||
def set_input_embeddings(self, value):
|
def set_input_embeddings(self, value):
|
||||||
self.embed_tokens = value
|
self.embed_tokens = value
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(EMU3_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(EMU3_TEXT_INPUTS_DOCSTRING)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
@@ -1598,77 +1594,6 @@ class Emu3TextModel(Emu3PreTrainedModel):
|
|||||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
||||||
|
|
||||||
|
|
||||||
EMU3_TEXT_INPUTS_DOCSTRING = r"""
|
|
||||||
Args:
|
|
||||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
||||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
|
||||||
it.
|
|
||||||
|
|
||||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
||||||
[`PreTrainedTokenizer.__call__`] for details.
|
|
||||||
|
|
||||||
[What are input IDs?](../glossary#input-ids)
|
|
||||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
||||||
|
|
||||||
- 1 for tokens that are **not masked**,
|
|
||||||
- 0 for tokens that are **masked**.
|
|
||||||
|
|
||||||
[What are attention masks?](../glossary#attention-mask)
|
|
||||||
|
|
||||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
||||||
[`PreTrainedTokenizer.__call__`] for details.
|
|
||||||
|
|
||||||
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
|
|
||||||
`past_key_values`).
|
|
||||||
|
|
||||||
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
|
||||||
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
|
||||||
information on the default strategy.
|
|
||||||
|
|
||||||
- 1 indicates the head is **not masked**,
|
|
||||||
- 0 indicates the head is **masked**.
|
|
||||||
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
|
||||||
config.n_positions - 1]`.
|
|
||||||
|
|
||||||
[What are position IDs?](../glossary#position-ids)
|
|
||||||
past_key_values (`Cache`, *optional*):
|
|
||||||
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
|
||||||
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
|
||||||
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
|
||||||
|
|
||||||
Has to be an instance of [`~cache_utils.Cache`] instance, see our
|
|
||||||
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
|
||||||
|
|
||||||
The model will output the same cache type that is fed as input. If no `past_key_values` are passed, the
|
|
||||||
legacy cache format will be returned.
|
|
||||||
|
|
||||||
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
|
||||||
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
|
||||||
of shape `(batch_size, sequence_length)`.
|
|
||||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
||||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
|
||||||
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
|
||||||
model's internal embedding lookup matrix.
|
|
||||||
use_cache (`bool`, *optional*):
|
|
||||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
|
||||||
`past_key_values`).
|
|
||||||
output_attentions (`bool`, *optional*):
|
|
||||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
|
||||||
tensors for more detail.
|
|
||||||
output_hidden_states (`bool`, *optional*):
|
|
||||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
|
||||||
more detail.
|
|
||||||
return_dict (`bool`, *optional*):
|
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
||||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
|
||||||
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
|
||||||
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
|
||||||
the complete sequence length.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class Emu3ForCausalLM(Emu3PreTrainedModel, GenerationMixin):
|
class Emu3ForCausalLM(Emu3PreTrainedModel, GenerationMixin):
|
||||||
_tied_weights_keys = ["lm_head.weight"]
|
_tied_weights_keys = ["lm_head.weight"]
|
||||||
_tp_plan = {"lm_head": "colwise_rep"}
|
_tp_plan = {"lm_head": "colwise_rep"}
|
||||||
@@ -1790,6 +1715,85 @@ class Emu3ForCausalLM(Emu3PreTrainedModel, GenerationMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
EMU3_INPUTS_DOCSTRING = r"""
|
||||||
|
Args:
|
||||||
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||||
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||||||
|
it.
|
||||||
|
|
||||||
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||||
|
[`PreTrainedTokenizer.__call__`] for details.
|
||||||
|
|
||||||
|
[What are input IDs?](../glossary#input-ids)
|
||||||
|
pixel_values (`torch.FloatTensor` of shape `(batch_size, max_num_images, max_num_tiles, channels, image_size, image_size)):
|
||||||
|
The tensors corresponding to the input images. Pixel values can be obtained using
|
||||||
|
[`AutoImageProcessor`]. See [`Emu3ImageProcessor.__call__`] for details ([]`Emu3Processor`] uses
|
||||||
|
[`Emu3ImageProcessor`] for processing images).
|
||||||
|
image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`):
|
||||||
|
The sizes of the images in the batch, being (height, width) for each image. Image sizes can be obtained using
|
||||||
|
[`AutoImageProcessor`]. See [`Emu3ImageProcessor.__call__`] for details ([]`Emu3Processor`] uses
|
||||||
|
[`Emu3ImageProcessor`] for processing images).
|
||||||
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
|
[What are attention masks?](../glossary#attention-mask)
|
||||||
|
|
||||||
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||||
|
[`PreTrainedTokenizer.__call__`] for details.
|
||||||
|
|
||||||
|
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
|
||||||
|
`past_key_values`).
|
||||||
|
|
||||||
|
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
||||||
|
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
||||||
|
information on the default strategy.
|
||||||
|
|
||||||
|
- 1 indicates the head is **not masked**,
|
||||||
|
- 0 indicates the head is **masked**.
|
||||||
|
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||||
|
config.n_positions - 1]`.
|
||||||
|
|
||||||
|
[What are position IDs?](../glossary#position-ids)
|
||||||
|
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
||||||
|
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||||
|
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
||||||
|
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
||||||
|
|
||||||
|
Has to be an instance of [`~cache_utils.Cache`] instance, see our
|
||||||
|
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||||
|
|
||||||
|
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
||||||
|
legacy cache format will be returned.
|
||||||
|
|
||||||
|
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
||||||
|
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
||||||
|
of shape `(batch_size, sequence_length)`.
|
||||||
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||||
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
||||||
|
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
||||||
|
model's internal embedding lookup matrix.
|
||||||
|
use_cache (`bool`, *optional*):
|
||||||
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||||
|
`past_key_values`).
|
||||||
|
output_attentions (`bool`, *optional*):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||||
|
tensors for more detail.
|
||||||
|
output_hidden_states (`bool`, *optional*):
|
||||||
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||||
|
more detail.
|
||||||
|
return_dict (`bool`, *optional*):
|
||||||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
|
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||||
|
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
||||||
|
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
||||||
|
the complete sequence length.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
|
class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
|
||||||
_tied_weights_keys = ["text_model.lm_head.weight"]
|
_tied_weights_keys = ["text_model.lm_head.weight"]
|
||||||
|
|
||||||
|
|||||||
@@ -1059,6 +1059,10 @@ class Emu3TextModel(LlamaModel, Emu3PreTrainedModel):
|
|||||||
[Emu3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
[Emu3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(EMU3_TEXT_INPUTS_DOCSTRING)
|
||||||
|
def forward(self, **super_kwargs):
|
||||||
|
super().forward(**super_kwargs)
|
||||||
|
|
||||||
|
|
||||||
class Emu3ForCausalLM(LlamaForCausalLM, Emu3PreTrainedModel, GenerationMixin):
|
class Emu3ForCausalLM(LlamaForCausalLM, Emu3PreTrainedModel, GenerationMixin):
|
||||||
config_class = Emu3TextConfig
|
config_class = Emu3TextConfig
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ class IJepaConfig(PretrainedConfig):
|
|||||||
This is the configuration class to store the configuration of a [`IJepaModel`]. It is used to instantiate an IJEPA
|
This is the configuration class to store the configuration of a [`IJepaModel`]. It is used to instantiate an IJEPA
|
||||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||||
defaults will yield a similar configuration to that of the I-JEPA
|
defaults will yield a similar configuration to that of the I-JEPA
|
||||||
[google/ijepa-base-patch16-224](https://huggingface.co/google/ijepa-base-patch16-224) architecture.
|
[facebook/ijepa_vith14_1k](https://huggingface.co/facebook/ijepa_vith14_1k) architecture.
|
||||||
|
|
||||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
documentation from [`PretrainedConfig`] for more information.
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|||||||
@@ -527,7 +527,9 @@ IJEPA_INPUTS_DOCSTRING = r"""
|
|||||||
return_dict (`bool`, *optional*):
|
return_dict (`bool`, *optional*):
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
"""
|
"""
|
||||||
_EXPECTED_OUTPUT_SHAPE = [1, 197, 768]
|
|
||||||
|
|
||||||
|
_EXPECTED_OUTPUT_SHAPE = [1, 256, 1280]
|
||||||
|
|
||||||
|
|
||||||
IJEPA_START_DOCSTRING = r"""
|
IJEPA_START_DOCSTRING = r"""
|
||||||
@@ -640,8 +642,7 @@ class IJepaModel(IJepaPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Image classification docstring
|
_IMAGE_CLASS_CHECKPOINT = "facebook/ijepa_vith14_1k"
|
||||||
_IMAGE_CLASS_CHECKPOINT = "google/ijepa-base-patch16-224"
|
|
||||||
_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat"
|
_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -49,7 +49,6 @@ from .configuration_moonshine import MoonshineConfig
|
|||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "MoonshineConfig"
|
_CONFIG_FOR_DOC = "MoonshineConfig"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ from .configuration_phi import PhiConfig
|
|||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CHECKPOINT_FOR_DOC = "meta-phi/Phi-2-7b-hf"
|
_CHECKPOINT_FOR_DOC = "microsoft/phi-1"
|
||||||
_CONFIG_FOR_DOC = "PhiConfig"
|
_CONFIG_FOR_DOC = "PhiConfig"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -26,6 +26,9 @@ from .configuration_phi import PhiConfig
|
|||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
_CHECKPOINT_FOR_DOC = "microsoft/phi-1"
|
||||||
|
_CONFIG_FOR_DOC = "PhiConfig"
|
||||||
|
|
||||||
|
|
||||||
class PhiAttention(LlamaAttention):
|
class PhiAttention(LlamaAttention):
|
||||||
def __init__(self, config: PhiConfig, layer_idx: int):
|
def __init__(self, config: PhiConfig, layer_idx: int):
|
||||||
|
|||||||
@@ -56,6 +56,7 @@ from .configuration_starcoder2 import Starcoder2Config
|
|||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
_CHECKPOINT_FOR_DOC = "bigcode/starcoder2-7b"
|
_CHECKPOINT_FOR_DOC = "bigcode/starcoder2-7b"
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "Starcoder2Config"
|
_CONFIG_FOR_DOC = "Starcoder2Config"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -268,7 +268,7 @@ def merge_docstrings(original_docstring, updated_docstring):
|
|||||||
updated_docstring = "".join(
|
updated_docstring = "".join(
|
||||||
[
|
[
|
||||||
parts[0].rstrip(" \n") + new_parts[0],
|
parts[0].rstrip(" \n") + new_parts[0],
|
||||||
f"\n{original_level*' '}```",
|
f"\n{original_level * ' '}```",
|
||||||
parts[1],
|
parts[1],
|
||||||
"```",
|
"```",
|
||||||
parts[2],
|
parts[2],
|
||||||
@@ -515,10 +515,8 @@ def find_all_dependencies(
|
|||||||
return all_dependencies_with_parent
|
return all_dependencies_with_parent
|
||||||
|
|
||||||
|
|
||||||
# These top-level variables will always use the value in the `modular_xxx.py` file
|
# Top-level variables that match the following patterns will always use the value in the `modular_xxx.py` file
|
||||||
ASSIGNMENTS_TO_KEEP = {
|
ASSIGNMENTS_REGEX_TO_KEEP = [r"_CHECKPOINT", r"_EXPECTED", r"_FOR_DOC"]
|
||||||
"_CHECKPOINT_FOR_DOC",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ClassDependencyMapper(CSTVisitor):
|
class ClassDependencyMapper(CSTVisitor):
|
||||||
@@ -828,12 +826,14 @@ class ModelFileMapper(ModuleMapper):
|
|||||||
def _merge_assignments(self, assignments: dict[str, cst.CSTNode], object_mapping: dict[str, set]):
|
def _merge_assignments(self, assignments: dict[str, cst.CSTNode], object_mapping: dict[str, set]):
|
||||||
"""Update the global nodes with the assignment from the modular file.
|
"""Update the global nodes with the assignment from the modular file.
|
||||||
|
|
||||||
Merging rule: if any assignment with the same name was redefined in the modular, we use it and its dependencies ONLY if it is
|
Merging rule: if any assignment with the same name was redefined in the modular, we use it and its dependencies ONLY if it matches
|
||||||
in `ASSIGNMENTS_TO_KEEP`. Otherwise, we use the original value and dependencies. This rule was chosen to avoid having to rewrite the
|
a pattern in `ASSIGNMENTS_REGEX_TO_KEEP`. Otherwise, we use the original value and dependencies. This rule was chosen to avoid having to rewrite the
|
||||||
big docstrings.
|
big docstrings.
|
||||||
"""
|
"""
|
||||||
for assignment, node in assignments.items():
|
for assignment, node in assignments.items():
|
||||||
if assignment in ASSIGNMENTS_TO_KEEP or assignment not in self.assignments:
|
should_keep = any(re.search(pattern, assignment) for pattern in ASSIGNMENTS_REGEX_TO_KEEP)
|
||||||
|
|
||||||
|
if should_keep or assignment not in self.assignments:
|
||||||
self.assignments[assignment] = node
|
self.assignments[assignment] = node
|
||||||
if assignment in object_mapping:
|
if assignment in object_mapping:
|
||||||
self.object_dependency_mapping[assignment] = object_mapping[assignment]
|
self.object_dependency_mapping[assignment] = object_mapping[assignment]
|
||||||
@@ -1404,7 +1404,7 @@ class ModularFileMapper(ModuleMapper):
|
|||||||
]
|
]
|
||||||
if len(modeling_bases) > 1:
|
if len(modeling_bases) > 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"{class_name} was defined with more than 1 model-specific super class. This is unsupported. We found {*modeling_bases,}."
|
f"{class_name} was defined with more than 1 model-specific super class. This is unsupported. We found {(*modeling_bases,)}."
|
||||||
)
|
)
|
||||||
if len(modeling_bases) == 1:
|
if len(modeling_bases) == 1:
|
||||||
filename = self.model_specific_imported_objects[modeling_bases[0]]
|
filename = self.model_specific_imported_objects[modeling_bases[0]]
|
||||||
@@ -1432,7 +1432,7 @@ class ModularFileMapper(ModuleMapper):
|
|||||||
if final_name != cased_default_name and has_prefix_collision:
|
if final_name != cased_default_name and has_prefix_collision:
|
||||||
if len(prefixes_counter) > 1:
|
if len(prefixes_counter) > 1:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"We detected multiple prefix names when inheriting from {file}: {*set(prefixes_counter),}. However, the "
|
f"We detected multiple prefix names when inheriting from {file}: {(*set(prefixes_counter),)}. However, the "
|
||||||
f"most used one, '{final_name}', is already present in the source file and will likely cause consistency "
|
f"most used one, '{final_name}', is already present in the source file and will likely cause consistency "
|
||||||
f"issues. For this reason we fallback to the default prefix '{cased_default_name}' when grabbing args "
|
f"issues. For this reason we fallback to the default prefix '{cased_default_name}' when grabbing args "
|
||||||
"and dependencies. Make sure to subclass the intermediate classes with the prefix you want (if different "
|
"and dependencies. Make sure to subclass the intermediate classes with the prefix you want (if different "
|
||||||
@@ -1448,7 +1448,7 @@ class ModularFileMapper(ModuleMapper):
|
|||||||
final_name = cased_default_name
|
final_name = cased_default_name
|
||||||
elif len(prefixes_counter) > 1:
|
elif len(prefixes_counter) > 1:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"We detected multiple prefix names when inheriting from {file}: {*set(prefixes_counter),}. We will only "
|
f"We detected multiple prefix names when inheriting from {file}: {(*set(prefixes_counter),)}. We will only "
|
||||||
f"use the most used '{final_name}' prefix when grabbing args and dependencies. Make sure to subclass the "
|
f"use the most used '{final_name}' prefix when grabbing args and dependencies. Make sure to subclass the "
|
||||||
f"intermediate classes with the prefix you want (if different from '{final_name}') or use a single prefix "
|
f"intermediate classes with the prefix you want (if different from '{final_name}') or use a single prefix "
|
||||||
"in all the modular (best)."
|
"in all the modular (best)."
|
||||||
|
|||||||
Reference in New Issue
Block a user