diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 5cda62c397..c738ca4645 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1947,7 +1947,7 @@ class GenerationTesterMixin: ) @parameterized.expand([("offloaded",)]) # ("offloaded_static",) TODO: @raushan fixme in some models (eg T5) - @require_torch_gpu + @require_torch_accelerator @pytest.mark.generate def test_offloaded_cache_implementation(self, cache_implementation): """Tests we can generate by indicating `cache_implementation` for each possible cache class"""