Fix missing test in torch_job (#33593)
fix missing tests Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -89,7 +89,6 @@ if is_torch_available():
|
|||||||
from transformers.generation.utils import _speculative_sampling
|
from transformers.generation.utils import _speculative_sampling
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.generate
|
|
||||||
class GenerationTesterMixin:
|
class GenerationTesterMixin:
|
||||||
model_tester = None
|
model_tester = None
|
||||||
all_generative_model_classes = ()
|
all_generative_model_classes = ()
|
||||||
@@ -2035,6 +2034,7 @@ class GenerationTesterMixin:
|
|||||||
output_compiled = compiled_generate(model_inputs, generation_config=generation_config)
|
output_compiled = compiled_generate(model_inputs, generation_config=generation_config)
|
||||||
self.assertListEqual(output_dynamic.tolist(), output_compiled.tolist())
|
self.assertListEqual(output_dynamic.tolist(), output_compiled.tolist())
|
||||||
|
|
||||||
|
@pytest.mark.generate
|
||||||
def test_generate_methods_with_num_logits_to_keep(self):
|
def test_generate_methods_with_num_logits_to_keep(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
|
if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
|
||||||
@@ -2063,6 +2063,7 @@ class GenerationTesterMixin:
|
|||||||
)
|
)
|
||||||
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
|
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())
|
||||||
|
|
||||||
|
@pytest.mark.generate
|
||||||
@is_flaky() # assisted generation tests are flaky (minor fp ops differences)
|
@is_flaky() # assisted generation tests are flaky (minor fp ops differences)
|
||||||
def test_assisted_decoding_with_num_logits_to_keep(self):
|
def test_assisted_decoding_with_num_logits_to_keep(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
|
|||||||
Reference in New Issue
Block a user