@@ -85,6 +85,12 @@ class Lfm2ModelTest(CausalLMModelTest, unittest.TestCase):
|
||||
def test_contrastive_generate_low_memory(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"Lfm2 has a special cache format which is not compatible with compile as it has static address for conv cache"
|
||||
)
|
||||
def test_sdpa_can_compile_dynamic(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch_accelerator
|
||||
@require_read_token
|
||||
|
||||
@@ -4484,6 +4484,7 @@ class ModelTesterMixin:
|
||||
),
|
||||
"position_ids": torch.arange(0, 10, device=torch_device).unsqueeze(0),
|
||||
"labels": torch.randint(low=1, high=model.config.vocab_size, size=(2, 10), device=torch_device),
|
||||
"use_cache": False,
|
||||
}
|
||||
|
||||
# eager backward
|
||||
|
||||
Reference in New Issue
Block a user