Fix AttentionInterface following feedback (#37010)
* up * typo * update doc * Update attention_interface.md
This commit is contained in:
@@ -5917,7 +5917,7 @@ class AttentionInterface(MutableMapping):
|
||||
"""
|
||||
Dict-like object keeping track of allowed attention functions. You can easily add a new attention function
|
||||
with a call to `register()`. If a model needs to locally overwrite an existing attention function, say `sdpa`,
|
||||
it needs to declare a new instance of this class inside the `modeling.py`, and declare it on that instance.
|
||||
it needs to declare a new instance of this class inside the `modeling_<model>.py`, and declare it on that instance.
|
||||
"""
|
||||
|
||||
# Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
|
||||
@@ -5946,7 +5946,7 @@ class AttentionInterface(MutableMapping):
|
||||
|
||||
def __iter__(self):
|
||||
# Ensure we use all keys, with the overwritten ones on top
|
||||
return iter(self._global_mapping.update(self._local_mapping))
|
||||
return iter({**self._global_mapping, **self._local_mapping})
|
||||
|
||||
def __len__(self):
|
||||
return len(self._global_mapping.keys() | self._local_mapping.keys())
|
||||
@@ -5956,7 +5956,7 @@ class AttentionInterface(MutableMapping):
|
||||
cls._global_mapping.update({key: value})
|
||||
|
||||
def valid_keys(self) -> List[str]:
|
||||
return list(self._global_mapping.keys() | self._local_mapping.keys())
|
||||
return list(self.keys())
|
||||
|
||||
|
||||
# Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones
|
||||
|
||||
Reference in New Issue
Block a user