Fix responses add tests (#39848)

* Quick responses fix

* [serve] Fix responses API and add tests

* Remove typo

* Remove typo

* Tests
This commit is contained in:
Lysandre Debut
2025-08-01 18:06:08 +02:00
committed by GitHub
parent 6ea646a03a
commit 88ead3f518
2 changed files with 131 additions and 4 deletions

View File

@@ -396,6 +396,9 @@ class ServeArguments:
log_level: str = field( log_level: str = field(
default="info", metadata={"help": "Logging level as a string. Example: 'info' or 'warning'."} default="info", metadata={"help": "Logging level as a string. Example: 'info' or 'warning'."}
) )
default_seed: Optional[int] = field(
default=None, metadata={"help": "The default seed for torch, should be an integer."}
)
enable_cors: bool = field( enable_cors: bool = field(
default=False, default=False,
metadata={ metadata={
@@ -451,6 +454,9 @@ class ServeCommand(BaseTransformersCLICommand):
self.use_continuous_batching = self.args.attn_implementation == "sdpa_paged" self.use_continuous_batching = self.args.attn_implementation == "sdpa_paged"
self.enable_cors = self.args.enable_cors self.enable_cors = self.args.enable_cors
if self.args.default_seed is not None:
torch.manual_seed(self.args.default_seed)
# Set up logging # Set up logging
transformers_logger = logging.get_logger("transformers") transformers_logger = logging.get_logger("transformers")
transformers_logger.setLevel(logging.log_levels[self.args.log_level.lower()]) transformers_logger.setLevel(logging.log_levels[self.args.log_level.lower()])
@@ -1032,7 +1038,26 @@ class ServeCommand(BaseTransformersCLICommand):
self.last_model = model_id_and_revision self.last_model = model_id_and_revision
model, processor = self.load_model_and_processor(model_id_and_revision) model, processor = self.load_model_and_processor(model_id_and_revision)
inputs = processor.apply_chat_template(req["input"], add_generation_prompt=True).to(model.device) if isinstance(req["input"], str):
inputs = [{"role": "system", "content": req["instructions"]}] if "instructions" in req else []
inputs.append({"role": "user", "content": req["input"]})
elif isinstance(req["input"], list):
if "instructions" in req:
if req["input"][0]["role"] != "system":
inputs = [{"role": "system", "content": req["instructions"]}, *req["input"]]
else:
inputs = req["input"]
inputs[0]["content"] = req["instructions"]
else:
inputs = req["input"]
elif isinstance(req["input"], dict):
inputs = [{"role": "system", "content": req["instructions"]}] if "instructions" in req else []
inputs.append(req["input"])
else:
raise ValueError("inputs should be a list, dict, or str")
inputs = processor.apply_chat_template(inputs, add_generation_prompt=True, return_tensors="pt")
inputs = inputs.to(model.device)
request_id = req.get("previous_response_id", "req_0") request_id = req.get("previous_response_id", "req_0")
generation_streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True) generation_streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True)

View File

@@ -30,8 +30,20 @@ from transformers.utils.import_utils import is_openai_available
if is_openai_available(): if is_openai_available():
from openai import APIConnectionError, OpenAI
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction
from openai.types.responses import Response, ResponseCreatedEvent from openai.types.responses import (
Response,
ResponseCompletedEvent,
ResponseContentPartAddedEvent,
ResponseContentPartDoneEvent,
ResponseCreatedEvent,
ResponseInProgressEvent,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseTextDeltaEvent,
ResponseTextDoneEvent,
)
@require_openai @require_openai
@@ -156,7 +168,7 @@ def async_retry(fn, max_attempts=5, delay=2):
for _ in range(max_attempts): for _ in range(max_attempts):
try: try:
return await fn(*args, **kwargs) return await fn(*args, **kwargs)
except aiohttp.client_exceptions.ClientConnectorError: except (aiohttp.client_exceptions.ClientConnectorError, APIConnectionError):
time.sleep(delay) time.sleep(delay)
return wrapper return wrapper
@@ -465,4 +477,94 @@ class ServeCompletionsContinuousBatchingIntegrationTest(ServeCompletionsMixin, u
thread.start() thread.start()
# TODO: Response integration tests @require_openai
class ServeResponsesMixin:
"""
Mixin class for the Completions API tests, to seamlessly replicate tests across the two versions of the API
(`generate` and `continuous_batching`).
"""
@async_retry
async def run_server(self, request):
client = OpenAI(base_url=f"http://localhost:{self.port}/v1", api_key="<KEY>")
stream = client.responses.create(**request)
all_payloads = []
for payload in stream:
all_payloads.append(payload)
return all_payloads
def test_request(self):
"""Tests that an inference using the Responses API works"""
request = {
"model": "Qwen/Qwen2.5-0.5B-Instruct",
"instructions": "You are a helpful assistant.",
"input": "Hello!",
"stream": True,
"max_output_tokens": 1,
}
all_payloads = asyncio.run(self.run_server(request))
print("ok")
order_of_payloads = [
ResponseCreatedEvent,
ResponseInProgressEvent,
ResponseOutputItemAddedEvent,
ResponseContentPartAddedEvent,
ResponseTextDeltaEvent,
ResponseTextDeltaEvent,
ResponseTextDoneEvent,
ResponseContentPartDoneEvent,
ResponseOutputItemDoneEvent,
ResponseCompletedEvent,
]
self.assertEqual(len(all_payloads), 10)
for payload, payload_type in zip(all_payloads, order_of_payloads):
self.assertIsInstance(payload, payload_type)
# TODO: one test for each request flag, to confirm it is working as expected
# TODO: speed-based test to confirm that KV cache is working across requests
@slow # server startup time is slow on our push CI
@require_openai
class ServeResponsesIntegrationTest(ServeResponsesMixin, unittest.TestCase):
"""Tests the Responses API."""
@classmethod
def setUpClass(cls):
"""Starts a server for tests to connect to."""
cls.port = 8003
args = ServeArguments(port=cls.port, default_seed=42)
serve_command = ServeCommand(args)
thread = Thread(target=serve_command.run)
thread.daemon = True
thread.start()
@slow
def test_full_request(self):
"""Tests that an inference using the Responses API works"""
request = {
"model": "Qwen/Qwen2.5-0.5B-Instruct",
"instructions": "You are a sports assistant designed to craft sports programs.",
"input": "Tell me what you can do.",
"stream": True,
"max_output_tokens": 30,
}
all_payloads = asyncio.run(self.run_server(request))
full_text = ""
for token in all_payloads:
if isinstance(token, ResponseTextDeltaEvent):
full_text += token.delta
# Verify that the system prompt went through.
self.assertTrue(
full_text.startswith(
"As an AI language model, I am designed to assist with various tasks and provide information on different topics related to sports."
)
)