Test: generate with torch.compile(model.forward) as a fast test (#34544)

This commit is contained in:
Joao Gante
2025-01-28 14:10:38 +00:00
committed by GitHub
parent f48ecd7608
commit ece8c42488
25 changed files with 105 additions and 53 deletions

View File

@@ -364,7 +364,7 @@ class CacheIntegrationTest(unittest.TestCase):
input_ids = gen_out
# We went well beyond the cache length
self.assertTrue(input_ids.shape[1] > cache.get_max_length() * 1.5)
self.assertTrue(input_ids.shape[1] > cache.get_max_cache_shape() * 1.5)
# And it still produces a coherent english
decoded = tokenizer.batch_decode(input_ids, skip_special_tokens=True)