🔴🔴🔴 [Attention] Refactor Attention Interface for Bart-based Models (#38108)
* starting attn refactor for encoder decoder models via bart (eager + sdpa) * flash attention works, remove unnecessary code * flex attention support for bart!, gotta check if the renaming is not too aggressive * some comments * skip flex grad test for standalone as done with the other test * revert flex attn rename (for now), sdpa simplify, and todos * more todos * refactor mask creation for reuse * modular attempt at biogpt * first batch of other models * fix attn dropout * fix autoformer copies * hubert * another batch of models * copies/style + last round of bart models --> whisper next? * remove unnecessary _reshape function and remove copy to whisper * add skip for decoder-only models out of enc-dec (same as in bart) * bring back licences * remove comment, added to pr read instead * mostly docs * disable sew flex attn as it's unclear attn mask for now * oops * test fixes for enc-dec * torch fx fixes + try at flex attn * skip on mbart * some more fixes * musicgen skip / delete old attn class logic + sdpa compose compile skip * disable flex attn for musicgen, not worth the effort * more fixes and style * flex attention test for dropout and encoder decoder that dont have main input names * informer fixes * the weirdest thing I've encountered yet... * style * remove empty tensor attempt, found core root in previous commits * disable time series due to tests being very text centric on inputs * add speech to text to be ignoring the other attns, also due to tests * update docs * remaining issues resolved ? * update docs for current state --> nllb moe and pegasus x sdpa is questionable :D * some models have not set the is_causal flag... * change dtype in softmax tol old behaviour + some modular fixes * I hate it but it is what it is * fixes from main for bart * forgot this one * some model fixes * style * current status * marian works now * fixing some copies * some copy fixes + time series x informer * last models possibly and fixes on style/copies * some post merge fixes * more fixes * make attention interface callable and move warnings there * style lol * add comment to "unsupported" * remove callable interface and change interface warnings + some copies * fix * ternary is ugly af, make it simpler * how did that happen * fix flex attn test * failing the test * no more fallback! fixing copies next * style + attn fixed * fixing copies and mask creation * wrong copy * fixup tests and disable flex attn for now * fixup last tests?
This commit is contained in:
@@ -1521,3 +1521,7 @@ class BartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, un
|
||||
@unittest.skip(reason="Decoder cannot keep gradients")
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
return
|
||||
|
||||
@unittest.skip(reason="Decoder cannot keep gradients")
|
||||
def test_flex_attention_with_grads():
|
||||
return
|
||||
|
||||
@@ -554,3 +554,7 @@ class BlenderbotStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMix
|
||||
@unittest.skip(reason="decoder cannot keep gradients")
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
return
|
||||
|
||||
@unittest.skip(reason="Decoder cannot keep gradients")
|
||||
def test_flex_attention_with_grads():
|
||||
return
|
||||
|
||||
@@ -563,3 +563,7 @@ class BlenderbotSmallStandaloneDecoderModelTest(ModelTesterMixin, GenerationTest
|
||||
@unittest.skip(reason="decoder cannot keep gradients")
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
return
|
||||
|
||||
@unittest.skip(reason="Decoder cannot keep gradients")
|
||||
def test_flex_attention_with_grads():
|
||||
return
|
||||
|
||||
@@ -420,6 +420,9 @@ class Data2VecAudioModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Tes
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = True
|
||||
|
||||
# force eager attention to support output attentions
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
# no need to test all models as different heads yield the same functionality
|
||||
model_class = self.all_model_classes[0]
|
||||
model = model_class(config)
|
||||
|
||||
@@ -412,6 +412,10 @@ class EncoderDecoderMixin:
|
||||
labels,
|
||||
**kwargs,
|
||||
):
|
||||
# force eager attention to support output attentions
|
||||
config._attn_implementation = "eager"
|
||||
decoder_config._attn_implementation = "eager"
|
||||
|
||||
# make the decoder inputs a different shape from the encoder inputs to harden the test
|
||||
decoder_input_ids = decoder_input_ids[:, :-1]
|
||||
decoder_attention_mask = decoder_attention_mask[:, :-1]
|
||||
@@ -445,6 +449,10 @@ class EncoderDecoderMixin:
|
||||
# config file. Contrarily to most models, changing the model's config won't work -- the defaults are loaded
|
||||
# from the inner models' configurations.
|
||||
|
||||
# force eager attention to support output attentions
|
||||
config._attn_implementation = "eager"
|
||||
decoder_config._attn_implementation = "eager"
|
||||
|
||||
decoder_input_ids = decoder_input_ids[:, :-1]
|
||||
decoder_attention_mask = decoder_attention_mask[:, :-1]
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
|
||||
@@ -370,6 +370,9 @@ class HubertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = True
|
||||
|
||||
# force eager attention to support output attentions
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
# no need to test all models as different heads yield the same functionality
|
||||
model_class = self.all_model_classes[0]
|
||||
model = model_class(config)
|
||||
@@ -632,6 +635,9 @@ class HubertRobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = True
|
||||
|
||||
# force eager attention to support output attentions
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
# no need to test all models as different heads yield the same functionality
|
||||
model_class = self.all_model_classes[0]
|
||||
model = model_class(config)
|
||||
|
||||
@@ -357,7 +357,8 @@ class M2M100ModelIntegrationTests(unittest.TestCase):
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
# change to expected output here
|
||||
expected_slice = torch.tensor(
|
||||
[[-0.7780, -0.1676, 0.1038], [-6.7556, -1.3992, 0.0567], [-7.5383, -0.5920, -0.2779]], device=torch_device
|
||||
[[[-0.7780, -0.1676, 0.1038], [-6.7556, -1.3992, 0.0567], [-7.5383, -0.5920, -0.2779]]],
|
||||
device=torch_device,
|
||||
)
|
||||
torch.testing.assert_close(output[:, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE)
|
||||
|
||||
@@ -374,7 +375,8 @@ class M2M100ModelIntegrationTests(unittest.TestCase):
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
# change to expected output here
|
||||
expected_slice = torch.tensor(
|
||||
[[-1.0448, -1.0411, 3.7992], [-3.2191, -3.2386, -1.3451], [-3.6210, -3.5993, 0.4925]], device=torch_device
|
||||
[[[-1.0448, -1.0411, 3.7992], [-3.2191, -3.2386, -1.3451], [-3.6210, -3.5993, 0.4925]]],
|
||||
device=torch_device,
|
||||
)
|
||||
torch.testing.assert_close(output[:, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE)
|
||||
|
||||
@@ -426,7 +428,7 @@ class M2M100ModelIntegrationTests(unittest.TestCase):
|
||||
Overwriting the common test as the test is flaky on tiny models
|
||||
"""
|
||||
model = M2M100ForConditionalGeneration.from_pretrained(
|
||||
"facebook/m2m100_418M", torch_dtype=torch.float16, attn_implementation="flash_attention_2"
|
||||
"facebook/m2m100_418M", attn_implementation="flash_attention_2", torch_dtype=torch.float16
|
||||
).to(torch_device)
|
||||
|
||||
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", src_lang="fr", tgt_lang="en")
|
||||
|
||||
@@ -850,3 +850,7 @@ class MarianStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin,
|
||||
@unittest.skip(reason="Decoder cannot keep gradients")
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
return
|
||||
|
||||
@unittest.skip(reason="Decoder cannot keep gradients")
|
||||
def test_flex_attention_with_grads():
|
||||
return
|
||||
|
||||
@@ -735,3 +735,7 @@ class MBartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, u
|
||||
@unittest.skip(reason="Decoder cannot retain gradients")
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
return
|
||||
|
||||
@unittest.skip(reason="Decoder cannot retain gradients")
|
||||
def test_flex_attention_with_grads(self):
|
||||
return
|
||||
|
||||
@@ -728,6 +728,9 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
def test_attention_outputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# force eager attention to support output attentions
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
self.check_musicgen_model_output_attentions(model_class, config, **inputs_dict)
|
||||
self.check_musicgen_model_output_attentions_from_config(model_class, config, **inputs_dict)
|
||||
@@ -805,6 +808,9 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
config.text_encoder.output_attentions = True
|
||||
config.decoder.output_attentions = True
|
||||
|
||||
# force eager attention to support output attentions
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
# no need to test all models as different heads yield the same functionality
|
||||
model_class = self.all_model_classes[0]
|
||||
model = model_class(config)
|
||||
@@ -1036,30 +1042,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attn_2_conversion(self):
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
torch_dtype=torch.float16,
|
||||
attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None},
|
||||
).to(torch_device)
|
||||
|
||||
for _, module in model.named_modules():
|
||||
if "FlashAttention" in module.__class__.__name__:
|
||||
return
|
||||
|
||||
self.assertTrue(False, "FlashAttention2 modules not found in model")
|
||||
self.skipTest(reason="Musicgen doesn't use the MusicgenFlashAttention2 class method.")
|
||||
|
||||
@require_torch_sdpa
|
||||
@require_torch_gpu
|
||||
@@ -1234,18 +1217,6 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
self.assertTrue(model_eager.decoder.config._attn_implementation == "eager")
|
||||
self.assertTrue(model_eager.config._attn_implementation == "eager")
|
||||
|
||||
for name, submodule in model_eager.named_modules():
|
||||
if "SdpaAttention" in submodule.__class__.__name__:
|
||||
raise ValueError("The eager model should not have SDPA attention layers")
|
||||
|
||||
has_sdpa = False
|
||||
for name, submodule in model_sdpa.named_modules():
|
||||
if "SdpaAttention" in submodule.__class__.__name__:
|
||||
has_sdpa = True
|
||||
break
|
||||
if not has_sdpa and model_sdpa.config.model_type != "falcon":
|
||||
raise ValueError("The SDPA model should have SDPA attention layers")
|
||||
|
||||
def test_requires_grad_with_frozen_encoders(self):
|
||||
config = self.model_tester.get_config()
|
||||
for model_class in self.all_model_classes:
|
||||
@@ -1276,6 +1247,10 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
def test_generation_tester_mixin_inheritance(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason=("MusicGen has a set of composite models which might not have SDPA themselves, e.g. T5."))
|
||||
def test_sdpa_can_compile_dynamic(self):
|
||||
pass
|
||||
|
||||
|
||||
def get_bip_bip(bip_duration=0.125, duration=0.5, sample_rate=32000):
|
||||
"""Produces a series of 'bip bip' sounds at a given frequency."""
|
||||
|
||||
@@ -731,6 +731,9 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
def test_attention_outputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# force eager attention to support output attentions
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
self.check_musicgen_melody_model_output_attentions(model_class, config, **inputs_dict)
|
||||
self.check_musicgen_melody_model_output_attentions_from_config(model_class, config, **inputs_dict)
|
||||
@@ -807,6 +810,9 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
config.text_encoder.output_attentions = True
|
||||
config.decoder.output_attentions = True
|
||||
|
||||
# force eager attention to support output attentions
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
# no need to test all models as different heads yield the same functionality
|
||||
model_class = self.all_model_classes[0]
|
||||
model = model_class(config)
|
||||
@@ -1036,30 +1042,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attn_2_conversion(self):
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
||||
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model = model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
torch_dtype=torch.float16,
|
||||
attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None},
|
||||
).to(torch_device)
|
||||
|
||||
for _, module in model.named_modules():
|
||||
if "FlashAttention" in module.__class__.__name__:
|
||||
return
|
||||
|
||||
self.assertTrue(False, "FlashAttention2 modules not found in model")
|
||||
self.skipTest(reason="MusicgenMelody doesn't use the MusicgenMelodyFlashAttention2 class method.")
|
||||
|
||||
@require_torch_sdpa
|
||||
@require_torch_gpu
|
||||
@@ -1234,18 +1217,6 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
self.assertTrue(model_eager.decoder.config._attn_implementation == "eager")
|
||||
self.assertTrue(model_eager.config._attn_implementation == "eager")
|
||||
|
||||
for name, submodule in model_eager.named_modules():
|
||||
if "SdpaAttention" in submodule.__class__.__name__:
|
||||
raise ValueError("The eager model should not have SDPA attention layers")
|
||||
|
||||
has_sdpa = False
|
||||
for name, submodule in model_sdpa.named_modules():
|
||||
if "SdpaAttention" in submodule.__class__.__name__:
|
||||
has_sdpa = True
|
||||
break
|
||||
if not has_sdpa and model_sdpa.config.model_type != "falcon":
|
||||
raise ValueError("The SDPA model should have SDPA attention layers")
|
||||
|
||||
def test_requires_grad_with_frozen_encoders(self):
|
||||
config = self.model_tester.get_config()
|
||||
for model_class in self.all_model_classes:
|
||||
@@ -1276,6 +1247,10 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
def test_generation_tester_mixin_inheritance(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason=("MusicGen has a set of composite models which might not have SDPA themselves, e.g. T5."))
|
||||
def test_sdpa_can_compile_dynamic(self):
|
||||
pass
|
||||
|
||||
|
||||
# Copied from tests.models.musicgen.test_modeling_musicgen.get_bip_bip
|
||||
def get_bip_bip(bip_duration=0.125, duration=0.5, sample_rate=32000):
|
||||
|
||||
@@ -480,7 +480,7 @@ class PatchTSMixerModelIntegrationTests(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor([[[[-0.9106]],[[1.5326]],[[-0.8245]],[[0.7439]],[[-0.7830]],[[2.6256]],[[-0.6485]],]],device=torch_device) # fmt: skip
|
||||
expected_slice = torch.tensor([[[-0.9106]],[[1.5326]],[[-0.8245]],[[0.7439]],[[-0.7830]],[[2.6256]],[[-0.6485]],],device=torch_device) # fmt: skip
|
||||
torch.testing.assert_close(output[0, :7, :1, :1], expected_slice, rtol=TOLERANCE, atol=TOLERANCE)
|
||||
|
||||
def test_forecasting_head(self):
|
||||
|
||||
@@ -597,3 +597,7 @@ class PegasusStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin,
|
||||
@unittest.skip(reason="Decoder cannot keep gradients")
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
return
|
||||
|
||||
@unittest.skip(reason="Decoder cannot keep gradients")
|
||||
def test_flex_attention_with_grads():
|
||||
return
|
||||
|
||||
@@ -590,7 +590,7 @@ class PegasusXModelIntegrationTests(unittest.TestCase):
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
# change to expected output here
|
||||
expected_slice = torch.tensor(
|
||||
[[0.0702, -0.1552, 0.1192], [0.0836, -0.1848, 0.1304], [0.0673, -0.1686, 0.1045]], device=torch_device
|
||||
[[[0.0702, -0.1552, 0.1192], [0.0836, -0.1848, 0.1304], [0.0673, -0.1686, 0.1045]]], device=torch_device
|
||||
)
|
||||
|
||||
torch.testing.assert_close(output[:, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE)
|
||||
@@ -608,7 +608,8 @@ class PegasusXModelIntegrationTests(unittest.TestCase):
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
# change to expected output here
|
||||
expected_slice = torch.tensor(
|
||||
[[0.0, 9.5705185, 1.5897303], [0.0, 9.833374, 1.5828674], [0.0, 10.429961, 1.5643371]], device=torch_device
|
||||
[[[0.0, 9.5705185, 1.5897303], [0.0, 9.833374, 1.5828674], [0.0, 10.429961, 1.5643371]]],
|
||||
device=torch_device,
|
||||
)
|
||||
torch.testing.assert_close(output[:, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE)
|
||||
|
||||
@@ -635,8 +636,7 @@ class PegasusXModelIntegrationTests(unittest.TestCase):
|
||||
batch_input,
|
||||
max_length=512,
|
||||
padding="max_length",
|
||||
truncation_strategy="only_first",
|
||||
truncation=True,
|
||||
truncation="only_first",
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
@@ -872,3 +872,7 @@ class PegasusXStandaloneDecoderModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
@unittest.skip(reason="Decoder cannot keep gradients")
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
return
|
||||
|
||||
@unittest.skip(reason="Decoder cannot keep gradients")
|
||||
def test_flex_attention_with_grads():
|
||||
return
|
||||
|
||||
@@ -677,3 +677,7 @@ class PLBartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin,
|
||||
@unittest.skip(reason="Decoder cannot keep gradients")
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
return
|
||||
|
||||
@unittest.skip(reason="Decoder cannot keep gradients")
|
||||
def test_flex_attention_with_grads():
|
||||
return
|
||||
|
||||
@@ -342,6 +342,9 @@ class SEWModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = True
|
||||
|
||||
# force eager attention to support output attentions
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
# no need to test all models as different heads yield the same functionality
|
||||
model_class = self.all_model_classes[0]
|
||||
model = model_class(config)
|
||||
|
||||
@@ -300,6 +300,10 @@ class EncoderDecoderMixin:
|
||||
input_features=None,
|
||||
**kwargs,
|
||||
):
|
||||
# force eager attention to support output attentions
|
||||
config._attn_implementation = "eager"
|
||||
decoder_config._attn_implementation = "eager"
|
||||
|
||||
# make the decoder inputs a different shape from the encoder inputs to harden the test
|
||||
decoder_input_ids = decoder_input_ids[:, :-1]
|
||||
decoder_attention_mask = decoder_attention_mask[:, :-1]
|
||||
|
||||
@@ -383,6 +383,9 @@ class UniSpeechRobustModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.T
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = True
|
||||
|
||||
# force eager attention to support output attentions
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
# no need to test all models as different heads yield the same functionality
|
||||
model_class = self.all_model_classes[0]
|
||||
model = model_class(config)
|
||||
|
||||
@@ -423,6 +423,9 @@ class UniSpeechSatModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Test
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = True
|
||||
|
||||
# force eager attention to support output attentions
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
# no need to test all models as different heads yield the same functionality
|
||||
model_class = self.all_model_classes[0]
|
||||
model = model_class(config)
|
||||
@@ -632,6 +635,9 @@ class UniSpeechSatRobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = True
|
||||
|
||||
# force eager attention to support output attentions
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
# no need to test all models as different heads yield the same functionality
|
||||
model_class = self.all_model_classes[0]
|
||||
model = model_class(config)
|
||||
|
||||
@@ -246,6 +246,10 @@ class EncoderDecoderMixin:
|
||||
pixel_values=None,
|
||||
**kwargs,
|
||||
):
|
||||
# force eager attention to support output attentions
|
||||
config._attn_implementation = "eager"
|
||||
decoder_config._attn_implementation = "eager"
|
||||
|
||||
# make the decoder inputs a different shape from the encoder inputs to harden the test
|
||||
decoder_input_ids = decoder_input_ids[:, :-1]
|
||||
decoder_attention_mask = decoder_attention_mask[:, :-1]
|
||||
@@ -480,6 +484,10 @@ class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||
pixel_values=None,
|
||||
**kwargs,
|
||||
):
|
||||
# force eager attention to support output attentions
|
||||
config._attn_implementation = "eager"
|
||||
decoder_config._attn_implementation = "eager"
|
||||
|
||||
# make the decoder inputs a different shape from the encoder inputs to harden the test
|
||||
decoder_input_ids = decoder_input_ids[:, :-1]
|
||||
decoder_attention_mask = decoder_attention_mask[:, :-1]
|
||||
@@ -670,6 +678,10 @@ class Swin2BartModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||
pixel_values=None,
|
||||
**kwargs,
|
||||
):
|
||||
# force eager attention to support output attentions
|
||||
config._attn_implementation = "eager"
|
||||
decoder_config._attn_implementation = "eager"
|
||||
|
||||
# make the decoder inputs a different shape from the encoder inputs to harden the test
|
||||
decoder_input_ids = decoder_input_ids[:, :-1]
|
||||
decoder_attention_mask = decoder_attention_mask[:, :-1]
|
||||
@@ -807,6 +819,10 @@ class LayoutLMv32TrOCR(EncoderDecoderMixin, unittest.TestCase):
|
||||
labels=None,
|
||||
**kwargs,
|
||||
):
|
||||
# force eager attention to support output attentions
|
||||
config._attn_implementation = "eager"
|
||||
decoder_config._attn_implementation = "eager"
|
||||
|
||||
# make the decoder inputs a different shape from the encoder inputs to harden the test
|
||||
decoder_input_ids = decoder_input_ids[:, :-1]
|
||||
decoder_attention_mask = decoder_attention_mask[:, :-1]
|
||||
@@ -929,6 +945,10 @@ class VIT2GPT2Test(EncoderDecoderMixin, unittest.TestCase):
|
||||
labels=None,
|
||||
**kwargs,
|
||||
):
|
||||
# force eager attention to support output attentions
|
||||
config._attn_implementation = "eager"
|
||||
decoder_config._attn_implementation = "eager"
|
||||
|
||||
# make the decoder inputs a different shape from the encoder inputs to harden the test
|
||||
decoder_input_ids = decoder_input_ids[:, :-1]
|
||||
decoder_attention_mask = decoder_attention_mask[:, :-1]
|
||||
@@ -1047,6 +1067,10 @@ class Donut2GPT2Test(EncoderDecoderMixin, unittest.TestCase):
|
||||
labels=None,
|
||||
**kwargs,
|
||||
):
|
||||
# force eager attention to support output attentions
|
||||
config._attn_implementation = "eager"
|
||||
decoder_config._attn_implementation = "eager"
|
||||
|
||||
# make the decoder inputs a different shape from the encoder inputs to harden the test
|
||||
decoder_input_ids = decoder_input_ids[:, :-1]
|
||||
decoder_attention_mask = decoder_attention_mask[:, :-1]
|
||||
|
||||
@@ -570,6 +570,9 @@ class Wav2Vec2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = True
|
||||
|
||||
# force eager attention to support output attentions
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
# no need to test all models as different heads yield the same functionality
|
||||
model_class = self.all_model_classes[0]
|
||||
model = model_class(config)
|
||||
@@ -917,6 +920,9 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = True
|
||||
|
||||
# force eager attention to support output attentions
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
# no need to test all models as different heads yield the same functionality
|
||||
model_class = self.all_model_classes[0]
|
||||
model = model_class(config)
|
||||
|
||||
Reference in New Issue
Block a user