Fix GPT2 with cross attention (#39754)

* fix

* use new mask API

* style

* fix copies and attention tests

* fix head pruning tests
This commit is contained in:
Raushan Turganbay
2025-07-29 15:40:31 +02:00
committed by GitHub
parent dfd616e658
commit ccb2e0e03b
4 changed files with 89 additions and 178 deletions

View File

@@ -1775,6 +1775,7 @@ class ModelTesterMixin:
model = model_class(config=config)
model.to(torch_device)
model.eval()
model.set_attn_implementation("eager")
heads_to_prune = {
0: list(range(1, self.model_tester.num_attention_heads)),
-1: [0],
@@ -1808,6 +1809,7 @@ class ModelTesterMixin:
model = model_class(config=config)
model.to(torch_device)
model.eval()
model.set_attn_implementation("eager")
heads_to_prune = {
0: list(range(1, self.model_tester.num_attention_heads)),
-1: [0],
@@ -1816,7 +1818,7 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as temp_dir_name:
model.save_pretrained(temp_dir_name)
model = model_class.from_pretrained(temp_dir_name)
model = model_class.from_pretrained(temp_dir_name, attn_implementation="eager")
model.to(torch_device)
with torch.no_grad():
@@ -1852,6 +1854,7 @@ class ModelTesterMixin:
model = model_class(config=config)
model.to(torch_device)
model.eval()
model.set_attn_implementation("eager")
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
@@ -1884,6 +1887,7 @@ class ModelTesterMixin:
model = model_class(config=config)
model.to(torch_device)
model.eval()
model.set_attn_implementation("eager")
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
@@ -1894,7 +1898,7 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as temp_dir_name:
model.save_pretrained(temp_dir_name)
model = model_class.from_pretrained(temp_dir_name)
model = model_class.from_pretrained(temp_dir_name, attn_implementation="eager")
model.to(torch_device)
with torch.no_grad():