Whisper: fix static cache CI (#35852)

* fix

* remove overriden method

* small change
This commit is contained in:
Raushan Turganbay
2025-01-30 12:43:00 +01:00
committed by GitHub
parent 9725e5be2f
commit 365fecb4d0
4 changed files with 15 additions and 93 deletions

View File

@@ -3323,8 +3323,8 @@ class WhisperModelIntegrationTests(unittest.TestCase):
input_features = input_features.to(torch_device)
eager_generated_ids = model.generate(input_features, max_new_tokens=64)
# Using statiic cache compiles forward for each decoding step, so we don't have to manually compile
model.generation_config.cache_implementation = "static"
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
# compile the forward pass and assert equivalence
static_generated_ids = model.generate(input_features, max_new_tokens=64)
@@ -3379,9 +3379,8 @@ class WhisperModelIntegrationTests(unittest.TestCase):
set_seed(42)
eager_generated_ids = model.generate(**inputs, **gen_kwargs)
# compile the forward pass and assert equivalence
# Using statiic cache compiles forward for each decoding step, so we don't have to manually compile
model.generation_config.cache_implementation = "static"
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
set_seed(42)
static_generated_ids = model.generate(**inputs, **gen_kwargs)