diff --git a/examples/test_examples.py b/examples/test_examples.py index 8ea51b5726..989ec367ee 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -98,12 +98,12 @@ class ExamplesTests(unittest.TestCase): testargs = ["run_generation.py", "--prompt=Hello", + "--length=10", "--seed=42"] model_name = "--model_name=openai-gpt" with patch.object(sys, 'argv', testargs + [model_name]): result = run_generation.main() - self.assertGreaterEqual(result['f1'], 30) - self.assertGreaterEqual(result['exact'], 30) + self.assertGreaterEqual(len(result), 10) if __name__ == "__main__": unittest.main()