Fix responses add tests (#39848)
* Quick responses fix * [serve] Fix responses API and add tests * Remove typo * Remove typo * Tests
This commit is contained in:
@@ -396,6 +396,9 @@ class ServeArguments:
|
||||
log_level: str = field(
|
||||
default="info", metadata={"help": "Logging level as a string. Example: 'info' or 'warning'."}
|
||||
)
|
||||
default_seed: Optional[int] = field(
|
||||
default=None, metadata={"help": "The default seed for torch, should be an integer."}
|
||||
)
|
||||
enable_cors: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
@@ -451,6 +454,9 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
self.use_continuous_batching = self.args.attn_implementation == "sdpa_paged"
|
||||
self.enable_cors = self.args.enable_cors
|
||||
|
||||
if self.args.default_seed is not None:
|
||||
torch.manual_seed(self.args.default_seed)
|
||||
|
||||
# Set up logging
|
||||
transformers_logger = logging.get_logger("transformers")
|
||||
transformers_logger.setLevel(logging.log_levels[self.args.log_level.lower()])
|
||||
@@ -1032,7 +1038,26 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
self.last_model = model_id_and_revision
|
||||
model, processor = self.load_model_and_processor(model_id_and_revision)
|
||||
|
||||
inputs = processor.apply_chat_template(req["input"], add_generation_prompt=True).to(model.device)
|
||||
if isinstance(req["input"], str):
|
||||
inputs = [{"role": "system", "content": req["instructions"]}] if "instructions" in req else []
|
||||
inputs.append({"role": "user", "content": req["input"]})
|
||||
elif isinstance(req["input"], list):
|
||||
if "instructions" in req:
|
||||
if req["input"][0]["role"] != "system":
|
||||
inputs = [{"role": "system", "content": req["instructions"]}, *req["input"]]
|
||||
else:
|
||||
inputs = req["input"]
|
||||
inputs[0]["content"] = req["instructions"]
|
||||
else:
|
||||
inputs = req["input"]
|
||||
elif isinstance(req["input"], dict):
|
||||
inputs = [{"role": "system", "content": req["instructions"]}] if "instructions" in req else []
|
||||
inputs.append(req["input"])
|
||||
else:
|
||||
raise ValueError("inputs should be a list, dict, or str")
|
||||
|
||||
inputs = processor.apply_chat_template(inputs, add_generation_prompt=True, return_tensors="pt")
|
||||
inputs = inputs.to(model.device)
|
||||
request_id = req.get("previous_response_id", "req_0")
|
||||
|
||||
generation_streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True)
|
||||
|
||||
@@ -30,8 +30,20 @@ from transformers.utils.import_utils import is_openai_available
|
||||
|
||||
|
||||
if is_openai_available():
|
||||
from openai import APIConnectionError, OpenAI
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction
|
||||
from openai.types.responses import Response, ResponseCreatedEvent
|
||||
from openai.types.responses import (
|
||||
Response,
|
||||
ResponseCompletedEvent,
|
||||
ResponseContentPartAddedEvent,
|
||||
ResponseContentPartDoneEvent,
|
||||
ResponseCreatedEvent,
|
||||
ResponseInProgressEvent,
|
||||
ResponseOutputItemAddedEvent,
|
||||
ResponseOutputItemDoneEvent,
|
||||
ResponseTextDeltaEvent,
|
||||
ResponseTextDoneEvent,
|
||||
)
|
||||
|
||||
|
||||
@require_openai
|
||||
@@ -156,7 +168,7 @@ def async_retry(fn, max_attempts=5, delay=2):
|
||||
for _ in range(max_attempts):
|
||||
try:
|
||||
return await fn(*args, **kwargs)
|
||||
except aiohttp.client_exceptions.ClientConnectorError:
|
||||
except (aiohttp.client_exceptions.ClientConnectorError, APIConnectionError):
|
||||
time.sleep(delay)
|
||||
|
||||
return wrapper
|
||||
@@ -465,4 +477,94 @@ class ServeCompletionsContinuousBatchingIntegrationTest(ServeCompletionsMixin, u
|
||||
thread.start()
|
||||
|
||||
|
||||
# TODO: Response integration tests
|
||||
@require_openai
|
||||
class ServeResponsesMixin:
|
||||
"""
|
||||
Mixin class for the Completions API tests, to seamlessly replicate tests across the two versions of the API
|
||||
(`generate` and `continuous_batching`).
|
||||
"""
|
||||
|
||||
@async_retry
|
||||
async def run_server(self, request):
|
||||
client = OpenAI(base_url=f"http://localhost:{self.port}/v1", api_key="<KEY>")
|
||||
stream = client.responses.create(**request)
|
||||
|
||||
all_payloads = []
|
||||
for payload in stream:
|
||||
all_payloads.append(payload)
|
||||
|
||||
return all_payloads
|
||||
|
||||
def test_request(self):
|
||||
"""Tests that an inference using the Responses API works"""
|
||||
|
||||
request = {
|
||||
"model": "Qwen/Qwen2.5-0.5B-Instruct",
|
||||
"instructions": "You are a helpful assistant.",
|
||||
"input": "Hello!",
|
||||
"stream": True,
|
||||
"max_output_tokens": 1,
|
||||
}
|
||||
all_payloads = asyncio.run(self.run_server(request))
|
||||
print("ok")
|
||||
|
||||
order_of_payloads = [
|
||||
ResponseCreatedEvent,
|
||||
ResponseInProgressEvent,
|
||||
ResponseOutputItemAddedEvent,
|
||||
ResponseContentPartAddedEvent,
|
||||
ResponseTextDeltaEvent,
|
||||
ResponseTextDeltaEvent,
|
||||
ResponseTextDoneEvent,
|
||||
ResponseContentPartDoneEvent,
|
||||
ResponseOutputItemDoneEvent,
|
||||
ResponseCompletedEvent,
|
||||
]
|
||||
|
||||
self.assertEqual(len(all_payloads), 10)
|
||||
for payload, payload_type in zip(all_payloads, order_of_payloads):
|
||||
self.assertIsInstance(payload, payload_type)
|
||||
|
||||
# TODO: one test for each request flag, to confirm it is working as expected
|
||||
# TODO: speed-based test to confirm that KV cache is working across requests
|
||||
|
||||
|
||||
@slow # server startup time is slow on our push CI
|
||||
@require_openai
|
||||
class ServeResponsesIntegrationTest(ServeResponsesMixin, unittest.TestCase):
|
||||
"""Tests the Responses API."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""Starts a server for tests to connect to."""
|
||||
cls.port = 8003
|
||||
args = ServeArguments(port=cls.port, default_seed=42)
|
||||
serve_command = ServeCommand(args)
|
||||
thread = Thread(target=serve_command.run)
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
|
||||
@slow
|
||||
def test_full_request(self):
|
||||
"""Tests that an inference using the Responses API works"""
|
||||
|
||||
request = {
|
||||
"model": "Qwen/Qwen2.5-0.5B-Instruct",
|
||||
"instructions": "You are a sports assistant designed to craft sports programs.",
|
||||
"input": "Tell me what you can do.",
|
||||
"stream": True,
|
||||
"max_output_tokens": 30,
|
||||
}
|
||||
all_payloads = asyncio.run(self.run_server(request))
|
||||
|
||||
full_text = ""
|
||||
for token in all_payloads:
|
||||
if isinstance(token, ResponseTextDeltaEvent):
|
||||
full_text += token.delta
|
||||
|
||||
# Verify that the system prompt went through.
|
||||
self.assertTrue(
|
||||
full_text.startswith(
|
||||
"As an AI language model, I am designed to assist with various tasks and provide information on different topics related to sports."
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user