[refactor] set attention implementation (#38974)
* update * fix some tests * init from config, changes it in-place, add deepcopy in tests * fix modernbert * don't delete thsi config attr * update * style and copies * skip tests in generation * fix style * accidentally removed flash-attn-3, revert * docs * forgot about flags set to False * fix copies * address a few comments * fix copies * custom code BC
This commit is contained in:
committed by
GitHub
parent
6017f5e8ed
commit
8d6259b0b8
@@ -29,7 +29,15 @@ import pytest
|
||||
from packaging import version
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import AutoConfig, AutoProcessor, AutoTokenizer, is_torch_available, logging, pipeline
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
PreTrainedModel,
|
||||
is_torch_available,
|
||||
logging,
|
||||
pipeline,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
CaptureLogger,
|
||||
is_flaky,
|
||||
@@ -2007,7 +2015,7 @@ class GenerationTesterMixin:
|
||||
max_new_tokens = 20
|
||||
|
||||
for dtype in (torch.float32, torch.float16):
|
||||
model = model_class(config).to(torch_device).to(dtype).eval()
|
||||
model = model_class(copy.deepcopy(config)).to(torch_device).to(dtype).eval()
|
||||
inputs_dict = {
|
||||
k: v.to(dtype) if isinstance(v, torch.Tensor) and torch.is_floating_point(v) else v
|
||||
for k, v in inputs_dict.items()
|
||||
@@ -2340,6 +2348,18 @@ class GenerationTesterMixin:
|
||||
set_config_for_less_flaky_test(config)
|
||||
model = model_class(config)
|
||||
|
||||
# If not all sub-models support flex, skip the test. We could potentially set not supported backbones
|
||||
# to "eager" attention, leaving it for future updates on multimodality tests
|
||||
sub_models_supporting_attn = [
|
||||
getattr(module, support_flag[attn_implementation])
|
||||
for name, module in model.named_modules()
|
||||
if isinstance(module, PreTrainedModel) and name != ""
|
||||
]
|
||||
if not all(sub_models_supporting_attn) and len(sub_models_supporting_attn) > 0:
|
||||
self.skipTest(
|
||||
f"One of {model_class.__name__}'s backbones does not support `attn_implementation={attn_implementation}`"
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
del model
|
||||
|
||||
Reference in New Issue
Block a user