Whisper: fix static cache CI (#35852)
* fix * remove overriden method * small change
This commit is contained in:
committed by
GitHub
parent
9725e5be2f
commit
365fecb4d0
@@ -3323,8 +3323,8 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
input_features = input_features.to(torch_device)
|
||||
eager_generated_ids = model.generate(input_features, max_new_tokens=64)
|
||||
|
||||
# Using statiic cache compiles forward for each decoding step, so we don't have to manually compile
|
||||
model.generation_config.cache_implementation = "static"
|
||||
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
# compile the forward pass and assert equivalence
|
||||
static_generated_ids = model.generate(input_features, max_new_tokens=64)
|
||||
@@ -3379,9 +3379,8 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
set_seed(42)
|
||||
eager_generated_ids = model.generate(**inputs, **gen_kwargs)
|
||||
|
||||
# compile the forward pass and assert equivalence
|
||||
# Using statiic cache compiles forward for each decoding step, so we don't have to manually compile
|
||||
model.generation_config.cache_implementation = "static"
|
||||
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
set_seed(42)
|
||||
static_generated_ids = model.generate(**inputs, **gen_kwargs)
|
||||
|
||||
Reference in New Issue
Block a user