Fix AttentionInterface following feedback (#37010)

* up

* typo

* update doc

* Update attention_interface.md
This commit is contained in:
Cyril Vallez
2025-03-28 18:00:35 +01:00
committed by GitHub
parent a86dad56bc
commit 2bea6bf24e
2 changed files with 33 additions and 11 deletions

View File

@@ -5917,7 +5917,7 @@ class AttentionInterface(MutableMapping):
"""
Dict-like object keeping track of allowed attention functions. You can easily add a new attention function
with a call to `register()`. If a model needs to locally overwrite an existing attention function, say `sdpa`,
it needs to declare a new instance of this class inside the `modeling.py`, and declare it on that instance.
it needs to declare a new instance of this class inside the `modeling_<model>.py`, and declare it on that instance.
"""
# Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
@@ -5946,7 +5946,7 @@ class AttentionInterface(MutableMapping):
def __iter__(self):
# Ensure we use all keys, with the overwritten ones on top
return iter(self._global_mapping.update(self._local_mapping))
return iter({**self._global_mapping, **self._local_mapping})
def __len__(self):
return len(self._global_mapping.keys() | self._local_mapping.keys())
@@ -5956,7 +5956,7 @@ class AttentionInterface(MutableMapping):
cls._global_mapping.update({key: value})
def valid_keys(self) -> List[str]:
return list(self._global_mapping.keys() | self._local_mapping.keys())
return list(self.keys())
# Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones