Fix responses add tests (#39848)
* Quick responses fix * [serve] Fix responses API and add tests * Remove typo * Remove typo * Tests
This commit is contained in:
@@ -396,6 +396,9 @@ class ServeArguments:
|
|||||||
log_level: str = field(
|
log_level: str = field(
|
||||||
default="info", metadata={"help": "Logging level as a string. Example: 'info' or 'warning'."}
|
default="info", metadata={"help": "Logging level as a string. Example: 'info' or 'warning'."}
|
||||||
)
|
)
|
||||||
|
default_seed: Optional[int] = field(
|
||||||
|
default=None, metadata={"help": "The default seed for torch, should be an integer."}
|
||||||
|
)
|
||||||
enable_cors: bool = field(
|
enable_cors: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
@@ -451,6 +454,9 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
self.use_continuous_batching = self.args.attn_implementation == "sdpa_paged"
|
self.use_continuous_batching = self.args.attn_implementation == "sdpa_paged"
|
||||||
self.enable_cors = self.args.enable_cors
|
self.enable_cors = self.args.enable_cors
|
||||||
|
|
||||||
|
if self.args.default_seed is not None:
|
||||||
|
torch.manual_seed(self.args.default_seed)
|
||||||
|
|
||||||
# Set up logging
|
# Set up logging
|
||||||
transformers_logger = logging.get_logger("transformers")
|
transformers_logger = logging.get_logger("transformers")
|
||||||
transformers_logger.setLevel(logging.log_levels[self.args.log_level.lower()])
|
transformers_logger.setLevel(logging.log_levels[self.args.log_level.lower()])
|
||||||
@@ -1032,7 +1038,26 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
self.last_model = model_id_and_revision
|
self.last_model = model_id_and_revision
|
||||||
model, processor = self.load_model_and_processor(model_id_and_revision)
|
model, processor = self.load_model_and_processor(model_id_and_revision)
|
||||||
|
|
||||||
inputs = processor.apply_chat_template(req["input"], add_generation_prompt=True).to(model.device)
|
if isinstance(req["input"], str):
|
||||||
|
inputs = [{"role": "system", "content": req["instructions"]}] if "instructions" in req else []
|
||||||
|
inputs.append({"role": "user", "content": req["input"]})
|
||||||
|
elif isinstance(req["input"], list):
|
||||||
|
if "instructions" in req:
|
||||||
|
if req["input"][0]["role"] != "system":
|
||||||
|
inputs = [{"role": "system", "content": req["instructions"]}, *req["input"]]
|
||||||
|
else:
|
||||||
|
inputs = req["input"]
|
||||||
|
inputs[0]["content"] = req["instructions"]
|
||||||
|
else:
|
||||||
|
inputs = req["input"]
|
||||||
|
elif isinstance(req["input"], dict):
|
||||||
|
inputs = [{"role": "system", "content": req["instructions"]}] if "instructions" in req else []
|
||||||
|
inputs.append(req["input"])
|
||||||
|
else:
|
||||||
|
raise ValueError("inputs should be a list, dict, or str")
|
||||||
|
|
||||||
|
inputs = processor.apply_chat_template(inputs, add_generation_prompt=True, return_tensors="pt")
|
||||||
|
inputs = inputs.to(model.device)
|
||||||
request_id = req.get("previous_response_id", "req_0")
|
request_id = req.get("previous_response_id", "req_0")
|
||||||
|
|
||||||
generation_streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True)
|
generation_streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True)
|
||||||
|
|||||||
@@ -30,8 +30,20 @@ from transformers.utils.import_utils import is_openai_available
|
|||||||
|
|
||||||
|
|
||||||
if is_openai_available():
|
if is_openai_available():
|
||||||
|
from openai import APIConnectionError, OpenAI
|
||||||
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction
|
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction
|
||||||
from openai.types.responses import Response, ResponseCreatedEvent
|
from openai.types.responses import (
|
||||||
|
Response,
|
||||||
|
ResponseCompletedEvent,
|
||||||
|
ResponseContentPartAddedEvent,
|
||||||
|
ResponseContentPartDoneEvent,
|
||||||
|
ResponseCreatedEvent,
|
||||||
|
ResponseInProgressEvent,
|
||||||
|
ResponseOutputItemAddedEvent,
|
||||||
|
ResponseOutputItemDoneEvent,
|
||||||
|
ResponseTextDeltaEvent,
|
||||||
|
ResponseTextDoneEvent,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@require_openai
|
@require_openai
|
||||||
@@ -156,7 +168,7 @@ def async_retry(fn, max_attempts=5, delay=2):
|
|||||||
for _ in range(max_attempts):
|
for _ in range(max_attempts):
|
||||||
try:
|
try:
|
||||||
return await fn(*args, **kwargs)
|
return await fn(*args, **kwargs)
|
||||||
except aiohttp.client_exceptions.ClientConnectorError:
|
except (aiohttp.client_exceptions.ClientConnectorError, APIConnectionError):
|
||||||
time.sleep(delay)
|
time.sleep(delay)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
@@ -465,4 +477,94 @@ class ServeCompletionsContinuousBatchingIntegrationTest(ServeCompletionsMixin, u
|
|||||||
thread.start()
|
thread.start()
|
||||||
|
|
||||||
|
|
||||||
# TODO: Response integration tests
|
@require_openai
|
||||||
|
class ServeResponsesMixin:
|
||||||
|
"""
|
||||||
|
Mixin class for the Completions API tests, to seamlessly replicate tests across the two versions of the API
|
||||||
|
(`generate` and `continuous_batching`).
|
||||||
|
"""
|
||||||
|
|
||||||
|
@async_retry
|
||||||
|
async def run_server(self, request):
|
||||||
|
client = OpenAI(base_url=f"http://localhost:{self.port}/v1", api_key="<KEY>")
|
||||||
|
stream = client.responses.create(**request)
|
||||||
|
|
||||||
|
all_payloads = []
|
||||||
|
for payload in stream:
|
||||||
|
all_payloads.append(payload)
|
||||||
|
|
||||||
|
return all_payloads
|
||||||
|
|
||||||
|
def test_request(self):
|
||||||
|
"""Tests that an inference using the Responses API works"""
|
||||||
|
|
||||||
|
request = {
|
||||||
|
"model": "Qwen/Qwen2.5-0.5B-Instruct",
|
||||||
|
"instructions": "You are a helpful assistant.",
|
||||||
|
"input": "Hello!",
|
||||||
|
"stream": True,
|
||||||
|
"max_output_tokens": 1,
|
||||||
|
}
|
||||||
|
all_payloads = asyncio.run(self.run_server(request))
|
||||||
|
print("ok")
|
||||||
|
|
||||||
|
order_of_payloads = [
|
||||||
|
ResponseCreatedEvent,
|
||||||
|
ResponseInProgressEvent,
|
||||||
|
ResponseOutputItemAddedEvent,
|
||||||
|
ResponseContentPartAddedEvent,
|
||||||
|
ResponseTextDeltaEvent,
|
||||||
|
ResponseTextDeltaEvent,
|
||||||
|
ResponseTextDoneEvent,
|
||||||
|
ResponseContentPartDoneEvent,
|
||||||
|
ResponseOutputItemDoneEvent,
|
||||||
|
ResponseCompletedEvent,
|
||||||
|
]
|
||||||
|
|
||||||
|
self.assertEqual(len(all_payloads), 10)
|
||||||
|
for payload, payload_type in zip(all_payloads, order_of_payloads):
|
||||||
|
self.assertIsInstance(payload, payload_type)
|
||||||
|
|
||||||
|
# TODO: one test for each request flag, to confirm it is working as expected
|
||||||
|
# TODO: speed-based test to confirm that KV cache is working across requests
|
||||||
|
|
||||||
|
|
||||||
|
@slow # server startup time is slow on our push CI
|
||||||
|
@require_openai
|
||||||
|
class ServeResponsesIntegrationTest(ServeResponsesMixin, unittest.TestCase):
|
||||||
|
"""Tests the Responses API."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
"""Starts a server for tests to connect to."""
|
||||||
|
cls.port = 8003
|
||||||
|
args = ServeArguments(port=cls.port, default_seed=42)
|
||||||
|
serve_command = ServeCommand(args)
|
||||||
|
thread = Thread(target=serve_command.run)
|
||||||
|
thread.daemon = True
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_full_request(self):
|
||||||
|
"""Tests that an inference using the Responses API works"""
|
||||||
|
|
||||||
|
request = {
|
||||||
|
"model": "Qwen/Qwen2.5-0.5B-Instruct",
|
||||||
|
"instructions": "You are a sports assistant designed to craft sports programs.",
|
||||||
|
"input": "Tell me what you can do.",
|
||||||
|
"stream": True,
|
||||||
|
"max_output_tokens": 30,
|
||||||
|
}
|
||||||
|
all_payloads = asyncio.run(self.run_server(request))
|
||||||
|
|
||||||
|
full_text = ""
|
||||||
|
for token in all_payloads:
|
||||||
|
if isinstance(token, ResponseTextDeltaEvent):
|
||||||
|
full_text += token.delta
|
||||||
|
|
||||||
|
# Verify that the system prompt went through.
|
||||||
|
self.assertTrue(
|
||||||
|
full_text.startswith(
|
||||||
|
"As an AI language model, I am designed to assist with various tasks and provide information on different topics related to sports."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user