Clean-up composite configs (#34603)

* remove manual assignment tie-word-embeddings

* remove another unused attribute

* fix tests

* fix tests

* remove unnecessary overwrites

* fix

* decoder=True

* clean pix2struct

* run-all

* forgot `_tied_weights_keys` when adding Emu3

* also Aria + fix-copies

* and clean aria
This commit is contained in:
Raushan Turganbay
2025-01-15 10:04:07 +01:00
committed by GitHub
parent c61fcde910
commit 09d5f76274
33 changed files with 68 additions and 219 deletions

View File

@@ -1360,6 +1360,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
config_class = AriaConfig
_supports_flash_attn_2 = False
_supports_sdpa = False
_tied_weights_keys = ["language_model.lm_head.weight"]
def __init__(self, config: AriaConfig):
super().__init__(config)
@@ -1406,9 +1407,6 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
def get_decoder(self):
return self.language_model.get_decoder()
def tie_weights(self):
return self.language_model.tie_weights()
def get_image_features(
self,
pixel_values: torch.FloatTensor,

View File

@@ -1337,6 +1337,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
config_class = AriaConfig
_supports_flash_attn_2 = False
_supports_sdpa = False
_tied_weights_keys = ["language_model.lm_head.weight"]
def __init__(self, config: AriaConfig):
super().__init__(config)
@@ -1383,9 +1384,6 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
def get_decoder(self):
return self.language_model.get_decoder()
def tie_weights(self):
return self.language_model.tie_weights()
def get_image_features(
self,
pixel_values: torch.FloatTensor,