Modular: support for importing functions from any file (#35692)

* fix function imports

* improve comment

* Update modeling_switch_function.py

* make checks more robust

* improvement

* rename

* final test update
This commit is contained in:
Cyril Vallez
2025-01-16 16:37:53 +00:00
committed by GitHub
parent 8ebe9d7166
commit 91be6a5eb2
10 changed files with 305 additions and 43 deletions

View File

@@ -45,13 +45,8 @@ class SuperRMSNorm(nn.Module):
class SuperRotaryEmbedding(nn.Module):
def __init__(
self,
config: SuperConfig,
device=None,
):
def __init__(self, config: SuperConfig, device=None):
super().__init__()
self.rope_kwargs = {}
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
@@ -63,7 +58,7 @@ class SuperRotaryEmbedding(nn.Module):
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@@ -75,13 +70,14 @@ class SuperRotaryEmbedding(nn.Module):
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len
@@ -356,6 +352,7 @@ class SuperPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True