Fix GPT2 with cross attention (#39754)
* fix * use new mask API * style * fix copies and attention tests * fix head pruning tests
This commit is contained in:
committed by
GitHub
parent
dfd616e658
commit
ccb2e0e03b
@@ -1775,6 +1775,7 @@ class ModelTesterMixin:
|
||||
model = model_class(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
model.set_attn_implementation("eager")
|
||||
heads_to_prune = {
|
||||
0: list(range(1, self.model_tester.num_attention_heads)),
|
||||
-1: [0],
|
||||
@@ -1808,6 +1809,7 @@ class ModelTesterMixin:
|
||||
model = model_class(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
model.set_attn_implementation("eager")
|
||||
heads_to_prune = {
|
||||
0: list(range(1, self.model_tester.num_attention_heads)),
|
||||
-1: [0],
|
||||
@@ -1816,7 +1818,7 @@ class ModelTesterMixin:
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir_name:
|
||||
model.save_pretrained(temp_dir_name)
|
||||
model = model_class.from_pretrained(temp_dir_name)
|
||||
model = model_class.from_pretrained(temp_dir_name, attn_implementation="eager")
|
||||
model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
@@ -1852,6 +1854,7 @@ class ModelTesterMixin:
|
||||
model = model_class(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
model.set_attn_implementation("eager")
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
@@ -1884,6 +1887,7 @@ class ModelTesterMixin:
|
||||
model = model_class(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
model.set_attn_implementation("eager")
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
@@ -1894,7 +1898,7 @@ class ModelTesterMixin:
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir_name:
|
||||
model.save_pretrained(temp_dir_name)
|
||||
model = model_class.from_pretrained(temp_dir_name)
|
||||
model = model_class.from_pretrained(temp_dir_name, attn_implementation="eager")
|
||||
model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
Reference in New Issue
Block a user