Responses API in transformers serve (#39155)

* Scaffolding

* Explicit content

* Naïve Responses API streaming implementation

* Cleanup

* Responses API (to be merged into #39155) (#39338)

* Scaffolding

* Explicit content

* Naïve Responses API streaming implementation

* Cleanup

* use openai

* validate request, including detecting unused fields

* dict indexing

* dict var access

* tmp commit (tests failing)

* add slow

* use oai output type in completions

* (little rebase errors)

* working spec?

* guard type hint

* type hints. fix state (CB can now load different models)

* type hints; fn names; error type

* add docstrings

* responses + kv cache

* metadata support; fix kv cache; error event

* add output_index and content_index

* docstrings

* add test_build_response_event

* docs/comments

* gate test requirements; terminate cb manager on model switch

* nasty type hints

* more type hints

* disable validation by default; enable force models

* todo

---------

Co-authored-by: Lysandre <hi@lysand.re>

* Slight bugfixes

* PR comments from #39338

* make fixup

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Joao Gante <joao@huggingface.co>
This commit is contained in:
Lysandre Debut
2025-07-16 14:16:16 +02:00
committed by GitHub
parent c8524aeb07
commit de5ca373ac
8 changed files with 937 additions and 380 deletions

View File

@@ -71,7 +71,7 @@ vllm serve Qwen/Qwen2.5-1.5B-Instruct \
> This section is experimental and subject to change in future versions > This section is experimental and subject to change in future versions
<!-- TODO: LLMs -> models, after we add audio/image input/output support --> <!-- TODO: LLMs -> models, after we add audio/image input/output support -->
You can serve LLMs supported by `transformers` with the `transformers serve` CLI. It spawns a local server that offers a chat Completions API compatible with the OpenAI SDK, which is the _de facto_ standard for LLM conversations. This way, you can use the server from many third party applications, or test it using the `transformers chat` CLI ([docs](conversations.md#chat-cli)). You can serve LLMs supported by `transformers` with the `transformers serve` CLI. It spawns a local server that offers a Chat Completion API or a Response API compatible with the OpenAI SDK, which are the _de facto_ standard for LLM conversations. This way, you can use the server from many third party applications, or test it using the `transformers chat` CLI ([docs](conversations.md#chat-cli)).
To launch a server, simply use the `transformers serve` CLI command: To launch a server, simply use the `transformers serve` CLI command:

View File

@@ -137,6 +137,7 @@ _deps = [
"onnxconverter-common", "onnxconverter-common",
"onnxruntime-tools>=1.4.2", "onnxruntime-tools>=1.4.2",
"onnxruntime>=1.4.0", "onnxruntime>=1.4.0",
"openai",
"opencv-python", "opencv-python",
"optimum-benchmark>=0.3.0", "optimum-benchmark>=0.3.0",
"optuna", "optuna",
@@ -314,7 +315,7 @@ extras["hub-kernels"] = deps_list("kernels")
extras["integrations"] = extras["hub-kernels"] + extras["optuna"] + extras["ray"] + extras["sigopt"] extras["integrations"] = extras["hub-kernels"] + extras["optuna"] + extras["ray"] + extras["sigopt"]
extras["serving"] = deps_list("pydantic", "uvicorn", "fastapi", "starlette") + extras["torch"] extras["serving"] = deps_list("openai", "pydantic", "uvicorn", "fastapi", "starlette") + extras["torch"]
extras["audio"] = deps_list( extras["audio"] = deps_list(
"librosa", "librosa",
"pyctcdecode", "pyctcdecode",

View File

@@ -471,7 +471,7 @@ class ChatCommand(BaseTransformersCLICommand):
# This is a chat session, so we have a few non-standard defaults # This is a chat session, so we have a few non-standard defaults
# !!!!!!!!! # !!!!!!!!!
generation_config = copy.deepcopy(model_generation_config) generation_config = copy.deepcopy(model_generation_config)
generation_config.update({"do_sample": True, "max_new_tokens": 256}) generation_config.update(**{"do_sample": True, "max_new_tokens": 256})
# Finally: parse and apply `generate_flags` # Finally: parse and apply `generate_flags`
parsed_generate_flags = self.parse_generate_flags(args.generate_flags) parsed_generate_flags = self.parse_generate_flags(args.generate_flags)

File diff suppressed because it is too large Load Diff

View File

@@ -43,6 +43,7 @@ deps = {
"onnxconverter-common": "onnxconverter-common", "onnxconverter-common": "onnxconverter-common",
"onnxruntime-tools": "onnxruntime-tools>=1.4.2", "onnxruntime-tools": "onnxruntime-tools>=1.4.2",
"onnxruntime": "onnxruntime>=1.4.0", "onnxruntime": "onnxruntime>=1.4.0",
"openai": "openai",
"opencv-python": "opencv-python", "opencv-python": "opencv-python",
"optimum-benchmark": "optimum-benchmark>=0.3.0", "optimum-benchmark": "optimum-benchmark>=0.3.0",
"optuna": "optuna", "optuna": "optuna",

View File

@@ -112,6 +112,7 @@ from .utils import (
is_natten_available, is_natten_available,
is_nltk_available, is_nltk_available,
is_onnx_available, is_onnx_available,
is_openai_available,
is_optimum_available, is_optimum_available,
is_optimum_quanto_available, is_optimum_quanto_available,
is_pandas_available, is_pandas_available,
@@ -1536,6 +1537,13 @@ def require_speech(test_case):
return unittest.skipUnless(is_speech_available(), "test requires torchaudio")(test_case) return unittest.skipUnless(is_speech_available(), "test requires torchaudio")(test_case)
def require_openai(test_case):
"""
Decorator marking a test that requires openai
"""
return unittest.skipUnless(is_openai_available(), "test requires openai")(test_case)
def require_mistral_common(test_case): def require_mistral_common(test_case):
""" """
Decorator marking a test that requires mistral-common. These tests are skipped when mistral-common isn't available. Decorator marking a test that requires mistral-common. These tests are skipped when mistral-common isn't available.

View File

@@ -515,6 +515,10 @@ def is_uvicorn_available():
return _uvicorn_available return _uvicorn_available
def is_openai_available():
return _openai_available
def is_pretty_midi_available(): def is_pretty_midi_available():
return _pretty_midi_available return _pretty_midi_available
@@ -730,10 +734,6 @@ def is_onnx_available():
return _onnx_available return _onnx_available
def is_openai_available():
return _openai_available
def is_flax_available(): def is_flax_available():
return _flax_available return _flax_available
@@ -1916,6 +1916,12 @@ UVICORN_IMPORT_ERROR = """
`pip install uvicorn`. Please note that you may need to restart your runtime after installation. `pip install uvicorn`. Please note that you may need to restart your runtime after installation.
""" """
# docstyle-ignore
OPENAI_IMPORT_ERROR = """
{0} requires the openai library but it was not found in your environment. You can install it with pip:
`pip install openai`. Please note that you may need to restart your runtime after installation.
"""
# docstyle-ignore # docstyle-ignore
PYTESSERACT_IMPORT_ERROR = """ PYTESSERACT_IMPORT_ERROR = """
{0} requires the PyTesseract library but it was not found in your environment. You can install it with pip: {0} requires the PyTesseract library but it was not found in your environment. You can install it with pip:
@@ -2046,6 +2052,7 @@ BACKENDS_MAPPING = OrderedDict(
("pydantic", (is_pydantic_available, PYDANTIC_IMPORT_ERROR)), ("pydantic", (is_pydantic_available, PYDANTIC_IMPORT_ERROR)),
("fastapi", (is_fastapi_available, FASTAPI_IMPORT_ERROR)), ("fastapi", (is_fastapi_available, FASTAPI_IMPORT_ERROR)),
("uvicorn", (is_uvicorn_available, UVICORN_IMPORT_ERROR)), ("uvicorn", (is_uvicorn_available, UVICORN_IMPORT_ERROR)),
("openai", (is_openai_available, OPENAI_IMPORT_ERROR)),
("mistral-common", (is_mistral_common_available, MISTRAL_COMMON_IMPORT_ERROR)), ("mistral-common", (is_mistral_common_available, MISTRAL_COMMON_IMPORT_ERROR)),
] ]
) )

View File

@@ -24,9 +24,16 @@ from parameterized import parameterized
import transformers.commands.transformers_cli as cli import transformers.commands.transformers_cli as cli
from transformers import GenerationConfig from transformers import GenerationConfig
from transformers.commands.serving import ServeArguments, ServeCommand from transformers.commands.serving import ServeArguments, ServeCommand
from transformers.testing_utils import CaptureStd, slow from transformers.testing_utils import CaptureStd, require_openai, slow
from transformers.utils.import_utils import is_openai_available
if is_openai_available():
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction
from openai.types.responses import Response, ResponseCreatedEvent
@require_openai
class ServeCLITest(unittest.TestCase): class ServeCLITest(unittest.TestCase):
def test_help(self): def test_help(self):
"""Minimal test: we can invoke the help command.""" """Minimal test: we can invoke the help command."""
@@ -49,36 +56,94 @@ class ServeCLITest(unittest.TestCase):
self.assertEqual(parsed_args.host, "0.0.0.0") self.assertEqual(parsed_args.host, "0.0.0.0")
self.assertEqual(parsed_args.port, 9000) self.assertEqual(parsed_args.port, 9000)
def test_completions_build_chunk(self): def test_build_chat_completion_chunk(self):
"""Tests that the chunks are correctly built for the Completions API.""" """
Tests that the chunks are correctly built for the Chat Completion API. The `choices` checks implictly
confirm that empty fields are not emitted.
"""
dummy = ServeCommand.__new__(ServeCommand) dummy = ServeCommand.__new__(ServeCommand)
dummy.args = type("Args", (), {})() dummy.args = type("Args", (), {})()
dummy.loaded_model = "dummy_model@main"
# The keys for these fields must be present in every chunk
MANDATORY_FIELDS = ["data", "id", "choices", "created", "model", "object", "system_fingerprint"]
# Case 1: most fields are provided # Case 1: most fields are provided
chunk = ServeCommand.build_chunk(dummy, request_id="req0", content="hello", finish_reason="stop", role="user") chunk = ServeCommand.build_chat_completion_chunk(
self.assertIn("chat.completion.chunk", chunk) dummy, request_id="req0", content="hello", finish_reason="stop", role="user"
self.assertIn("data:", chunk) )
for field in MANDATORY_FIELDS:
self.assertIn(field, chunk)
self.assertIn( self.assertIn(
'"choices": [{"delta": {"content": "hello", "role": "user"}, "index": 0, "finish_reason": "stop"}]', chunk '"choices":[{"delta":{"content":"hello","role":"user"},"finish_reason":"stop","index":0}]', chunk
) )
# Case 2: only the role is provided -- other fields in 'choices' are omitted # Case 2: only the role is provided -- other fields in 'choices' are omitted
chunk = ServeCommand.build_chunk(dummy, request_id="req0", role="user") chunk = dummy.build_chat_completion_chunk(request_id="req0", role="user")
self.assertIn("chat.completion.chunk", chunk) for field in MANDATORY_FIELDS:
self.assertIn("data:", chunk) self.assertIn(field, chunk)
self.assertIn('"choices":[{"delta":{"role":"user"},"index":0}]', chunk) self.assertIn('"choices":[{"delta":{"role":"user"},"index":0}]', chunk)
# Case 3: only the content is provided -- other fields in 'choices' are omitted # Case 3: only the content is provided -- other fields in 'choices' are omitted
chunk = ServeCommand.build_chunk(dummy, request_id="req0", content="hello") chunk = dummy.build_chat_completion_chunk(request_id="req0", content="hello")
self.assertIn("chat.completion.chunk", chunk) for field in MANDATORY_FIELDS:
self.assertIn("data:", chunk) self.assertIn(field, chunk)
self.assertIn('"choices":[{"delta":{"content":"hello"},"index":0}]', chunk) self.assertIn('"choices":[{"delta":{"content":"hello"},"index":0}]', chunk)
# Case 4: tool calls support a list of nested dictionaries # Case 4: tool calls support a list of ChoiceDeltaToolCall objects
chunk = ServeCommand.build_chunk(dummy, request_id="req0", tool_calls=[{"foo1": "bar1", "foo2": "bar2"}]) tool_call = ChoiceDeltaToolCall(
self.assertIn("chat.completion.chunk", chunk) index=0,
self.assertIn("data:", chunk) function=ChoiceDeltaToolCallFunction(name="foo_bar", arguments='{"foo1": "bar1", "foo2": "bar2"}'),
self.assertIn('"choices": [{"delta": {"tool_calls": [{"foo1": "bar1", "foo2": "bar2"}]}, "index": 0}]', chunk) type="function",
)
chunk = dummy.build_chat_completion_chunk(request_id="req0", tool_calls=[tool_call])
for field in MANDATORY_FIELDS:
self.assertIn(field, chunk)
expected_choices_content = (
'choices":[{"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\\"foo1\\": \\"bar1\\", '
'\\"foo2\\": \\"bar2\\"}","name":"foo_bar"},"type":"function"}]},"index":0}]'
)
self.assertIn(expected_choices_content, chunk)
def test_build_response_event(self):
"""
Tests that the events are correctly built for the Response API.
Contrarily to the Chat Completion API, the Response API has a wide set of possible output objects. This test
only checks a few basic assumptions -- we rely on OpenAI's pydantic models to enforce the correct schema.
"""
dummy = ServeCommand.__new__(ServeCommand)
dummy.args = type("Args", (), {})()
response_created = ResponseCreatedEvent(
type="response.created",
sequence_number=0,
response=Response(
id="resp_0",
created_at=time.time(),
status="queued",
model="dummy_model@main",
instructions=None, # <--- is set to None = should NOT be in the output.
text={"format": {"type": "text"}},
object="response",
tools=[], # <--- empty lists should be in the output (they are often mandatory fields)
output=[],
parallel_tool_calls=False,
tool_choice="auto",
metadata=None,
),
)
event = dummy.build_response_event(response_created)
self.assertTrue(event.startswith("data: ")) # Sanity check: event formatting
self.assertIn('"model":"dummy_model@main"', event) # Sanity check: set field
self.assertIn('"status":"queued"', event)
self.assertIn("tools", event) # empty lists should be in the output
self.assertIn("output", event)
self.assertNotIn("instructions", event) # None fields should NOT be in the output
self.assertNotIn("metadata", event)
self.assertNotIn("error", event) # Unset optional fields should NOT be in the output
self.assertNotIn("top_p", event)
def async_retry(fn, max_attempts=5, delay=2): def async_retry(fn, max_attempts=5, delay=2):
@@ -105,7 +170,7 @@ class ServeCompletionsMixin:
@async_retry @async_retry
async def run_server(self, request): async def run_server(self, request):
client = AsyncInferenceClient("http://localhost:8000") client = AsyncInferenceClient(f"http://localhost:{self.port}")
stream = client.chat_completion(**request) stream = client.chat_completion(**request)
all_payloads = [] all_payloads = []
@@ -119,8 +184,7 @@ class ServeCompletionsMixin:
[ [
("default_request", {}), ("default_request", {}),
("one_token", {"max_tokens": 1}), ("one_token", {"max_tokens": 1}),
# TODO: CB fails next case, seems like it is unable to switch models. fix me ("different_model", {"model": "HuggingFaceTB/SmolLM2-135M-Instruct"}),
# ("different_model", {"model": "HuggingFaceTB/SmolLM2-135M-Instruct"}),
( (
"tool_call", "tool_call",
{ {
@@ -191,20 +255,20 @@ class ServeCompletionsMixin:
# sets `do_sample=True` # sets `do_sample=True`
self.assertEqual(output_text, '<think>\nOkay, the user just asked, "') self.assertEqual(output_text, '<think>\nOkay, the user just asked, "')
# TODO: implement API-compliant error handling, and then test it
# See https://platform.openai.com/docs/guides/error-codes,
# TODO: one test for each request flag, to confirm it is working as expected # TODO: one test for each request flag, to confirm it is working as expected
# TODO: speed-based test to confirm that KV cache is working across requests # TODO: speed-based test to confirm that KV cache is working across requests
@slow # TODO (joao): this shouldn't be needed @slow # server startup time is slow on our push CI
class ServeCompletionsGenerateTest(ServeCompletionsMixin, unittest.TestCase): @require_openai
class ServeCompletionsGenerateIntegrationTest(ServeCompletionsMixin, unittest.TestCase):
"""Tests the `generate` version of the Completions API.""" """Tests the `generate` version of the Completions API."""
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
"""Starts a server for tests to connect to.""" """Starts a server for tests to connect to."""
args = ServeArguments() cls.port = 8001
args = ServeArguments(port=cls.port)
serve_command = ServeCommand(args) serve_command = ServeCommand(args)
thread = Thread(target=serve_command.run) thread = Thread(target=serve_command.run)
thread.daemon = True thread.daemon = True
@@ -287,15 +351,20 @@ class ServeCompletionsGenerateTest(ServeCompletionsMixin, unittest.TestCase):
self.assertTrue(all(reason is None for reason in finish_reasons[:-1])) self.assertTrue(all(reason is None for reason in finish_reasons[:-1]))
@slow # TODO (joao): this shouldn't be needed @slow # server startup time is slow on our push CI
class ServeCompletionsContinuousBatchingTest(ServeCompletionsMixin, unittest.TestCase): @require_openai
class ServeCompletionsContinuousBatchingIntegrationTest(ServeCompletionsMixin, unittest.TestCase):
"""Tests the `continuous_batching` version of the Completions API.""" """Tests the `continuous_batching` version of the Completions API."""
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
"""Starts a server for tests to connect to.""" """Starts a server for tests to connect to."""
args = ServeArguments(attn_implementation="sdpa_paged") # important: toggle continuous batching cls.port = 8002
args = ServeArguments(port=cls.port, attn_implementation="sdpa_paged") # important: toggle continuous batching
serve_command = ServeCommand(args) serve_command = ServeCommand(args)
thread = Thread(target=serve_command.run) thread = Thread(target=serve_command.run)
thread.daemon = True thread.daemon = True
thread.start() thread.start()
# TODO: Response integration tests