diff --git a/src/transformers/commands/serving.py b/src/transformers/commands/serving.py index 06a3bc6b92..d8d7d1b22b 100644 --- a/src/transformers/commands/serving.py +++ b/src/transformers/commands/serving.py @@ -396,6 +396,9 @@ class ServeArguments: log_level: str = field( default="info", metadata={"help": "Logging level as a string. Example: 'info' or 'warning'."} ) + default_seed: Optional[int] = field( + default=None, metadata={"help": "The default seed for torch, should be an integer."} + ) enable_cors: bool = field( default=False, metadata={ @@ -451,6 +454,9 @@ class ServeCommand(BaseTransformersCLICommand): self.use_continuous_batching = self.args.attn_implementation == "sdpa_paged" self.enable_cors = self.args.enable_cors + if self.args.default_seed is not None: + torch.manual_seed(self.args.default_seed) + # Set up logging transformers_logger = logging.get_logger("transformers") transformers_logger.setLevel(logging.log_levels[self.args.log_level.lower()]) @@ -1032,7 +1038,26 @@ class ServeCommand(BaseTransformersCLICommand): self.last_model = model_id_and_revision model, processor = self.load_model_and_processor(model_id_and_revision) - inputs = processor.apply_chat_template(req["input"], add_generation_prompt=True).to(model.device) + if isinstance(req["input"], str): + inputs = [{"role": "system", "content": req["instructions"]}] if "instructions" in req else [] + inputs.append({"role": "user", "content": req["input"]}) + elif isinstance(req["input"], list): + if "instructions" in req: + if req["input"][0]["role"] != "system": + inputs = [{"role": "system", "content": req["instructions"]}, *req["input"]] + else: + inputs = req["input"] + inputs[0]["content"] = req["instructions"] + else: + inputs = req["input"] + elif isinstance(req["input"], dict): + inputs = [{"role": "system", "content": req["instructions"]}] if "instructions" in req else [] + inputs.append(req["input"]) + else: + raise ValueError("inputs should be a list, dict, or str") + + inputs = processor.apply_chat_template(inputs, add_generation_prompt=True, return_tensors="pt") + inputs = inputs.to(model.device) request_id = req.get("previous_response_id", "req_0") generation_streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True) diff --git a/tests/commands/test_serving.py b/tests/commands/test_serving.py index f0592686e8..87b5f61e2c 100644 --- a/tests/commands/test_serving.py +++ b/tests/commands/test_serving.py @@ -30,8 +30,20 @@ from transformers.utils.import_utils import is_openai_available if is_openai_available(): + from openai import APIConnectionError, OpenAI from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction - from openai.types.responses import Response, ResponseCreatedEvent + from openai.types.responses import ( + Response, + ResponseCompletedEvent, + ResponseContentPartAddedEvent, + ResponseContentPartDoneEvent, + ResponseCreatedEvent, + ResponseInProgressEvent, + ResponseOutputItemAddedEvent, + ResponseOutputItemDoneEvent, + ResponseTextDeltaEvent, + ResponseTextDoneEvent, + ) @require_openai @@ -156,7 +168,7 @@ def async_retry(fn, max_attempts=5, delay=2): for _ in range(max_attempts): try: return await fn(*args, **kwargs) - except aiohttp.client_exceptions.ClientConnectorError: + except (aiohttp.client_exceptions.ClientConnectorError, APIConnectionError): time.sleep(delay) return wrapper @@ -465,4 +477,94 @@ class ServeCompletionsContinuousBatchingIntegrationTest(ServeCompletionsMixin, u thread.start() -# TODO: Response integration tests +@require_openai +class ServeResponsesMixin: + """ + Mixin class for the Completions API tests, to seamlessly replicate tests across the two versions of the API + (`generate` and `continuous_batching`). + """ + + @async_retry + async def run_server(self, request): + client = OpenAI(base_url=f"http://localhost:{self.port}/v1", api_key="") + stream = client.responses.create(**request) + + all_payloads = [] + for payload in stream: + all_payloads.append(payload) + + return all_payloads + + def test_request(self): + """Tests that an inference using the Responses API works""" + + request = { + "model": "Qwen/Qwen2.5-0.5B-Instruct", + "instructions": "You are a helpful assistant.", + "input": "Hello!", + "stream": True, + "max_output_tokens": 1, + } + all_payloads = asyncio.run(self.run_server(request)) + print("ok") + + order_of_payloads = [ + ResponseCreatedEvent, + ResponseInProgressEvent, + ResponseOutputItemAddedEvent, + ResponseContentPartAddedEvent, + ResponseTextDeltaEvent, + ResponseTextDeltaEvent, + ResponseTextDoneEvent, + ResponseContentPartDoneEvent, + ResponseOutputItemDoneEvent, + ResponseCompletedEvent, + ] + + self.assertEqual(len(all_payloads), 10) + for payload, payload_type in zip(all_payloads, order_of_payloads): + self.assertIsInstance(payload, payload_type) + + # TODO: one test for each request flag, to confirm it is working as expected + # TODO: speed-based test to confirm that KV cache is working across requests + + +@slow # server startup time is slow on our push CI +@require_openai +class ServeResponsesIntegrationTest(ServeResponsesMixin, unittest.TestCase): + """Tests the Responses API.""" + + @classmethod + def setUpClass(cls): + """Starts a server for tests to connect to.""" + cls.port = 8003 + args = ServeArguments(port=cls.port, default_seed=42) + serve_command = ServeCommand(args) + thread = Thread(target=serve_command.run) + thread.daemon = True + thread.start() + + @slow + def test_full_request(self): + """Tests that an inference using the Responses API works""" + + request = { + "model": "Qwen/Qwen2.5-0.5B-Instruct", + "instructions": "You are a sports assistant designed to craft sports programs.", + "input": "Tell me what you can do.", + "stream": True, + "max_output_tokens": 30, + } + all_payloads = asyncio.run(self.run_server(request)) + + full_text = "" + for token in all_payloads: + if isinstance(token, ResponseTextDeltaEvent): + full_text += token.delta + + # Verify that the system prompt went through. + self.assertTrue( + full_text.startswith( + "As an AI language model, I am designed to assist with various tasks and provide information on different topics related to sports." + ) + )