@@ -85,6 +85,12 @@ class Lfm2ModelTest(CausalLMModelTest, unittest.TestCase):
|
|||||||
def test_contrastive_generate_low_memory(self):
|
def test_contrastive_generate_low_memory(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
"Lfm2 has a special cache format which is not compatible with compile as it has static address for conv cache"
|
||||||
|
)
|
||||||
|
def test_sdpa_can_compile_dynamic(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
@require_read_token
|
@require_read_token
|
||||||
|
|||||||
@@ -4484,6 +4484,7 @@ class ModelTesterMixin:
|
|||||||
),
|
),
|
||||||
"position_ids": torch.arange(0, 10, device=torch_device).unsqueeze(0),
|
"position_ids": torch.arange(0, 10, device=torch_device).unsqueeze(0),
|
||||||
"labels": torch.randint(low=1, high=model.config.vocab_size, size=(2, 10), device=torch_device),
|
"labels": torch.randint(low=1, high=model.config.vocab_size, size=(2, 10), device=torch_device),
|
||||||
|
"use_cache": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
# eager backward
|
# eager backward
|
||||||
|
|||||||
Reference in New Issue
Block a user