diff --git a/examples/test_examples.py b/examples/test_examples.py index 2f88d129f8..688401ebc9 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -103,7 +103,7 @@ class ExamplesTests(unittest.TestCase): "--seed=42"] model_type, model_name = ("--model_type=openai-gpt", "--model_name_or_path=openai-gpt") - with patch.object(sys, 'argv', testargs + [model_name]): + with patch.object(sys, 'argv', testargs + [model_type, model_name]): result = run_generation.main() self.assertGreaterEqual(len(result), 10)