Fix responses add tests (#39848)
* Quick responses fix * [serve] Fix responses API and add tests * Remove typo * Remove typo * Tests
This commit is contained in:
@@ -30,8 +30,20 @@ from transformers.utils.import_utils import is_openai_available
|
||||
|
||||
|
||||
if is_openai_available():
|
||||
from openai import APIConnectionError, OpenAI
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction
|
||||
from openai.types.responses import Response, ResponseCreatedEvent
|
||||
from openai.types.responses import (
|
||||
Response,
|
||||
ResponseCompletedEvent,
|
||||
ResponseContentPartAddedEvent,
|
||||
ResponseContentPartDoneEvent,
|
||||
ResponseCreatedEvent,
|
||||
ResponseInProgressEvent,
|
||||
ResponseOutputItemAddedEvent,
|
||||
ResponseOutputItemDoneEvent,
|
||||
ResponseTextDeltaEvent,
|
||||
ResponseTextDoneEvent,
|
||||
)
|
||||
|
||||
|
||||
@require_openai
|
||||
@@ -156,7 +168,7 @@ def async_retry(fn, max_attempts=5, delay=2):
|
||||
for _ in range(max_attempts):
|
||||
try:
|
||||
return await fn(*args, **kwargs)
|
||||
except aiohttp.client_exceptions.ClientConnectorError:
|
||||
except (aiohttp.client_exceptions.ClientConnectorError, APIConnectionError):
|
||||
time.sleep(delay)
|
||||
|
||||
return wrapper
|
||||
@@ -465,4 +477,94 @@ class ServeCompletionsContinuousBatchingIntegrationTest(ServeCompletionsMixin, u
|
||||
thread.start()
|
||||
|
||||
|
||||
# TODO: Response integration tests
|
||||
@require_openai
|
||||
class ServeResponsesMixin:
|
||||
"""
|
||||
Mixin class for the Completions API tests, to seamlessly replicate tests across the two versions of the API
|
||||
(`generate` and `continuous_batching`).
|
||||
"""
|
||||
|
||||
@async_retry
|
||||
async def run_server(self, request):
|
||||
client = OpenAI(base_url=f"http://localhost:{self.port}/v1", api_key="<KEY>")
|
||||
stream = client.responses.create(**request)
|
||||
|
||||
all_payloads = []
|
||||
for payload in stream:
|
||||
all_payloads.append(payload)
|
||||
|
||||
return all_payloads
|
||||
|
||||
def test_request(self):
|
||||
"""Tests that an inference using the Responses API works"""
|
||||
|
||||
request = {
|
||||
"model": "Qwen/Qwen2.5-0.5B-Instruct",
|
||||
"instructions": "You are a helpful assistant.",
|
||||
"input": "Hello!",
|
||||
"stream": True,
|
||||
"max_output_tokens": 1,
|
||||
}
|
||||
all_payloads = asyncio.run(self.run_server(request))
|
||||
print("ok")
|
||||
|
||||
order_of_payloads = [
|
||||
ResponseCreatedEvent,
|
||||
ResponseInProgressEvent,
|
||||
ResponseOutputItemAddedEvent,
|
||||
ResponseContentPartAddedEvent,
|
||||
ResponseTextDeltaEvent,
|
||||
ResponseTextDeltaEvent,
|
||||
ResponseTextDoneEvent,
|
||||
ResponseContentPartDoneEvent,
|
||||
ResponseOutputItemDoneEvent,
|
||||
ResponseCompletedEvent,
|
||||
]
|
||||
|
||||
self.assertEqual(len(all_payloads), 10)
|
||||
for payload, payload_type in zip(all_payloads, order_of_payloads):
|
||||
self.assertIsInstance(payload, payload_type)
|
||||
|
||||
# TODO: one test for each request flag, to confirm it is working as expected
|
||||
# TODO: speed-based test to confirm that KV cache is working across requests
|
||||
|
||||
|
||||
@slow # server startup time is slow on our push CI
|
||||
@require_openai
|
||||
class ServeResponsesIntegrationTest(ServeResponsesMixin, unittest.TestCase):
|
||||
"""Tests the Responses API."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""Starts a server for tests to connect to."""
|
||||
cls.port = 8003
|
||||
args = ServeArguments(port=cls.port, default_seed=42)
|
||||
serve_command = ServeCommand(args)
|
||||
thread = Thread(target=serve_command.run)
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
|
||||
@slow
|
||||
def test_full_request(self):
|
||||
"""Tests that an inference using the Responses API works"""
|
||||
|
||||
request = {
|
||||
"model": "Qwen/Qwen2.5-0.5B-Instruct",
|
||||
"instructions": "You are a sports assistant designed to craft sports programs.",
|
||||
"input": "Tell me what you can do.",
|
||||
"stream": True,
|
||||
"max_output_tokens": 30,
|
||||
}
|
||||
all_payloads = asyncio.run(self.run_server(request))
|
||||
|
||||
full_text = ""
|
||||
for token in all_payloads:
|
||||
if isinstance(token, ResponseTextDeltaEvent):
|
||||
full_text += token.delta
|
||||
|
||||
# Verify that the system prompt went through.
|
||||
self.assertTrue(
|
||||
full_text.startswith(
|
||||
"As an AI language model, I am designed to assist with various tasks and provide information on different topics related to sports."
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user