Responses API in transformers serve (#39155)

* Scaffolding

* Explicit content

* Naïve Responses API streaming implementation

* Cleanup

* Responses API (to be merged into #39155) (#39338)

* Scaffolding

* Explicit content

* Naïve Responses API streaming implementation

* Cleanup

* use openai

* validate request, including detecting unused fields

* dict indexing

* dict var access

* tmp commit (tests failing)

* add slow

* use oai output type in completions

* (little rebase errors)

* working spec?

* guard type hint

* type hints. fix state (CB can now load different models)

* type hints; fn names; error type

* add docstrings

* responses + kv cache

* metadata support; fix kv cache; error event

* add output_index and content_index

* docstrings

* add test_build_response_event

* docs/comments

* gate test requirements; terminate cb manager on model switch

* nasty type hints

* more type hints

* disable validation by default; enable force models

* todo

---------

Co-authored-by: Lysandre <hi@lysand.re>

* Slight bugfixes

* PR comments from #39338

* make fixup

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Joao Gante <joao@huggingface.co>
This commit is contained in:
Lysandre Debut
2025-07-16 14:16:16 +02:00
committed by GitHub
parent c8524aeb07
commit de5ca373ac
8 changed files with 937 additions and 380 deletions

View File

@@ -24,9 +24,16 @@ from parameterized import parameterized
import transformers.commands.transformers_cli as cli
from transformers import GenerationConfig
from transformers.commands.serving import ServeArguments, ServeCommand
from transformers.testing_utils import CaptureStd, slow
from transformers.testing_utils import CaptureStd, require_openai, slow
from transformers.utils.import_utils import is_openai_available
if is_openai_available():
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction
from openai.types.responses import Response, ResponseCreatedEvent
@require_openai
class ServeCLITest(unittest.TestCase):
def test_help(self):
"""Minimal test: we can invoke the help command."""
@@ -49,36 +56,94 @@ class ServeCLITest(unittest.TestCase):
self.assertEqual(parsed_args.host, "0.0.0.0")
self.assertEqual(parsed_args.port, 9000)
def test_completions_build_chunk(self):
"""Tests that the chunks are correctly built for the Completions API."""
def test_build_chat_completion_chunk(self):
"""
Tests that the chunks are correctly built for the Chat Completion API. The `choices` checks implictly
confirm that empty fields are not emitted.
"""
dummy = ServeCommand.__new__(ServeCommand)
dummy.args = type("Args", (), {})()
dummy.loaded_model = "dummy_model@main"
# The keys for these fields must be present in every chunk
MANDATORY_FIELDS = ["data", "id", "choices", "created", "model", "object", "system_fingerprint"]
# Case 1: most fields are provided
chunk = ServeCommand.build_chunk(dummy, request_id="req0", content="hello", finish_reason="stop", role="user")
self.assertIn("chat.completion.chunk", chunk)
self.assertIn("data:", chunk)
chunk = ServeCommand.build_chat_completion_chunk(
dummy, request_id="req0", content="hello", finish_reason="stop", role="user"
)
for field in MANDATORY_FIELDS:
self.assertIn(field, chunk)
self.assertIn(
'"choices": [{"delta": {"content": "hello", "role": "user"}, "index": 0, "finish_reason": "stop"}]', chunk
'"choices":[{"delta":{"content":"hello","role":"user"},"finish_reason":"stop","index":0}]', chunk
)
# Case 2: only the role is provided -- other fields in 'choices' are omitted
chunk = ServeCommand.build_chunk(dummy, request_id="req0", role="user")
self.assertIn("chat.completion.chunk", chunk)
self.assertIn("data:", chunk)
self.assertIn('"choices": [{"delta": {"role": "user"}, "index": 0}]', chunk)
chunk = dummy.build_chat_completion_chunk(request_id="req0", role="user")
for field in MANDATORY_FIELDS:
self.assertIn(field, chunk)
self.assertIn('"choices":[{"delta":{"role":"user"},"index":0}]', chunk)
# Case 3: only the content is provided -- other fields in 'choices' are omitted
chunk = ServeCommand.build_chunk(dummy, request_id="req0", content="hello")
self.assertIn("chat.completion.chunk", chunk)
self.assertIn("data:", chunk)
self.assertIn('"choices": [{"delta": {"content": "hello"}, "index": 0}]', chunk)
chunk = dummy.build_chat_completion_chunk(request_id="req0", content="hello")
for field in MANDATORY_FIELDS:
self.assertIn(field, chunk)
self.assertIn('"choices":[{"delta":{"content":"hello"},"index":0}]', chunk)
# Case 4: tool calls support a list of nested dictionaries
chunk = ServeCommand.build_chunk(dummy, request_id="req0", tool_calls=[{"foo1": "bar1", "foo2": "bar2"}])
self.assertIn("chat.completion.chunk", chunk)
self.assertIn("data:", chunk)
self.assertIn('"choices": [{"delta": {"tool_calls": [{"foo1": "bar1", "foo2": "bar2"}]}, "index": 0}]', chunk)
# Case 4: tool calls support a list of ChoiceDeltaToolCall objects
tool_call = ChoiceDeltaToolCall(
index=0,
function=ChoiceDeltaToolCallFunction(name="foo_bar", arguments='{"foo1": "bar1", "foo2": "bar2"}'),
type="function",
)
chunk = dummy.build_chat_completion_chunk(request_id="req0", tool_calls=[tool_call])
for field in MANDATORY_FIELDS:
self.assertIn(field, chunk)
expected_choices_content = (
'choices":[{"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\\"foo1\\": \\"bar1\\", '
'\\"foo2\\": \\"bar2\\"}","name":"foo_bar"},"type":"function"}]},"index":0}]'
)
self.assertIn(expected_choices_content, chunk)
def test_build_response_event(self):
"""
Tests that the events are correctly built for the Response API.
Contrarily to the Chat Completion API, the Response API has a wide set of possible output objects. This test
only checks a few basic assumptions -- we rely on OpenAI's pydantic models to enforce the correct schema.
"""
dummy = ServeCommand.__new__(ServeCommand)
dummy.args = type("Args", (), {})()
response_created = ResponseCreatedEvent(
type="response.created",
sequence_number=0,
response=Response(
id="resp_0",
created_at=time.time(),
status="queued",
model="dummy_model@main",
instructions=None, # <--- is set to None = should NOT be in the output.
text={"format": {"type": "text"}},
object="response",
tools=[], # <--- empty lists should be in the output (they are often mandatory fields)
output=[],
parallel_tool_calls=False,
tool_choice="auto",
metadata=None,
),
)
event = dummy.build_response_event(response_created)
self.assertTrue(event.startswith("data: ")) # Sanity check: event formatting
self.assertIn('"model":"dummy_model@main"', event) # Sanity check: set field
self.assertIn('"status":"queued"', event)
self.assertIn("tools", event) # empty lists should be in the output
self.assertIn("output", event)
self.assertNotIn("instructions", event) # None fields should NOT be in the output
self.assertNotIn("metadata", event)
self.assertNotIn("error", event) # Unset optional fields should NOT be in the output
self.assertNotIn("top_p", event)
def async_retry(fn, max_attempts=5, delay=2):
@@ -105,7 +170,7 @@ class ServeCompletionsMixin:
@async_retry
async def run_server(self, request):
client = AsyncInferenceClient("http://localhost:8000")
client = AsyncInferenceClient(f"http://localhost:{self.port}")
stream = client.chat_completion(**request)
all_payloads = []
@@ -119,8 +184,7 @@ class ServeCompletionsMixin:
[
("default_request", {}),
("one_token", {"max_tokens": 1}),
# TODO: CB fails next case, seems like it is unable to switch models. fix me
# ("different_model", {"model": "HuggingFaceTB/SmolLM2-135M-Instruct"}),
("different_model", {"model": "HuggingFaceTB/SmolLM2-135M-Instruct"}),
(
"tool_call",
{
@@ -191,20 +255,20 @@ class ServeCompletionsMixin:
# sets `do_sample=True`
self.assertEqual(output_text, '<think>\nOkay, the user just asked, "')
# TODO: implement API-compliant error handling, and then test it
# See https://platform.openai.com/docs/guides/error-codes,
# TODO: one test for each request flag, to confirm it is working as expected
# TODO: speed-based test to confirm that KV cache is working across requests
@slow # TODO (joao): this shouldn't be needed
class ServeCompletionsGenerateTest(ServeCompletionsMixin, unittest.TestCase):
@slow # server startup time is slow on our push CI
@require_openai
class ServeCompletionsGenerateIntegrationTest(ServeCompletionsMixin, unittest.TestCase):
"""Tests the `generate` version of the Completions API."""
@classmethod
def setUpClass(cls):
"""Starts a server for tests to connect to."""
args = ServeArguments()
cls.port = 8001
args = ServeArguments(port=cls.port)
serve_command = ServeCommand(args)
thread = Thread(target=serve_command.run)
thread.daemon = True
@@ -287,15 +351,20 @@ class ServeCompletionsGenerateTest(ServeCompletionsMixin, unittest.TestCase):
self.assertTrue(all(reason is None for reason in finish_reasons[:-1]))
@slow # TODO (joao): this shouldn't be needed
class ServeCompletionsContinuousBatchingTest(ServeCompletionsMixin, unittest.TestCase):
@slow # server startup time is slow on our push CI
@require_openai
class ServeCompletionsContinuousBatchingIntegrationTest(ServeCompletionsMixin, unittest.TestCase):
"""Tests the `continuous_batching` version of the Completions API."""
@classmethod
def setUpClass(cls):
"""Starts a server for tests to connect to."""
args = ServeArguments(attn_implementation="sdpa_paged") # important: toggle continuous batching
cls.port = 8002
args = ServeArguments(port=cls.port, attn_implementation="sdpa_paged") # important: toggle continuous batching
serve_command = ServeCommand(args)
thread = Thread(target=serve_command.run)
thread.daemon = True
thread.start()
# TODO: Response integration tests