Refactor the way we handle outputs for new llamas and new models (#39120)
* just update 2 files * update other models as well just making fix-copies * also add the changes needed to modeling utils * put this on the pretrained model instead * nits and fixes * update generic, fix to use config value * update other modelings * use transformers kwargs instead * update * update * update other models * update * updates * update * update * update * fix * finally * very small nits * this fixes more tests * fix other models as well! * update modularqwen2 * update models based on qwen2 * update * update * remove the **flash stuff in favor of noraml kwargs * update * propagate gemma? * remove output attentions * propagate * support cross attention edge case * same * test this * fixes * more fix * update * update * fix conflicts * update * fix emu3 * fix emu3 * move the fix a bit * quel enfer * some fixes, loss_kwargs should never had been * finish fixing gemma3n * fix small lm3 * fix another one * fix csm now * fux csm and mistral * fix mistral now * small fixes * fix janusss * only for some models * fixup * phix phi3 * more fixes? * dose this fix it? * update * holy shit it was just graph breaks * protect torch * updates * fix samhq? * fix moonshine * more moonshine fixes, 3 failures left! * nits * generic needs to support more * more fixes to moonshine! * fix cross attention outputs! * fix csm! * nits * fix stupid kosmos2 * current updates * fixes * use output recorder? * nicer! * a little bit of magic * update * fix protect * fix * small fixes * protect import * fix a bunch of more models * fix fixups * fix some of the last ones * nit * partly fix phi * update * fix import path * make something that is fullgraph compatible just to be sure * typing was wrong on llama so the rest was wrong as well * fucking ugly but at least it is still exportable * syle * supposed to fix moonshine, it still breaks * fix some default * fix the last bits of sam * update samhq * more fixes to am hq * nit * fix all output+hidden states and output_attentions! * fix? * fix diffllama * updates to fix initialization on the sam pips * ups there was a bug * fix the last sam hq test * fix gotocr * fix gotocr2! * fixes * skip stupid tests * there was one left :) * fixup * fix fix copies issues with this test file * fix copies for sam_hq * rm some comments * skip 2 more failing tests * fix * fix everything * Apply suggestions from code review Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com> Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> * add more doc! * fix public init * fix modular qwen3 --------- Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com> Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>
This commit is contained in:
@@ -123,7 +123,7 @@ from .utils import (
|
||||
logging,
|
||||
strtobool,
|
||||
)
|
||||
from .utils.generic import GeneralInterface
|
||||
from .utils.generic import _CAN_RECORD_REGISTRY, GeneralInterface, OutputRecorder
|
||||
from .utils.hub import create_and_tag_model_card, get_checkpoint_shard_files
|
||||
from .utils.import_utils import (
|
||||
ENV_VARS_TRUE_VALUES,
|
||||
@@ -1925,7 +1925,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
- **is_parallelizable** (`bool`) -- A flag indicating whether this model supports model parallelization.
|
||||
- **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
|
||||
models, `pixel_values` for vision models and `input_values` for speech models).
|
||||
"""
|
||||
- **can_record_outputs** (dict):"""
|
||||
|
||||
config_class = None
|
||||
base_model_prefix = ""
|
||||
@@ -2006,6 +2006,50 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
# In practice, it means that they support attention interface functions, fully pass the kwargs
|
||||
# through all modules up to the Attention layer, can slice logits with Tensor, and have a default TP plan
|
||||
_supports_attention_backend = False
|
||||
_can_record_outputs = None
|
||||
|
||||
@property
|
||||
@torch._dynamo.allow_in_graph
|
||||
def can_record_outputs(self) -> dict[str, OutputRecorder]:
|
||||
"""
|
||||
Maps output names (e.g., "attentions", "hidden_states")
|
||||
to either:
|
||||
- A module class (e.g., `LlamaDecoderLayer`), using default index conventions:
|
||||
* index=0 for "hidden_states"
|
||||
* index=1 for "attentions"
|
||||
- Or an `OutputRecorder(...)` with `target_class`, optional `index`, and `layer_name`.
|
||||
|
||||
Examples:
|
||||
These two are equivalent:
|
||||
|
||||
```python
|
||||
_can_record_outputs = {
|
||||
"attentions": LlamaAttention,
|
||||
"hidden_states": LlamaDecoderLayer
|
||||
}
|
||||
|
||||
_can_record_outputs = {
|
||||
"attentions": OutputRecorder(LlamaAttention, index=1),
|
||||
"hidden_states": OutputRecorder(LlamaDecoderLayer, index=0)
|
||||
}
|
||||
```
|
||||
|
||||
This means you can record outputs from the same class, by specifying a layer name. Before
|
||||
collecting outputs, we check that they come from this layer.
|
||||
|
||||
If you have cross attention that come from `LlamaAttention` and self attention that also
|
||||
come from `LlamaAttention` but from `self_attn` you can do this:
|
||||
|
||||
```python
|
||||
class LlamaModel(PreTrainedModel):
|
||||
_can_record_outputs = {
|
||||
"attentions": OutputRecorder(LlamaAttention, index=1, layer-name="self_attn"),
|
||||
"cross_attentions": OutputRecorder(LlamaAttention, index=1, layer_name="cross_attn")
|
||||
}
|
||||
|
||||
```
|
||||
"""
|
||||
return self._can_record_outputs or {}
|
||||
|
||||
@property
|
||||
def dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
@@ -2056,6 +2100,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
self._keep_in_fp32_modules_strict = copy.copy(self.__class__._keep_in_fp32_modules_strict)
|
||||
|
||||
self._no_split_modules = self._no_split_modules or []
|
||||
_CAN_RECORD_REGISTRY[self] = self._can_record_outputs # added for executorch support only
|
||||
|
||||
def post_init(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user