[modular] Do not track imports in functions (#36279)

* Add check

* just check for function

* Update examples
This commit is contained in:
Cyril Vallez
2025-02-25 10:29:47 +01:00
committed by GitHub
parent 4b5cf5496d
commit bc65f3fc1c
10 changed files with 82 additions and 33 deletions

View File

@@ -356,6 +356,7 @@ class DummyPreTrainedModel(PreTrainedModel):
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_attention_backend = True
def _init_weights(self, module):
std = self.config.initializer_range
@@ -698,7 +699,9 @@ class DummyModel(DummyPreTrainedModel):
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
causal_mask.device
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype