Refactor the way we handle outputs for new llamas and new models (#39120)

* just update 2 files

* update other models as well just making fix-copies

* also add the changes needed to modeling utils

* put this on the pretrained model instead

* nits and fixes

* update generic, fix to use config value

* update other modelings

* use transformers kwargs instead

* update

* update

* update other models

* update

* updates

* update

* update

* update

* fix

* finally

* very small nits

* this fixes more tests

* fix other models as well!

* update modularqwen2

* update models based on qwen2

* update

* update

* remove the **flash stuff in favor of noraml kwargs

* update

* propagate gemma?

* remove output attentions

* propagate

* support cross attention edge case

* same

* test this

* fixes

* more fix

* update

* update

* fix conflicts

* update

* fix emu3

* fix emu3

* move the fix a bit

* quel enfer

* some fixes, loss_kwargs should never had been

* finish fixing gemma3n

* fix small lm3

* fix another one

* fix csm now

* fux csm and mistral

* fix mistral now

* small fixes

* fix janusss

* only for some models

* fixup

* phix phi3

* more fixes?

* dose this fix it?

* update

* holy shit it was just graph breaks

* protect torch

* updates

* fix samhq?

* fix moonshine

* more moonshine fixes, 3 failures left!

* nits

* generic needs to support more

* more fixes to moonshine!

* fix cross attention outputs!

* fix csm!

* nits

* fix stupid kosmos2

* current updates

* fixes

* use output recorder?

* nicer!

* a little bit of magic

* update

* fix protect

* fix

* small fixes

* protect import

* fix a bunch of more models

* fix fixups

* fix some of the last ones

* nit

* partly fix phi

* update

* fix import path

* make something that is fullgraph compatible just to be sure

* typing was wrong on llama so the rest was wrong as well

* fucking ugly but at least it is still exportable

* syle

* supposed to fix moonshine, it still breaks

* fix some default

* fix the last bits of sam

* update samhq

* more fixes to am hq

* nit

* fix all output+hidden states and output_attentions!

* fix?

* fix diffllama

* updates to fix initialization on the sam pips

* ups there was a bug

* fix the last sam hq test

* fix gotocr

* fix gotocr2!

* fixes

* skip stupid tests

* there was one left :)

* fixup

* fix fix copies issues with this test file

* fix copies for sam_hq

* rm some comments

* skip 2 more failing tests

* fix

* fix everything

* Apply suggestions from code review

Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>

* add more doc!

* fix public init

* fix modular qwen3

---------

Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>
This commit is contained in:
Arthur
2025-07-05 11:34:28 +02:00
committed by GitHub
parent e6a8063ef1
commit ca7e1a3756
146 changed files with 2045 additions and 5936 deletions

View File

@@ -123,7 +123,7 @@ from .utils import (
logging,
strtobool,
)
from .utils.generic import GeneralInterface
from .utils.generic import _CAN_RECORD_REGISTRY, GeneralInterface, OutputRecorder
from .utils.hub import create_and_tag_model_card, get_checkpoint_shard_files
from .utils.import_utils import (
ENV_VARS_TRUE_VALUES,
@@ -1925,7 +1925,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
- **is_parallelizable** (`bool`) -- A flag indicating whether this model supports model parallelization.
- **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
models, `pixel_values` for vision models and `input_values` for speech models).
"""
- **can_record_outputs** (dict):"""
config_class = None
base_model_prefix = ""
@@ -2006,6 +2006,50 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
# In practice, it means that they support attention interface functions, fully pass the kwargs
# through all modules up to the Attention layer, can slice logits with Tensor, and have a default TP plan
_supports_attention_backend = False
_can_record_outputs = None
@property
@torch._dynamo.allow_in_graph
def can_record_outputs(self) -> dict[str, OutputRecorder]:
"""
Maps output names (e.g., "attentions", "hidden_states")
to either:
- A module class (e.g., `LlamaDecoderLayer`), using default index conventions:
* index=0 for "hidden_states"
* index=1 for "attentions"
- Or an `OutputRecorder(...)` with `target_class`, optional `index`, and `layer_name`.
Examples:
These two are equivalent:
```python
_can_record_outputs = {
"attentions": LlamaAttention,
"hidden_states": LlamaDecoderLayer
}
_can_record_outputs = {
"attentions": OutputRecorder(LlamaAttention, index=1),
"hidden_states": OutputRecorder(LlamaDecoderLayer, index=0)
}
```
This means you can record outputs from the same class, by specifying a layer name. Before
collecting outputs, we check that they come from this layer.
If you have cross attention that come from `LlamaAttention` and self attention that also
come from `LlamaAttention` but from `self_attn` you can do this:
```python
class LlamaModel(PreTrainedModel):
_can_record_outputs = {
"attentions": OutputRecorder(LlamaAttention, index=1, layer-name="self_attn"),
"cross_attentions": OutputRecorder(LlamaAttention, index=1, layer_name="cross_attn")
}
```
"""
return self._can_record_outputs or {}
@property
def dummy_inputs(self) -> dict[str, torch.Tensor]:
@@ -2056,6 +2100,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
self._keep_in_fp32_modules_strict = copy.copy(self.__class__._keep_in_fp32_modules_strict)
self._no_split_modules = self._no_split_modules or []
_CAN_RECORD_REGISTRY[self] = self._can_record_outputs # added for executorch support only
def post_init(self):
"""