Responses API in transformers serve (#39155)
* Scaffolding * Explicit content * Naïve Responses API streaming implementation * Cleanup * Responses API (to be merged into #39155) (#39338) * Scaffolding * Explicit content * Naïve Responses API streaming implementation * Cleanup * use openai * validate request, including detecting unused fields * dict indexing * dict var access * tmp commit (tests failing) * add slow * use oai output type in completions * (little rebase errors) * working spec? * guard type hint * type hints. fix state (CB can now load different models) * type hints; fn names; error type * add docstrings * responses + kv cache * metadata support; fix kv cache; error event * add output_index and content_index * docstrings * add test_build_response_event * docs/comments * gate test requirements; terminate cb manager on model switch * nasty type hints * more type hints * disable validation by default; enable force models * todo --------- Co-authored-by: Lysandre <hi@lysand.re> * Slight bugfixes * PR comments from #39338 * make fixup --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Joao Gante <joao@huggingface.co>
This commit is contained in:
@@ -71,7 +71,7 @@ vllm serve Qwen/Qwen2.5-1.5B-Instruct \
|
|||||||
> This section is experimental and subject to change in future versions
|
> This section is experimental and subject to change in future versions
|
||||||
|
|
||||||
<!-- TODO: LLMs -> models, after we add audio/image input/output support -->
|
<!-- TODO: LLMs -> models, after we add audio/image input/output support -->
|
||||||
You can serve LLMs supported by `transformers` with the `transformers serve` CLI. It spawns a local server that offers a chat Completions API compatible with the OpenAI SDK, which is the _de facto_ standard for LLM conversations. This way, you can use the server from many third party applications, or test it using the `transformers chat` CLI ([docs](conversations.md#chat-cli)).
|
You can serve LLMs supported by `transformers` with the `transformers serve` CLI. It spawns a local server that offers a Chat Completion API or a Response API compatible with the OpenAI SDK, which are the _de facto_ standard for LLM conversations. This way, you can use the server from many third party applications, or test it using the `transformers chat` CLI ([docs](conversations.md#chat-cli)).
|
||||||
|
|
||||||
To launch a server, simply use the `transformers serve` CLI command:
|
To launch a server, simply use the `transformers serve` CLI command:
|
||||||
|
|
||||||
|
|||||||
3
setup.py
3
setup.py
@@ -137,6 +137,7 @@ _deps = [
|
|||||||
"onnxconverter-common",
|
"onnxconverter-common",
|
||||||
"onnxruntime-tools>=1.4.2",
|
"onnxruntime-tools>=1.4.2",
|
||||||
"onnxruntime>=1.4.0",
|
"onnxruntime>=1.4.0",
|
||||||
|
"openai",
|
||||||
"opencv-python",
|
"opencv-python",
|
||||||
"optimum-benchmark>=0.3.0",
|
"optimum-benchmark>=0.3.0",
|
||||||
"optuna",
|
"optuna",
|
||||||
@@ -314,7 +315,7 @@ extras["hub-kernels"] = deps_list("kernels")
|
|||||||
|
|
||||||
extras["integrations"] = extras["hub-kernels"] + extras["optuna"] + extras["ray"] + extras["sigopt"]
|
extras["integrations"] = extras["hub-kernels"] + extras["optuna"] + extras["ray"] + extras["sigopt"]
|
||||||
|
|
||||||
extras["serving"] = deps_list("pydantic", "uvicorn", "fastapi", "starlette") + extras["torch"]
|
extras["serving"] = deps_list("openai", "pydantic", "uvicorn", "fastapi", "starlette") + extras["torch"]
|
||||||
extras["audio"] = deps_list(
|
extras["audio"] = deps_list(
|
||||||
"librosa",
|
"librosa",
|
||||||
"pyctcdecode",
|
"pyctcdecode",
|
||||||
|
|||||||
@@ -471,7 +471,7 @@ class ChatCommand(BaseTransformersCLICommand):
|
|||||||
# This is a chat session, so we have a few non-standard defaults
|
# This is a chat session, so we have a few non-standard defaults
|
||||||
# !!!!!!!!!
|
# !!!!!!!!!
|
||||||
generation_config = copy.deepcopy(model_generation_config)
|
generation_config = copy.deepcopy(model_generation_config)
|
||||||
generation_config.update({"do_sample": True, "max_new_tokens": 256})
|
generation_config.update(**{"do_sample": True, "max_new_tokens": 256})
|
||||||
|
|
||||||
# Finally: parse and apply `generate_flags`
|
# Finally: parse and apply `generate_flags`
|
||||||
parsed_generate_flags = self.parse_generate_flags(args.generate_flags)
|
parsed_generate_flags = self.parse_generate_flags(args.generate_flags)
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -43,6 +43,7 @@ deps = {
|
|||||||
"onnxconverter-common": "onnxconverter-common",
|
"onnxconverter-common": "onnxconverter-common",
|
||||||
"onnxruntime-tools": "onnxruntime-tools>=1.4.2",
|
"onnxruntime-tools": "onnxruntime-tools>=1.4.2",
|
||||||
"onnxruntime": "onnxruntime>=1.4.0",
|
"onnxruntime": "onnxruntime>=1.4.0",
|
||||||
|
"openai": "openai",
|
||||||
"opencv-python": "opencv-python",
|
"opencv-python": "opencv-python",
|
||||||
"optimum-benchmark": "optimum-benchmark>=0.3.0",
|
"optimum-benchmark": "optimum-benchmark>=0.3.0",
|
||||||
"optuna": "optuna",
|
"optuna": "optuna",
|
||||||
|
|||||||
@@ -112,6 +112,7 @@ from .utils import (
|
|||||||
is_natten_available,
|
is_natten_available,
|
||||||
is_nltk_available,
|
is_nltk_available,
|
||||||
is_onnx_available,
|
is_onnx_available,
|
||||||
|
is_openai_available,
|
||||||
is_optimum_available,
|
is_optimum_available,
|
||||||
is_optimum_quanto_available,
|
is_optimum_quanto_available,
|
||||||
is_pandas_available,
|
is_pandas_available,
|
||||||
@@ -1536,6 +1537,13 @@ def require_speech(test_case):
|
|||||||
return unittest.skipUnless(is_speech_available(), "test requires torchaudio")(test_case)
|
return unittest.skipUnless(is_speech_available(), "test requires torchaudio")(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def require_openai(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test that requires openai
|
||||||
|
"""
|
||||||
|
return unittest.skipUnless(is_openai_available(), "test requires openai")(test_case)
|
||||||
|
|
||||||
|
|
||||||
def require_mistral_common(test_case):
|
def require_mistral_common(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that requires mistral-common. These tests are skipped when mistral-common isn't available.
|
Decorator marking a test that requires mistral-common. These tests are skipped when mistral-common isn't available.
|
||||||
|
|||||||
@@ -515,6 +515,10 @@ def is_uvicorn_available():
|
|||||||
return _uvicorn_available
|
return _uvicorn_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_openai_available():
|
||||||
|
return _openai_available
|
||||||
|
|
||||||
|
|
||||||
def is_pretty_midi_available():
|
def is_pretty_midi_available():
|
||||||
return _pretty_midi_available
|
return _pretty_midi_available
|
||||||
|
|
||||||
@@ -730,10 +734,6 @@ def is_onnx_available():
|
|||||||
return _onnx_available
|
return _onnx_available
|
||||||
|
|
||||||
|
|
||||||
def is_openai_available():
|
|
||||||
return _openai_available
|
|
||||||
|
|
||||||
|
|
||||||
def is_flax_available():
|
def is_flax_available():
|
||||||
return _flax_available
|
return _flax_available
|
||||||
|
|
||||||
@@ -1916,6 +1916,12 @@ UVICORN_IMPORT_ERROR = """
|
|||||||
`pip install uvicorn`. Please note that you may need to restart your runtime after installation.
|
`pip install uvicorn`. Please note that you may need to restart your runtime after installation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# docstyle-ignore
|
||||||
|
OPENAI_IMPORT_ERROR = """
|
||||||
|
{0} requires the openai library but it was not found in your environment. You can install it with pip:
|
||||||
|
`pip install openai`. Please note that you may need to restart your runtime after installation.
|
||||||
|
"""
|
||||||
|
|
||||||
# docstyle-ignore
|
# docstyle-ignore
|
||||||
PYTESSERACT_IMPORT_ERROR = """
|
PYTESSERACT_IMPORT_ERROR = """
|
||||||
{0} requires the PyTesseract library but it was not found in your environment. You can install it with pip:
|
{0} requires the PyTesseract library but it was not found in your environment. You can install it with pip:
|
||||||
@@ -2046,6 +2052,7 @@ BACKENDS_MAPPING = OrderedDict(
|
|||||||
("pydantic", (is_pydantic_available, PYDANTIC_IMPORT_ERROR)),
|
("pydantic", (is_pydantic_available, PYDANTIC_IMPORT_ERROR)),
|
||||||
("fastapi", (is_fastapi_available, FASTAPI_IMPORT_ERROR)),
|
("fastapi", (is_fastapi_available, FASTAPI_IMPORT_ERROR)),
|
||||||
("uvicorn", (is_uvicorn_available, UVICORN_IMPORT_ERROR)),
|
("uvicorn", (is_uvicorn_available, UVICORN_IMPORT_ERROR)),
|
||||||
|
("openai", (is_openai_available, OPENAI_IMPORT_ERROR)),
|
||||||
("mistral-common", (is_mistral_common_available, MISTRAL_COMMON_IMPORT_ERROR)),
|
("mistral-common", (is_mistral_common_available, MISTRAL_COMMON_IMPORT_ERROR)),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -24,9 +24,16 @@ from parameterized import parameterized
|
|||||||
import transformers.commands.transformers_cli as cli
|
import transformers.commands.transformers_cli as cli
|
||||||
from transformers import GenerationConfig
|
from transformers import GenerationConfig
|
||||||
from transformers.commands.serving import ServeArguments, ServeCommand
|
from transformers.commands.serving import ServeArguments, ServeCommand
|
||||||
from transformers.testing_utils import CaptureStd, slow
|
from transformers.testing_utils import CaptureStd, require_openai, slow
|
||||||
|
from transformers.utils.import_utils import is_openai_available
|
||||||
|
|
||||||
|
|
||||||
|
if is_openai_available():
|
||||||
|
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction
|
||||||
|
from openai.types.responses import Response, ResponseCreatedEvent
|
||||||
|
|
||||||
|
|
||||||
|
@require_openai
|
||||||
class ServeCLITest(unittest.TestCase):
|
class ServeCLITest(unittest.TestCase):
|
||||||
def test_help(self):
|
def test_help(self):
|
||||||
"""Minimal test: we can invoke the help command."""
|
"""Minimal test: we can invoke the help command."""
|
||||||
@@ -49,36 +56,94 @@ class ServeCLITest(unittest.TestCase):
|
|||||||
self.assertEqual(parsed_args.host, "0.0.0.0")
|
self.assertEqual(parsed_args.host, "0.0.0.0")
|
||||||
self.assertEqual(parsed_args.port, 9000)
|
self.assertEqual(parsed_args.port, 9000)
|
||||||
|
|
||||||
def test_completions_build_chunk(self):
|
def test_build_chat_completion_chunk(self):
|
||||||
"""Tests that the chunks are correctly built for the Completions API."""
|
"""
|
||||||
|
Tests that the chunks are correctly built for the Chat Completion API. The `choices` checks implictly
|
||||||
|
confirm that empty fields are not emitted.
|
||||||
|
"""
|
||||||
dummy = ServeCommand.__new__(ServeCommand)
|
dummy = ServeCommand.__new__(ServeCommand)
|
||||||
dummy.args = type("Args", (), {})()
|
dummy.args = type("Args", (), {})()
|
||||||
|
dummy.loaded_model = "dummy_model@main"
|
||||||
|
|
||||||
|
# The keys for these fields must be present in every chunk
|
||||||
|
MANDATORY_FIELDS = ["data", "id", "choices", "created", "model", "object", "system_fingerprint"]
|
||||||
|
|
||||||
# Case 1: most fields are provided
|
# Case 1: most fields are provided
|
||||||
chunk = ServeCommand.build_chunk(dummy, request_id="req0", content="hello", finish_reason="stop", role="user")
|
chunk = ServeCommand.build_chat_completion_chunk(
|
||||||
self.assertIn("chat.completion.chunk", chunk)
|
dummy, request_id="req0", content="hello", finish_reason="stop", role="user"
|
||||||
self.assertIn("data:", chunk)
|
)
|
||||||
|
for field in MANDATORY_FIELDS:
|
||||||
|
self.assertIn(field, chunk)
|
||||||
self.assertIn(
|
self.assertIn(
|
||||||
'"choices": [{"delta": {"content": "hello", "role": "user"}, "index": 0, "finish_reason": "stop"}]', chunk
|
'"choices":[{"delta":{"content":"hello","role":"user"},"finish_reason":"stop","index":0}]', chunk
|
||||||
)
|
)
|
||||||
|
|
||||||
# Case 2: only the role is provided -- other fields in 'choices' are omitted
|
# Case 2: only the role is provided -- other fields in 'choices' are omitted
|
||||||
chunk = ServeCommand.build_chunk(dummy, request_id="req0", role="user")
|
chunk = dummy.build_chat_completion_chunk(request_id="req0", role="user")
|
||||||
self.assertIn("chat.completion.chunk", chunk)
|
for field in MANDATORY_FIELDS:
|
||||||
self.assertIn("data:", chunk)
|
self.assertIn(field, chunk)
|
||||||
self.assertIn('"choices": [{"delta": {"role": "user"}, "index": 0}]', chunk)
|
self.assertIn('"choices":[{"delta":{"role":"user"},"index":0}]', chunk)
|
||||||
|
|
||||||
# Case 3: only the content is provided -- other fields in 'choices' are omitted
|
# Case 3: only the content is provided -- other fields in 'choices' are omitted
|
||||||
chunk = ServeCommand.build_chunk(dummy, request_id="req0", content="hello")
|
chunk = dummy.build_chat_completion_chunk(request_id="req0", content="hello")
|
||||||
self.assertIn("chat.completion.chunk", chunk)
|
for field in MANDATORY_FIELDS:
|
||||||
self.assertIn("data:", chunk)
|
self.assertIn(field, chunk)
|
||||||
self.assertIn('"choices": [{"delta": {"content": "hello"}, "index": 0}]', chunk)
|
self.assertIn('"choices":[{"delta":{"content":"hello"},"index":0}]', chunk)
|
||||||
|
|
||||||
# Case 4: tool calls support a list of nested dictionaries
|
# Case 4: tool calls support a list of ChoiceDeltaToolCall objects
|
||||||
chunk = ServeCommand.build_chunk(dummy, request_id="req0", tool_calls=[{"foo1": "bar1", "foo2": "bar2"}])
|
tool_call = ChoiceDeltaToolCall(
|
||||||
self.assertIn("chat.completion.chunk", chunk)
|
index=0,
|
||||||
self.assertIn("data:", chunk)
|
function=ChoiceDeltaToolCallFunction(name="foo_bar", arguments='{"foo1": "bar1", "foo2": "bar2"}'),
|
||||||
self.assertIn('"choices": [{"delta": {"tool_calls": [{"foo1": "bar1", "foo2": "bar2"}]}, "index": 0}]', chunk)
|
type="function",
|
||||||
|
)
|
||||||
|
chunk = dummy.build_chat_completion_chunk(request_id="req0", tool_calls=[tool_call])
|
||||||
|
for field in MANDATORY_FIELDS:
|
||||||
|
self.assertIn(field, chunk)
|
||||||
|
expected_choices_content = (
|
||||||
|
'choices":[{"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\\"foo1\\": \\"bar1\\", '
|
||||||
|
'\\"foo2\\": \\"bar2\\"}","name":"foo_bar"},"type":"function"}]},"index":0}]'
|
||||||
|
)
|
||||||
|
self.assertIn(expected_choices_content, chunk)
|
||||||
|
|
||||||
|
def test_build_response_event(self):
|
||||||
|
"""
|
||||||
|
Tests that the events are correctly built for the Response API.
|
||||||
|
|
||||||
|
Contrarily to the Chat Completion API, the Response API has a wide set of possible output objects. This test
|
||||||
|
only checks a few basic assumptions -- we rely on OpenAI's pydantic models to enforce the correct schema.
|
||||||
|
"""
|
||||||
|
dummy = ServeCommand.__new__(ServeCommand)
|
||||||
|
dummy.args = type("Args", (), {})()
|
||||||
|
|
||||||
|
response_created = ResponseCreatedEvent(
|
||||||
|
type="response.created",
|
||||||
|
sequence_number=0,
|
||||||
|
response=Response(
|
||||||
|
id="resp_0",
|
||||||
|
created_at=time.time(),
|
||||||
|
status="queued",
|
||||||
|
model="dummy_model@main",
|
||||||
|
instructions=None, # <--- is set to None = should NOT be in the output.
|
||||||
|
text={"format": {"type": "text"}},
|
||||||
|
object="response",
|
||||||
|
tools=[], # <--- empty lists should be in the output (they are often mandatory fields)
|
||||||
|
output=[],
|
||||||
|
parallel_tool_calls=False,
|
||||||
|
tool_choice="auto",
|
||||||
|
metadata=None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
event = dummy.build_response_event(response_created)
|
||||||
|
self.assertTrue(event.startswith("data: ")) # Sanity check: event formatting
|
||||||
|
self.assertIn('"model":"dummy_model@main"', event) # Sanity check: set field
|
||||||
|
self.assertIn('"status":"queued"', event)
|
||||||
|
self.assertIn("tools", event) # empty lists should be in the output
|
||||||
|
self.assertIn("output", event)
|
||||||
|
self.assertNotIn("instructions", event) # None fields should NOT be in the output
|
||||||
|
self.assertNotIn("metadata", event)
|
||||||
|
self.assertNotIn("error", event) # Unset optional fields should NOT be in the output
|
||||||
|
self.assertNotIn("top_p", event)
|
||||||
|
|
||||||
|
|
||||||
def async_retry(fn, max_attempts=5, delay=2):
|
def async_retry(fn, max_attempts=5, delay=2):
|
||||||
@@ -105,7 +170,7 @@ class ServeCompletionsMixin:
|
|||||||
|
|
||||||
@async_retry
|
@async_retry
|
||||||
async def run_server(self, request):
|
async def run_server(self, request):
|
||||||
client = AsyncInferenceClient("http://localhost:8000")
|
client = AsyncInferenceClient(f"http://localhost:{self.port}")
|
||||||
stream = client.chat_completion(**request)
|
stream = client.chat_completion(**request)
|
||||||
|
|
||||||
all_payloads = []
|
all_payloads = []
|
||||||
@@ -119,8 +184,7 @@ class ServeCompletionsMixin:
|
|||||||
[
|
[
|
||||||
("default_request", {}),
|
("default_request", {}),
|
||||||
("one_token", {"max_tokens": 1}),
|
("one_token", {"max_tokens": 1}),
|
||||||
# TODO: CB fails next case, seems like it is unable to switch models. fix me
|
("different_model", {"model": "HuggingFaceTB/SmolLM2-135M-Instruct"}),
|
||||||
# ("different_model", {"model": "HuggingFaceTB/SmolLM2-135M-Instruct"}),
|
|
||||||
(
|
(
|
||||||
"tool_call",
|
"tool_call",
|
||||||
{
|
{
|
||||||
@@ -191,20 +255,20 @@ class ServeCompletionsMixin:
|
|||||||
# sets `do_sample=True`
|
# sets `do_sample=True`
|
||||||
self.assertEqual(output_text, '<think>\nOkay, the user just asked, "')
|
self.assertEqual(output_text, '<think>\nOkay, the user just asked, "')
|
||||||
|
|
||||||
# TODO: implement API-compliant error handling, and then test it
|
|
||||||
# See https://platform.openai.com/docs/guides/error-codes,
|
|
||||||
# TODO: one test for each request flag, to confirm it is working as expected
|
# TODO: one test for each request flag, to confirm it is working as expected
|
||||||
# TODO: speed-based test to confirm that KV cache is working across requests
|
# TODO: speed-based test to confirm that KV cache is working across requests
|
||||||
|
|
||||||
|
|
||||||
@slow # TODO (joao): this shouldn't be needed
|
@slow # server startup time is slow on our push CI
|
||||||
class ServeCompletionsGenerateTest(ServeCompletionsMixin, unittest.TestCase):
|
@require_openai
|
||||||
|
class ServeCompletionsGenerateIntegrationTest(ServeCompletionsMixin, unittest.TestCase):
|
||||||
"""Tests the `generate` version of the Completions API."""
|
"""Tests the `generate` version of the Completions API."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
"""Starts a server for tests to connect to."""
|
"""Starts a server for tests to connect to."""
|
||||||
args = ServeArguments()
|
cls.port = 8001
|
||||||
|
args = ServeArguments(port=cls.port)
|
||||||
serve_command = ServeCommand(args)
|
serve_command = ServeCommand(args)
|
||||||
thread = Thread(target=serve_command.run)
|
thread = Thread(target=serve_command.run)
|
||||||
thread.daemon = True
|
thread.daemon = True
|
||||||
@@ -287,15 +351,20 @@ class ServeCompletionsGenerateTest(ServeCompletionsMixin, unittest.TestCase):
|
|||||||
self.assertTrue(all(reason is None for reason in finish_reasons[:-1]))
|
self.assertTrue(all(reason is None for reason in finish_reasons[:-1]))
|
||||||
|
|
||||||
|
|
||||||
@slow # TODO (joao): this shouldn't be needed
|
@slow # server startup time is slow on our push CI
|
||||||
class ServeCompletionsContinuousBatchingTest(ServeCompletionsMixin, unittest.TestCase):
|
@require_openai
|
||||||
|
class ServeCompletionsContinuousBatchingIntegrationTest(ServeCompletionsMixin, unittest.TestCase):
|
||||||
"""Tests the `continuous_batching` version of the Completions API."""
|
"""Tests the `continuous_batching` version of the Completions API."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
"""Starts a server for tests to connect to."""
|
"""Starts a server for tests to connect to."""
|
||||||
args = ServeArguments(attn_implementation="sdpa_paged") # important: toggle continuous batching
|
cls.port = 8002
|
||||||
|
args = ServeArguments(port=cls.port, attn_implementation="sdpa_paged") # important: toggle continuous batching
|
||||||
serve_command = ServeCommand(args)
|
serve_command = ServeCommand(args)
|
||||||
thread = Thread(target=serve_command.run)
|
thread = Thread(target=serve_command.run)
|
||||||
thread.daemon = True
|
thread.daemon = True
|
||||||
thread.start()
|
thread.start()
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Response integration tests
|
||||||
|
|||||||
Reference in New Issue
Block a user