Responses API in transformers serve (#39155)
* Scaffolding * Explicit content * Naïve Responses API streaming implementation * Cleanup * Responses API (to be merged into #39155) (#39338) * Scaffolding * Explicit content * Naïve Responses API streaming implementation * Cleanup * use openai * validate request, including detecting unused fields * dict indexing * dict var access * tmp commit (tests failing) * add slow * use oai output type in completions * (little rebase errors) * working spec? * guard type hint * type hints. fix state (CB can now load different models) * type hints; fn names; error type * add docstrings * responses + kv cache * metadata support; fix kv cache; error event * add output_index and content_index * docstrings * add test_build_response_event * docs/comments * gate test requirements; terminate cb manager on model switch * nasty type hints * more type hints * disable validation by default; enable force models * todo --------- Co-authored-by: Lysandre <hi@lysand.re> * Slight bugfixes * PR comments from #39338 * make fixup --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Joao Gante <joao@huggingface.co>
This commit is contained in:
@@ -24,9 +24,16 @@ from parameterized import parameterized
|
||||
import transformers.commands.transformers_cli as cli
|
||||
from transformers import GenerationConfig
|
||||
from transformers.commands.serving import ServeArguments, ServeCommand
|
||||
from transformers.testing_utils import CaptureStd, slow
|
||||
from transformers.testing_utils import CaptureStd, require_openai, slow
|
||||
from transformers.utils.import_utils import is_openai_available
|
||||
|
||||
|
||||
if is_openai_available():
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction
|
||||
from openai.types.responses import Response, ResponseCreatedEvent
|
||||
|
||||
|
||||
@require_openai
|
||||
class ServeCLITest(unittest.TestCase):
|
||||
def test_help(self):
|
||||
"""Minimal test: we can invoke the help command."""
|
||||
@@ -49,36 +56,94 @@ class ServeCLITest(unittest.TestCase):
|
||||
self.assertEqual(parsed_args.host, "0.0.0.0")
|
||||
self.assertEqual(parsed_args.port, 9000)
|
||||
|
||||
def test_completions_build_chunk(self):
|
||||
"""Tests that the chunks are correctly built for the Completions API."""
|
||||
def test_build_chat_completion_chunk(self):
|
||||
"""
|
||||
Tests that the chunks are correctly built for the Chat Completion API. The `choices` checks implictly
|
||||
confirm that empty fields are not emitted.
|
||||
"""
|
||||
dummy = ServeCommand.__new__(ServeCommand)
|
||||
dummy.args = type("Args", (), {})()
|
||||
dummy.loaded_model = "dummy_model@main"
|
||||
|
||||
# The keys for these fields must be present in every chunk
|
||||
MANDATORY_FIELDS = ["data", "id", "choices", "created", "model", "object", "system_fingerprint"]
|
||||
|
||||
# Case 1: most fields are provided
|
||||
chunk = ServeCommand.build_chunk(dummy, request_id="req0", content="hello", finish_reason="stop", role="user")
|
||||
self.assertIn("chat.completion.chunk", chunk)
|
||||
self.assertIn("data:", chunk)
|
||||
chunk = ServeCommand.build_chat_completion_chunk(
|
||||
dummy, request_id="req0", content="hello", finish_reason="stop", role="user"
|
||||
)
|
||||
for field in MANDATORY_FIELDS:
|
||||
self.assertIn(field, chunk)
|
||||
self.assertIn(
|
||||
'"choices": [{"delta": {"content": "hello", "role": "user"}, "index": 0, "finish_reason": "stop"}]', chunk
|
||||
'"choices":[{"delta":{"content":"hello","role":"user"},"finish_reason":"stop","index":0}]', chunk
|
||||
)
|
||||
|
||||
# Case 2: only the role is provided -- other fields in 'choices' are omitted
|
||||
chunk = ServeCommand.build_chunk(dummy, request_id="req0", role="user")
|
||||
self.assertIn("chat.completion.chunk", chunk)
|
||||
self.assertIn("data:", chunk)
|
||||
self.assertIn('"choices": [{"delta": {"role": "user"}, "index": 0}]', chunk)
|
||||
chunk = dummy.build_chat_completion_chunk(request_id="req0", role="user")
|
||||
for field in MANDATORY_FIELDS:
|
||||
self.assertIn(field, chunk)
|
||||
self.assertIn('"choices":[{"delta":{"role":"user"},"index":0}]', chunk)
|
||||
|
||||
# Case 3: only the content is provided -- other fields in 'choices' are omitted
|
||||
chunk = ServeCommand.build_chunk(dummy, request_id="req0", content="hello")
|
||||
self.assertIn("chat.completion.chunk", chunk)
|
||||
self.assertIn("data:", chunk)
|
||||
self.assertIn('"choices": [{"delta": {"content": "hello"}, "index": 0}]', chunk)
|
||||
chunk = dummy.build_chat_completion_chunk(request_id="req0", content="hello")
|
||||
for field in MANDATORY_FIELDS:
|
||||
self.assertIn(field, chunk)
|
||||
self.assertIn('"choices":[{"delta":{"content":"hello"},"index":0}]', chunk)
|
||||
|
||||
# Case 4: tool calls support a list of nested dictionaries
|
||||
chunk = ServeCommand.build_chunk(dummy, request_id="req0", tool_calls=[{"foo1": "bar1", "foo2": "bar2"}])
|
||||
self.assertIn("chat.completion.chunk", chunk)
|
||||
self.assertIn("data:", chunk)
|
||||
self.assertIn('"choices": [{"delta": {"tool_calls": [{"foo1": "bar1", "foo2": "bar2"}]}, "index": 0}]', chunk)
|
||||
# Case 4: tool calls support a list of ChoiceDeltaToolCall objects
|
||||
tool_call = ChoiceDeltaToolCall(
|
||||
index=0,
|
||||
function=ChoiceDeltaToolCallFunction(name="foo_bar", arguments='{"foo1": "bar1", "foo2": "bar2"}'),
|
||||
type="function",
|
||||
)
|
||||
chunk = dummy.build_chat_completion_chunk(request_id="req0", tool_calls=[tool_call])
|
||||
for field in MANDATORY_FIELDS:
|
||||
self.assertIn(field, chunk)
|
||||
expected_choices_content = (
|
||||
'choices":[{"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\\"foo1\\": \\"bar1\\", '
|
||||
'\\"foo2\\": \\"bar2\\"}","name":"foo_bar"},"type":"function"}]},"index":0}]'
|
||||
)
|
||||
self.assertIn(expected_choices_content, chunk)
|
||||
|
||||
def test_build_response_event(self):
|
||||
"""
|
||||
Tests that the events are correctly built for the Response API.
|
||||
|
||||
Contrarily to the Chat Completion API, the Response API has a wide set of possible output objects. This test
|
||||
only checks a few basic assumptions -- we rely on OpenAI's pydantic models to enforce the correct schema.
|
||||
"""
|
||||
dummy = ServeCommand.__new__(ServeCommand)
|
||||
dummy.args = type("Args", (), {})()
|
||||
|
||||
response_created = ResponseCreatedEvent(
|
||||
type="response.created",
|
||||
sequence_number=0,
|
||||
response=Response(
|
||||
id="resp_0",
|
||||
created_at=time.time(),
|
||||
status="queued",
|
||||
model="dummy_model@main",
|
||||
instructions=None, # <--- is set to None = should NOT be in the output.
|
||||
text={"format": {"type": "text"}},
|
||||
object="response",
|
||||
tools=[], # <--- empty lists should be in the output (they are often mandatory fields)
|
||||
output=[],
|
||||
parallel_tool_calls=False,
|
||||
tool_choice="auto",
|
||||
metadata=None,
|
||||
),
|
||||
)
|
||||
|
||||
event = dummy.build_response_event(response_created)
|
||||
self.assertTrue(event.startswith("data: ")) # Sanity check: event formatting
|
||||
self.assertIn('"model":"dummy_model@main"', event) # Sanity check: set field
|
||||
self.assertIn('"status":"queued"', event)
|
||||
self.assertIn("tools", event) # empty lists should be in the output
|
||||
self.assertIn("output", event)
|
||||
self.assertNotIn("instructions", event) # None fields should NOT be in the output
|
||||
self.assertNotIn("metadata", event)
|
||||
self.assertNotIn("error", event) # Unset optional fields should NOT be in the output
|
||||
self.assertNotIn("top_p", event)
|
||||
|
||||
|
||||
def async_retry(fn, max_attempts=5, delay=2):
|
||||
@@ -105,7 +170,7 @@ class ServeCompletionsMixin:
|
||||
|
||||
@async_retry
|
||||
async def run_server(self, request):
|
||||
client = AsyncInferenceClient("http://localhost:8000")
|
||||
client = AsyncInferenceClient(f"http://localhost:{self.port}")
|
||||
stream = client.chat_completion(**request)
|
||||
|
||||
all_payloads = []
|
||||
@@ -119,8 +184,7 @@ class ServeCompletionsMixin:
|
||||
[
|
||||
("default_request", {}),
|
||||
("one_token", {"max_tokens": 1}),
|
||||
# TODO: CB fails next case, seems like it is unable to switch models. fix me
|
||||
# ("different_model", {"model": "HuggingFaceTB/SmolLM2-135M-Instruct"}),
|
||||
("different_model", {"model": "HuggingFaceTB/SmolLM2-135M-Instruct"}),
|
||||
(
|
||||
"tool_call",
|
||||
{
|
||||
@@ -191,20 +255,20 @@ class ServeCompletionsMixin:
|
||||
# sets `do_sample=True`
|
||||
self.assertEqual(output_text, '<think>\nOkay, the user just asked, "')
|
||||
|
||||
# TODO: implement API-compliant error handling, and then test it
|
||||
# See https://platform.openai.com/docs/guides/error-codes,
|
||||
# TODO: one test for each request flag, to confirm it is working as expected
|
||||
# TODO: speed-based test to confirm that KV cache is working across requests
|
||||
|
||||
|
||||
@slow # TODO (joao): this shouldn't be needed
|
||||
class ServeCompletionsGenerateTest(ServeCompletionsMixin, unittest.TestCase):
|
||||
@slow # server startup time is slow on our push CI
|
||||
@require_openai
|
||||
class ServeCompletionsGenerateIntegrationTest(ServeCompletionsMixin, unittest.TestCase):
|
||||
"""Tests the `generate` version of the Completions API."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""Starts a server for tests to connect to."""
|
||||
args = ServeArguments()
|
||||
cls.port = 8001
|
||||
args = ServeArguments(port=cls.port)
|
||||
serve_command = ServeCommand(args)
|
||||
thread = Thread(target=serve_command.run)
|
||||
thread.daemon = True
|
||||
@@ -287,15 +351,20 @@ class ServeCompletionsGenerateTest(ServeCompletionsMixin, unittest.TestCase):
|
||||
self.assertTrue(all(reason is None for reason in finish_reasons[:-1]))
|
||||
|
||||
|
||||
@slow # TODO (joao): this shouldn't be needed
|
||||
class ServeCompletionsContinuousBatchingTest(ServeCompletionsMixin, unittest.TestCase):
|
||||
@slow # server startup time is slow on our push CI
|
||||
@require_openai
|
||||
class ServeCompletionsContinuousBatchingIntegrationTest(ServeCompletionsMixin, unittest.TestCase):
|
||||
"""Tests the `continuous_batching` version of the Completions API."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""Starts a server for tests to connect to."""
|
||||
args = ServeArguments(attn_implementation="sdpa_paged") # important: toggle continuous batching
|
||||
cls.port = 8002
|
||||
args = ServeArguments(port=cls.port, attn_implementation="sdpa_paged") # important: toggle continuous batching
|
||||
serve_command = ServeCommand(args)
|
||||
thread = Thread(target=serve_command.run)
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
|
||||
|
||||
# TODO: Response integration tests
|
||||
|
||||
Reference in New Issue
Block a user