Clean-up composite configs (#34603)
* remove manual assignment tie-word-embeddings * remove another unused attribute * fix tests * fix tests * remove unnecessary overwrites * fix * decoder=True * clean pix2struct * run-all * forgot `_tied_weights_keys` when adding Emu3 * also Aria + fix-copies * and clean aria
This commit is contained in:
committed by
GitHub
parent
c61fcde910
commit
09d5f76274
@@ -1360,6 +1360,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
config_class = AriaConfig
|
||||
_supports_flash_attn_2 = False
|
||||
_supports_sdpa = False
|
||||
_tied_weights_keys = ["language_model.lm_head.weight"]
|
||||
|
||||
def __init__(self, config: AriaConfig):
|
||||
super().__init__(config)
|
||||
@@ -1406,9 +1407,6 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
def get_decoder(self):
|
||||
return self.language_model.get_decoder()
|
||||
|
||||
def tie_weights(self):
|
||||
return self.language_model.tie_weights()
|
||||
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
|
||||
@@ -1337,6 +1337,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
config_class = AriaConfig
|
||||
_supports_flash_attn_2 = False
|
||||
_supports_sdpa = False
|
||||
_tied_weights_keys = ["language_model.lm_head.weight"]
|
||||
|
||||
def __init__(self, config: AriaConfig):
|
||||
super().__init__(config)
|
||||
@@ -1383,9 +1384,6 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
def get_decoder(self):
|
||||
return self.language_model.get_decoder()
|
||||
|
||||
def tie_weights(self):
|
||||
return self.language_model.tie_weights()
|
||||
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
|
||||
Reference in New Issue
Block a user