🚨🚨 Fix and simplify attention implementation dispatch and subconfigs handling (#39423)
* first try * Update modeling_utils.py * Update modeling_utils.py * big refactor * Update modeling_utils.py * style * docstrings and simplify inner workings of configs * remove all trace of _internal * Update modeling_utils.py * fix logic error * Update modeling_utils.py * recursive on config * Update configuration_utils.py * fix * Update configuration_dpt.py * Update configuration_utils.py * Update configuration_utils.py * Update modeling_idefics.py * Update modeling_utils.py * fix for old models * more old models fixup * Update modeling_utils.py * Update configuration_utils.py * Remove outdated test * remove the deepcopy!! 🥵🥵 * Update test_modeling_gpt_bigcode.py * fix qwen dispatch * restrict to only models supporting it * style * switch name * Update modeling_utils.py * Update modeling_utils.py * add tests! * fix * rypo * remove bad copies * fix * Update modeling_utils.py * additional check * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * Update modeling_utils.py * fix * skip
This commit is contained in:
@@ -7,10 +7,3 @@ class CustomConfig(PretrainedConfig):
|
||||
def __init__(self, attribute=1, **kwargs):
|
||||
self.attribute = attribute
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class NoSuperInitConfig(PretrainedConfig):
|
||||
model_type = "custom"
|
||||
|
||||
def __init__(self, attribute=1, **kwargs):
|
||||
self.attribute = attribute
|
||||
|
||||
@@ -2,7 +2,7 @@ import torch
|
||||
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from .custom_configuration import CustomConfig, NoSuperInitConfig
|
||||
from .custom_configuration import CustomConfig
|
||||
|
||||
|
||||
class CustomModel(PreTrainedModel):
|
||||
@@ -17,17 +17,3 @@ class CustomModel(PreTrainedModel):
|
||||
|
||||
def _init_weights(self, module):
|
||||
pass
|
||||
|
||||
|
||||
class NoSuperInitModel(PreTrainedModel):
|
||||
config_class = NoSuperInitConfig
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.linear = torch.nn.Linear(config.attribute, config.attribute)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
def _init_weights(self, module):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user