Refactor Attention implementation for ViT-based models (#36545)
* Refactor vit attention * Refactor ViT-based models * 🚨🚨🚨 Fix prefix for DPT * Update params order * trigger tests * Fix Dinov2 attention * Fix DPT attention impl propagation for backbone config * Common test fix: config is modif. inplace - avoid it * view->reshape * Fixup * Fixup * Enable IJepa FA2 * Add FA2 in corresponding model docs
This commit is contained in:
committed by
GitHub
parent
730d2a52e7
commit
66291778dd
@@ -315,8 +315,6 @@ class ModelTesterMixin:
|
||||
return inputs_dict
|
||||
|
||||
def test_save_load(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
def check_save_load(out1, out2):
|
||||
# make sure we don't have nans
|
||||
out_2 = out2.cpu().numpy()
|
||||
@@ -330,6 +328,7 @@ class ModelTesterMixin:
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
@@ -508,16 +507,16 @@ class ModelTesterMixin:
|
||||
|
||||
@is_flaky(description="low likelihood of failure, reason not yet discovered")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if config.__class__ not in MODEL_MAPPING:
|
||||
self.skipTest(reason=f"{config.__class__.__name__} not in MODEL_MAPPING")
|
||||
|
||||
base_class = MODEL_MAPPING[config.__class__]
|
||||
|
||||
if isinstance(base_class, tuple):
|
||||
base_class = base_class[0]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if config.__class__ not in MODEL_MAPPING:
|
||||
self.skipTest(reason=f"{config.__class__.__name__} not in MODEL_MAPPING")
|
||||
|
||||
base_class = MODEL_MAPPING[config.__class__]
|
||||
|
||||
if isinstance(base_class, tuple):
|
||||
base_class = base_class[0]
|
||||
|
||||
if model_class == base_class:
|
||||
continue
|
||||
|
||||
@@ -2228,9 +2227,9 @@ class ModelTesterMixin:
|
||||
def test_correct_missing_keys(self):
|
||||
if not self.test_missing_keys:
|
||||
self.skipTest(reason="test_missing_keys is set to `False`")
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
base_model_prefix = model.base_model_prefix
|
||||
|
||||
@@ -2287,8 +2286,8 @@ class ModelTesterMixin:
|
||||
|
||||
@require_safetensors
|
||||
def test_can_use_safetensors(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model_tied = model_class(config)
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
try:
|
||||
@@ -2323,9 +2322,9 @@ class ModelTesterMixin:
|
||||
)
|
||||
|
||||
def test_load_save_without_tied_weights(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.tie_word_embeddings = False
|
||||
for model_class in self.all_model_classes:
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.tie_word_embeddings = False
|
||||
model = model_class(config)
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
model.save_pretrained(d)
|
||||
@@ -2373,8 +2372,8 @@ class ModelTesterMixin:
|
||||
)
|
||||
|
||||
def test_model_weights_reload_no_missing_tied_weights(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir)
|
||||
|
||||
Reference in New Issue
Block a user