Fix responses add tests (#39848)

* Quick responses fix

* [serve] Fix responses API and add tests

* Remove typo

* Remove typo

* Tests
This commit is contained in:
Lysandre Debut
2025-08-01 18:06:08 +02:00
committed by GitHub
parent 6ea646a03a
commit 88ead3f518
2 changed files with 131 additions and 4 deletions

View File

@@ -30,8 +30,20 @@ from transformers.utils.import_utils import is_openai_available
if is_openai_available():
from openai import APIConnectionError, OpenAI
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction
from openai.types.responses import Response, ResponseCreatedEvent
from openai.types.responses import (
Response,
ResponseCompletedEvent,
ResponseContentPartAddedEvent,
ResponseContentPartDoneEvent,
ResponseCreatedEvent,
ResponseInProgressEvent,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseTextDeltaEvent,
ResponseTextDoneEvent,
)
@require_openai
@@ -156,7 +168,7 @@ def async_retry(fn, max_attempts=5, delay=2):
for _ in range(max_attempts):
try:
return await fn(*args, **kwargs)
except aiohttp.client_exceptions.ClientConnectorError:
except (aiohttp.client_exceptions.ClientConnectorError, APIConnectionError):
time.sleep(delay)
return wrapper
@@ -465,4 +477,94 @@ class ServeCompletionsContinuousBatchingIntegrationTest(ServeCompletionsMixin, u
thread.start()
# TODO: Response integration tests
@require_openai
class ServeResponsesMixin:
"""
Mixin class for the Completions API tests, to seamlessly replicate tests across the two versions of the API
(`generate` and `continuous_batching`).
"""
@async_retry
async def run_server(self, request):
client = OpenAI(base_url=f"http://localhost:{self.port}/v1", api_key="<KEY>")
stream = client.responses.create(**request)
all_payloads = []
for payload in stream:
all_payloads.append(payload)
return all_payloads
def test_request(self):
"""Tests that an inference using the Responses API works"""
request = {
"model": "Qwen/Qwen2.5-0.5B-Instruct",
"instructions": "You are a helpful assistant.",
"input": "Hello!",
"stream": True,
"max_output_tokens": 1,
}
all_payloads = asyncio.run(self.run_server(request))
print("ok")
order_of_payloads = [
ResponseCreatedEvent,
ResponseInProgressEvent,
ResponseOutputItemAddedEvent,
ResponseContentPartAddedEvent,
ResponseTextDeltaEvent,
ResponseTextDeltaEvent,
ResponseTextDoneEvent,
ResponseContentPartDoneEvent,
ResponseOutputItemDoneEvent,
ResponseCompletedEvent,
]
self.assertEqual(len(all_payloads), 10)
for payload, payload_type in zip(all_payloads, order_of_payloads):
self.assertIsInstance(payload, payload_type)
# TODO: one test for each request flag, to confirm it is working as expected
# TODO: speed-based test to confirm that KV cache is working across requests
@slow # server startup time is slow on our push CI
@require_openai
class ServeResponsesIntegrationTest(ServeResponsesMixin, unittest.TestCase):
"""Tests the Responses API."""
@classmethod
def setUpClass(cls):
"""Starts a server for tests to connect to."""
cls.port = 8003
args = ServeArguments(port=cls.port, default_seed=42)
serve_command = ServeCommand(args)
thread = Thread(target=serve_command.run)
thread.daemon = True
thread.start()
@slow
def test_full_request(self):
"""Tests that an inference using the Responses API works"""
request = {
"model": "Qwen/Qwen2.5-0.5B-Instruct",
"instructions": "You are a sports assistant designed to craft sports programs.",
"input": "Tell me what you can do.",
"stream": True,
"max_output_tokens": 30,
}
all_payloads = asyncio.run(self.run_server(request))
full_text = ""
for token in all_payloads:
if isinstance(token, ResponseTextDeltaEvent):
full_text += token.delta
# Verify that the system prompt went through.
self.assertTrue(
full_text.startswith(
"As an AI language model, I am designed to assist with various tasks and provide information on different topics related to sports."
)
)