[refactor] set attention implementation (#38974)

* update

* fix some tests

* init from config, changes it in-place, add deepcopy in tests

* fix modernbert

* don't delete thsi config attr

* update

* style and copies

* skip tests in generation

* fix style

* accidentally removed flash-attn-3, revert

* docs

* forgot about flags set to False

* fix copies

* address a few comments

* fix copies

* custom code BC
This commit is contained in:
Raushan Turganbay
2025-07-15 12:34:06 +05:00
committed by GitHub
parent 6017f5e8ed
commit 8d6259b0b8
185 changed files with 451 additions and 776 deletions

View File

@@ -29,7 +29,15 @@ import pytest
from packaging import version
from parameterized import parameterized
from transformers import AutoConfig, AutoProcessor, AutoTokenizer, is_torch_available, logging, pipeline
from transformers import (
AutoConfig,
AutoProcessor,
AutoTokenizer,
PreTrainedModel,
is_torch_available,
logging,
pipeline,
)
from transformers.testing_utils import (
CaptureLogger,
is_flaky,
@@ -2007,7 +2015,7 @@ class GenerationTesterMixin:
max_new_tokens = 20
for dtype in (torch.float32, torch.float16):
model = model_class(config).to(torch_device).to(dtype).eval()
model = model_class(copy.deepcopy(config)).to(torch_device).to(dtype).eval()
inputs_dict = {
k: v.to(dtype) if isinstance(v, torch.Tensor) and torch.is_floating_point(v) else v
for k, v in inputs_dict.items()
@@ -2340,6 +2348,18 @@ class GenerationTesterMixin:
set_config_for_less_flaky_test(config)
model = model_class(config)
# If not all sub-models support flex, skip the test. We could potentially set not supported backbones
# to "eager" attention, leaving it for future updates on multimodality tests
sub_models_supporting_attn = [
getattr(module, support_flag[attn_implementation])
for name, module in model.named_modules()
if isinstance(module, PreTrainedModel) and name != ""
]
if not all(sub_models_supporting_attn) and len(sub_models_supporting_attn) > 0:
self.skipTest(
f"One of {model_class.__name__}'s backbones does not support `attn_implementation={attn_implementation}`"
)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
del model