Refactor the way we handle outputs for new llamas and new models (#39120)

* just update 2 files

* update other models as well just making fix-copies

* also add the changes needed to modeling utils

* put this on the pretrained model instead

* nits and fixes

* update generic, fix to use config value

* update other modelings

* use transformers kwargs instead

* update

* update

* update other models

* update

* updates

* update

* update

* update

* fix

* finally

* very small nits

* this fixes more tests

* fix other models as well!

* update modularqwen2

* update models based on qwen2

* update

* update

* remove the **flash stuff in favor of noraml kwargs

* update

* propagate gemma?

* remove output attentions

* propagate

* support cross attention edge case

* same

* test this

* fixes

* more fix

* update

* update

* fix conflicts

* update

* fix emu3

* fix emu3

* move the fix a bit

* quel enfer

* some fixes, loss_kwargs should never had been

* finish fixing gemma3n

* fix small lm3

* fix another one

* fix csm now

* fux csm and mistral

* fix mistral now

* small fixes

* fix janusss

* only for some models

* fixup

* phix phi3

* more fixes?

* dose this fix it?

* update

* holy shit it was just graph breaks

* protect torch

* updates

* fix samhq?

* fix moonshine

* more moonshine fixes, 3 failures left!

* nits

* generic needs to support more

* more fixes to moonshine!

* fix cross attention outputs!

* fix csm!

* nits

* fix stupid kosmos2

* current updates

* fixes

* use output recorder?

* nicer!

* a little bit of magic

* update

* fix protect

* fix

* small fixes

* protect import

* fix a bunch of more models

* fix fixups

* fix some of the last ones

* nit

* partly fix phi

* update

* fix import path

* make something that is fullgraph compatible just to be sure

* typing was wrong on llama so the rest was wrong as well

* fucking ugly but at least it is still exportable

* syle

* supposed to fix moonshine, it still breaks

* fix some default

* fix the last bits of sam

* update samhq

* more fixes to am hq

* nit

* fix all output+hidden states and output_attentions!

* fix?

* fix diffllama

* updates to fix initialization on the sam pips

* ups there was a bug

* fix the last sam hq test

* fix gotocr

* fix gotocr2!

* fixes

* skip stupid tests

* there was one left :)

* fixup

* fix fix copies issues with this test file

* fix copies for sam_hq

* rm some comments

* skip 2 more failing tests

* fix

* fix everything

* Apply suggestions from code review

Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>

* add more doc!

* fix public init

* fix modular qwen3

---------

Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>
This commit is contained in:
Arthur
2025-07-05 11:34:28 +02:00
committed by GitHub
parent e6a8063ef1
commit ca7e1a3756
146 changed files with 2045 additions and 5936 deletions

View File

@@ -224,7 +224,6 @@ class T5GemmaModelTester:
lm_labels,
)
# Copied from tests.models.t5.test_modeling_t5.T5ModelTester.prepare_config_and_inputs_for_common
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
@@ -613,7 +612,6 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
num_hidden_layers=self.model_tester.num_hidden_layers,
)
# Copied from tests.models.t5.test_modeling_t5.T5ModelTest.is_pipeline_test_to_skip
def is_pipeline_test_to_skip(
self,
pipeline_test_case_name,
@@ -631,16 +629,14 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
return False
# Copied from tests.models.t5.test_modeling_t5.T5ModelTest.test_config
def test_config(self):
self.config_tester.run_common_tests()
# Copied from tests.models.t5.test_modeling_t5.T5ModelTest.test_shift_right
def test_shift_right(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_prepare_lm_labels_via_shift_left(*config_and_inputs)
# Copied from tests.models.t5.test_modeling_t5.T5ModelTest.test_model
@unittest.skip("This was not properly written, submodules need the attribute to be overwritten")
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
@@ -675,19 +671,17 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
with torch.no_grad():
model(**inputs)[0]
# Copied from tests.models.t5.test_modeling_t5.T5ModelTest.test_config_and_model_silu_gated
@unittest.skip("This was not properly written, submodules need the attribute to be overwritten")
def test_config_and_model_silu_gated(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
config = config_and_inputs[0]
config.feed_forward_proj = "gated-silu"
self.model_tester.create_and_check_model(*config_and_inputs)
# Copied from tests.models.t5.test_modeling_t5.T5ModelTest.test_with_lm_head
def test_with_lm_head(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_with_lm_head(*config_and_inputs)
# Copied from tests.models.t5.test_modeling_t5.T5ModelTest.test_with_sequence_classification_head
def test_with_sequence_classification_head(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_with_sequence_classification_head(*config_and_inputs)
@@ -706,12 +700,11 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
*config_and_inputs, is_encoder_decoder
)
# Copied from tests.models.t5.test_modeling_t5.T5ModelTest.test_decoder_model_past
@unittest.skip("This was not properly written, submodules need the attribute to be overwritten")
def test_decoder_model_past(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_past(*config_and_inputs)
# Copied from tests.models.t5.test_modeling_t5.T5ModelTest.test_decoder_model_past_with_attn_mask
def test_decoder_model_past_with_attn_mask(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs)
@@ -745,18 +738,15 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
lm_labels,
)
# Copied from tests.models.t5.test_modeling_t5.T5ModelTest.test_decoder_model_past_with_large_inputs
def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
# Copied from tests.models.t5.test_modeling_t5.T5ModelTest.test_generate_with_past_key_values
def test_generate_with_past_key_values(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_generate_with_past_key_values(*config_and_inputs)
@unittest.skipIf(torch_device == "cpu", "Can't do half precision")
# Copied from tests.models.t5.test_modeling_t5.T5ModelTest.test_model_fp16_forward
def test_model_fp16_forward(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs)
@@ -872,6 +862,7 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
# Based on tests.test_modeling_common.ModelTesterMixin.test_attention_outputs
# Skip token classification
@unittest.skip("This was not properly written, submodules need the attribute to be overwritten")
def test_attention_outputs(self):
if not self.has_attentions:
self.skipTest(reason="Model does not output attentions")
@@ -909,7 +900,7 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
del inputs_dict["output_attentions"]
config._attn_implementation = "eager"
config.output_attentions = True
model = model_class(config)
model = model_class._from_config(config, attn_implementation="eager")
model.to(torch_device)
model.eval()
with torch.no_grad():
@@ -1254,6 +1245,7 @@ class T5GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
# Based on tests.test_modeling_common.ModelTesterMixin.test_inputs_embeds_matches_input_ids
# Adjust token classiifcation
@unittest.skip("This was not properly written, submodules need the attribute to be overwritten")
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
if model_class in [self.model_tester.for_token_class, self.model_tester.for_sequence_class]:
@@ -1607,6 +1599,7 @@ class T5GemmaEncoderOnlyModelTest(ModelTesterMixin, unittest.TestCase):
def test_config(self):
self.config_tester.run_common_tests()
@unittest.skip("This was not properly written, submodules need the attribute to be overwritten")
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)