diff --git a/docs/source/en/serving.md b/docs/source/en/serving.md
index 5f73f7e136..cad4cbeb41 100644
--- a/docs/source/en/serving.md
+++ b/docs/source/en/serving.md
@@ -70,8 +70,13 @@ vllm serve Qwen/Qwen2.5-1.5B-Instruct \
> [!WARNING]
> This section is experimental and subject to change in future versions
-
-You can serve LLMs supported by `transformers` with the `transformers serve` CLI. It spawns a local server that offers a Chat Completion API or a Response API compatible with the OpenAI SDK, which are the _de facto_ standard for LLM conversations. This way, you can use the server from many third party applications, or test it using the `transformers chat` CLI ([docs](conversations.md#chat-cli)).
+You can serve models of diverse modalities supported by `transformers` with the `transformers serve` CLI. It spawns a local server that offers compatibility with the OpenAI SDK, which is the _de facto_ standard for LLM conversations and other related tasks. This way, you can use the server from many third party applications, or test it using the `transformers chat` CLI ([docs](conversations.md#chat-cli)).
+
+The server supports the following REST APIs:
+- `/v1/chat/completions`
+- `/v1/responses`
+- `/v1/audio/transcriptions`
+- `/v1/models`
To launch a server, simply use the `transformers serve` CLI command:
@@ -109,7 +114,7 @@ The server is also an MCP client, so it can interact with MCP tools in agentic u
-### Usage example 1: apps with local requests (feat. Jan)
+### Usage example 1: chat with local requests (feat. Jan)
This example shows how to use `transformers serve` as a local LLM provider for the [Jan](https://jan.ai/) app. Jan is a ChatGPT-alternative graphical interface, fully running on your machine. The requests to `transformers serve` come directly from the local app -- while this section focuses on Jan, you can extrapolate some instructions to other apps that make local requests.
@@ -139,17 +144,17 @@ ssh -N -f -L 8000:localhost:8000 your_server_account@your_server_IP -p port_to_s
Port forwarding is not Jan-specific: you can use it to connect `transformers serve` running in a different machine with an app of your choice.
-### Usage example 2: apps with external requests (feat. Cursor)
+### Usage example 2: chat with external requests (feat. Cursor)
This example shows how to use `transformers serve` as a local LLM provider for [Cursor](https://cursor.com/), the popular IDE. Unlike in the previous example, requests to `transformers serve` will come from an external IP (Cursor's server IPs), which requires some additional setup. Furthermore, some of Cursor's requests require [CORS](https://developer.mozilla.org/en-US/docs/Web/HTTP/Guides/CORS), which is disabled by default for security reasons.
-To launch our server with CORS enabled, run
+To launch a server with CORS enabled, run
```shell
transformers serve --enable-cors
```
-We'll also need to expose our server to external IPs. A potential solution is to use [`ngrok`](https://ngrok.com/), which has a permissive free tier. After setting up your `ngrok` account and authenticating on your server machine, you run
+You'll also need to expose your server to external IPs. A potential solution is to use [`ngrok`](https://ngrok.com/), which has a permissive free tier. After setting up your `ngrok` account and authenticating on your server machine, you run
```shell
ngrok http [port]
@@ -161,7 +166,7 @@ where `port` is the port used by `transformers serve` (`8000` by default). On th
-We're now ready to set things up on the app side! In Cursor, while we can't set a new provider, we can change the endpoint for OpenAI requests in the model selection settings. First, navigate to "Settings" > "Cursor Settings", "Models" tab, and expand the "API Keys" collapsible. To set our `transformers serve` endpoint, follow this order:
+You're now ready to set things up on the app side! In Cursor, while you can't set a new provider, you can change the endpoint for OpenAI requests in the model selection settings. First, navigate to "Settings" > "Cursor Settings", "Models" tab, and expand the "API Keys" collapsible. To set your `transformers serve` endpoint, follow this order:
1. Unselect ALL models in the list above (e.g. `gpt4`, ...);
2. Add and select the model you want to use (e.g. `Qwen/Qwen3-4B`)
3. Add some random text to OpenAI API Key. This field won't be used, but it can’t be empty;
@@ -225,3 +230,26 @@ Image URL: https://evalstate-flux1-schnell.hf.space/gradio_api/file=/tmp/gradio/
I have generated an image of a cat on the moon using the Flux 1 Schnell Image Generator. The image is 1024x1024 pixels and was created with 4 inference steps. Let me know if you would like to make any changes or need further assistance!
```
+
+### Usage example 4: speech to text transcription (feat. Open WebUI)
+
+This guide shows how to do audio transcription for chat purposes, using `transformers serve` and [Open WebUI](https://openwebui.com/). This guide assumes you have Open WebUI installed on your machine and ready to run. Please refer to the examples above to use the text functionalities of `transformer serve` with Open WebUI -- the instructions are the same.
+
+To start, let's launch the server. Some of Open WebUI's requests require [CORS](https://developer.mozilla.org/en-US/docs/Web/HTTP/Guides/CORS), which is disabled by default for security reasons, so you need to enable it:
+
+```shell
+transformers serve --enable-cors
+```
+
+Before you can speak into Open WebUI, you need to update its settings to use your server for speech to text (STT) tasks. Launch Open WebUI, and navigate to the audio tab inside the admin settings. If you're using Open WebUI with the default ports, [this link (default)](http://localhost:3000/admin/settings/audio) or [this link (python deployment)](http://localhost:8080/admin/settings/audio) will take you there. Do the following changes there:
+1. Change the type of "Speech-to-Text Engine" to "OpenAI";
+2. Update the address to your server's address -- `http://localhost:8000/v1` by default;
+3. Type your model of choice into the "STT Model" field, e.g. `openai/whisper-large-v3` ([available models](https://huggingface.co/models?pipeline_tag=automatic-speech-recognition&sort=trending)).
+
+If you've done everything correctly, the audio tab should look like this
+
+
+
+
+
+You're now ready to speak! Open a new chat, utter a few words after hitting the microphone button, and you should see the corresponding text on the chat input after the model transcribes it.
diff --git a/src/transformers/commands/serving.py b/src/transformers/commands/serving.py
index bf762b2214..3209f3c8ae 100644
--- a/src/transformers/commands/serving.py
+++ b/src/transformers/commands/serving.py
@@ -14,24 +14,28 @@
import copy
import functools
+import gc
+import io
import json
import re
+import threading
import time
from argparse import ArgumentParser, Namespace
from dataclasses import dataclass, field
from threading import Thread
-from typing import Generator, Optional
+from typing import Generator, Optional, Union
from huggingface_hub import ModelInfo, model_info
from transformers.utils.import_utils import (
is_fastapi_available,
+ is_librosa_available,
is_openai_available,
is_pydantic_available,
is_uvicorn_available,
)
-from .. import LogitsProcessorList, PreTrainedTokenizerFast, TextIteratorStreamer
+from .. import LogitsProcessorList, PreTrainedTokenizerFast, ProcessorMixin, TextIteratorStreamer
from ..generation.continuous_batching import ContinuousBatchingManager, RequestStatus
from ..utils import is_torch_available, logging
from . import BaseTransformersCLICommand
@@ -42,12 +46,16 @@ if is_torch_available():
from transformers import (
AutoModelForCausalLM,
- AutoTokenizer,
+ AutoModelForSpeechSeq2Seq,
+ AutoProcessor,
BitsAndBytesConfig,
GenerationConfig,
PreTrainedModel,
)
+if is_librosa_available():
+ import librosa
+
serve_dependencies_available = (
is_pydantic_available() and is_fastapi_available() and is_uvicorn_available() and is_openai_available()
)
@@ -56,6 +64,8 @@ if serve_dependencies_available:
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
+ from openai.types.audio.transcription import Transcription
+ from openai.types.audio.transcription_create_params import TranscriptionCreateParamsBase
from openai.types.chat.chat_completion_chunk import (
ChatCompletionChunk,
Choice,
@@ -90,20 +100,28 @@ if serve_dependencies_available:
OpenAI's ResponseCreateParamsStreaming with an additional field for the generation config (as a json string).
"""
- generation_config: Optional[str]
+ generation_config: str
class TransformersCompletionCreateParamsStreaming(CompletionCreateParamsStreaming, total=False):
"""
- OpenAI's CompletionCreateParamsStreaming with additional fields for the generation config (as a json string)
- and the request ID to re-use the previous KV cache.
+ OpenAI's CompletionCreateParamsStreaming with an additional field for the generation config (as a json string).
"""
- generation_config: Optional[str]
- request_id: Optional[str]
+ generation_config: str
- # Contrarily to OpenAI's output types, input types are `TypedDict`, which don't have validation
+ class TransformersTranscriptionCreateParams(TranscriptionCreateParamsBase, total=False):
+ """
+ OpenAI's TranscriptionCreateParamsBase with an additional field for the generation config (as a json string).
+ """
+
+ file: bytes # Overwritten -- pydantic isn't happy with `typing.IO[bytes]`, present in the original type
+ generation_config: str
+ stream: Optional[bool] = False
+
+ # Contrarily to OpenAI's output types, input types are `TypedDict`, which don't have built-in validation.
response_validator = TypeAdapter(TransformersResponseCreateParamsStreaming)
completion_validator = TypeAdapter(TransformersCompletionCreateParamsStreaming)
+ transcription_validator = TypeAdapter(TransformersTranscriptionCreateParams)
# Define request fields that are not yet used in `transformers serve`. Receiving these fields will raise an
# HTTPException.
@@ -146,6 +164,14 @@ if serve_dependencies_available:
"user",
"web_search_options",
}
+ UNUSED_TRANSCRIPTION_FIELDS = {
+ "chunking_strategy",
+ "include",
+ "language",
+ "prompt",
+ "response_format",
+ "timestamp_granularities",
+ }
logger = logging.get_logger(__name__)
@@ -226,10 +252,6 @@ def create_generation_config_from_req(
if req.get("seed") is not None:
torch.manual_seed(req["seed"])
- # Sets server-specific defaults, if unset
- if generation_config.max_new_tokens is None:
- generation_config.max_new_tokens = 1024
-
return generation_config
@@ -247,6 +269,53 @@ class ToolState:
self.buffer = ""
+class TimedModel:
+ """
+ A class that holds a PreTrainedModel instance and its associated processor (tokenizer, audio processor, etc.).
+ Automatically deletes the instances after a specified timeout.
+ """
+
+ def __init__(
+ self,
+ model: "PreTrainedModel",
+ timeout_seconds: int,
+ processor: Optional[Union["ProcessorMixin", "PreTrainedTokenizerFast"]] = None,
+ ):
+ self.model = model
+ self._name_or_path = str(model.name_or_path)
+ self.processor = processor
+ self.timeout_seconds = timeout_seconds
+ self._timer = threading.Timer(self.timeout_seconds, self._delete_model)
+ self._timer.start()
+
+ def reset_timer(self):
+ """Reset the timer for the deletion of the instances."""
+ self._timer.cancel()
+ self._timer = threading.Timer(self.timeout_seconds, self._delete_model)
+ self._timer.start()
+
+ def _delete_model(self):
+ """Delete the wrapped model and processor and clean up resources."""
+ if hasattr(self, "model") and self.model is not None:
+ del self.model
+ del self.processor
+ self.model = None
+ self.processor = None
+ gc.collect()
+
+ # Clear CUDA cache if available
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ logger.info(
+ f"{self._name_or_path} was removed from memory after {self.timeout_seconds} seconds of inactivity"
+ )
+
+ def is_deleted(self):
+ """Check if the instances have been deleted."""
+ return not hasattr(self, "model") or self.model is None
+
+
@dataclass
class ServeArguments:
r"""
@@ -289,6 +358,10 @@ class ServeArguments:
# Serving settings
host: str = field(default="localhost", metadata={"help": "Interface the server will listen to."})
port: int = field(default=8000, metadata={"help": "Port the server will listen to."})
+ model_timeout: int = field(
+ default=300,
+ metadata={"help": "Time in seconds after which a model will be removed from memory."},
+ )
# Other settings
log_level: str = field(
@@ -357,16 +430,15 @@ class ServeCommand(BaseTransformersCLICommand):
cb_logger.setLevel(logging.log_levels[self.args.log_level.lower()])
# Internal state:
- # 1. Tracks the most recently used model, to prevent reloading the model unnecessarily
- self.loaded_model: Optional[str] = None
+ # 1. Tracks models in memory, to prevent reloading the model unnecessarily
+ self.loaded_models: dict[str, TimedModel] = {}
self.running_continuous_batching_manager: Optional[ContinuousBatchingManager] = None
- self.model: PreTrainedModel
- self.tokenizer: PreTrainedTokenizerFast
# 2. preserves information about the last call and last KV cache, to determine whether we can reuse the KV
# cache and avoid re-running prefil
self.last_messages = None
self.last_kv_cache = None
+ self.last_text_model = None
def _validate_request(
self,
@@ -433,10 +505,19 @@ class ServeCommand(BaseTransformersCLICommand):
unused_fields=UNUSED_CHAT_COMPLETION_FIELDS,
)
+ def validate_transcription_request(self, request: dict):
+ self._validate_request(
+ request=request,
+ schema=TransformersTranscriptionCreateParams,
+ validator=transcription_validator,
+ unused_fields=UNUSED_TRANSCRIPTION_FIELDS,
+ )
+
def build_chat_completion_chunk(
self,
request_id: Optional[str] = "",
content: Optional[str] = None,
+ model: Optional[str] = None,
role: Optional[str] = None,
finish_reason: Optional[str] = None,
tool_calls: Optional[list["ChoiceDeltaToolCall"]] = None,
@@ -452,6 +533,8 @@ class ServeCommand(BaseTransformersCLICommand):
The request ID.
content (`str`, *optional*):
Content of the response from the model.
+ model (`str`, *optional*):
+ The model that generated the content.
role (`str`, *optional*):
The role of the next content, until a new role is defined.
finish_reason (`str`, *optional*):
@@ -465,7 +548,7 @@ class ServeCommand(BaseTransformersCLICommand):
chunk = ChatCompletionChunk(
id=request_id,
created=int(time.time()),
- model=self.loaded_model,
+ model=model,
choices=[
Choice(
delta=ChoiceDelta(
@@ -529,6 +612,26 @@ class ServeCommand(BaseTransformersCLICommand):
output = self.generate_response(request)
return StreamingResponse(output, media_type="text/event-stream")
+ from fastapi import Request
+
+ @app.post("/v1/audio/transcriptions")
+ async def audio_transcriptions(request: Request):
+ # Parses the multipart/form-data request into the request format used by other endpoints
+ async with request.form() as form:
+ parsed_request = TransformersTranscriptionCreateParams(
+ file=await form["file"].read(),
+ model=form["model"],
+ # TODO: add other fields
+ )
+ logger.debug(
+ f"Received file: {form['file'].filename}; MIME type: {form['file'].content_type}; "
+ f"size: {form['file'].size / 1024:.2f} KiB"
+ )
+ self.validate_transcription_request(request=parsed_request)
+
+ output = self.generate_transcription(parsed_request)
+ return StreamingResponse(output, media_type="text/event-stream")
+
@app.get("/v1/models")
def get_all_models():
return JSONResponse(
@@ -579,22 +682,22 @@ class ServeCommand(BaseTransformersCLICommand):
Returns:
`Generator[str, None, None]`: A generator that yields the OpenAI Chat Completion chunks.
"""
- if self.args.force_model is not None:
- req["model"] = self.args.force_model
- update_model = self.canonicalized_model_name(req["model"]) != self.loaded_model
- if update_model:
+ model_id_and_revision = self.process_model_name(req["model"])
+ must_discard_cache = model_id_and_revision != self.last_text_model
+ self.last_text_model = model_id_and_revision
+ if must_discard_cache:
# When switching models, terminate a continuous batching manager if it is running.
if self.running_continuous_batching_manager is not None:
self.running_continuous_batching_manager.stop(block=True, timeout=2)
self.running_continuous_batching_manager = None
- self.load_model_and_tokenizer(req["model"], self.args)
+ model, tokenizer = self.load_text_model_and_tokenizer(model_id_and_revision)
generation_config = create_generation_config_from_req(
req,
- model_generation_config=self.model.generation_config,
- eos_token_id=self.tokenizer.eos_token_id,
- pad_token_id=self.tokenizer.pad_token_id,
+ model_generation_config=model.generation_config,
+ eos_token_id=tokenizer.eos_token_id,
+ pad_token_id=tokenizer.pad_token_id,
use_cache=False,
num_blocks=1,
block_size=1024,
@@ -604,7 +707,7 @@ class ServeCommand(BaseTransformersCLICommand):
)
if self.running_continuous_batching_manager is None:
- self.running_continuous_batching_manager = self.model.init_continuous_batching(
+ self.running_continuous_batching_manager = model.init_continuous_batching(
generation_config=generation_config, streaming=True
)
@@ -614,9 +717,9 @@ class ServeCommand(BaseTransformersCLICommand):
self.running_continuous_batching_manager.start()
# TODO (Joao, Lysandre): this should also work with tool support
- inputs = self.tokenizer.apply_chat_template(
- req["messages"], return_tensors="pt", add_generation_prompt=True
- ).to(self.model.device)
+ inputs = tokenizer.apply_chat_template(req["messages"], return_tensors="pt", add_generation_prompt=True).to(
+ model.device
+ )
def stream_chat_completion(_inputs):
try:
@@ -628,7 +731,7 @@ class ServeCommand(BaseTransformersCLICommand):
# Emit the assistant role to start the stream. Other chunks won't have a role, as it is implicit
# they come from the assistant.
- yield self.build_chat_completion_chunk(request_id, role="assistant")
+ yield self.build_chat_completion_chunk(request_id, role="assistant", model=model_id_and_revision)
for result in self.running_continuous_batching_manager:
if result.request_id != request_id:
@@ -641,10 +744,14 @@ class ServeCommand(BaseTransformersCLICommand):
finish_reason = "stop" if result.status == RequestStatus.FINISHED else None
if result.status == RequestStatus.FINISHED:
- yield self.build_chat_completion_chunk(request_id, finish_reason=finish_reason)
+ yield self.build_chat_completion_chunk(
+ request_id, finish_reason=finish_reason, model=model_id_and_revision
+ )
break
else:
- yield self.build_chat_completion_chunk(request_id=request_id, content=result.next_token)
+ yield self.build_chat_completion_chunk(
+ request_id=request_id, content=result.next_token, model=model_id_and_revision
+ )
except Exception as e:
logger.error(str(e))
@@ -662,22 +769,20 @@ class ServeCommand(BaseTransformersCLICommand):
Returns:
`Generator[str, None, None]`: A generator that yields the OpenAI Chat Completion chunks.
"""
- if self.args.force_model is not None:
- req["model"] = self.args.force_model
-
- update_model = self.canonicalized_model_name(req["model"]) != self.loaded_model
- if update_model:
- self.load_model_and_tokenizer(req["model"], self.args)
-
# HACK for tiny-agents: it sends a request after the assistant message (???). Let's assume we can't have a
# request whose last message is from the assistant.
if req["messages"][-1]["role"] == "assistant":
return
+ model_id_and_revision = self.process_model_name(req["model"])
+ must_discard_cache = model_id_and_revision != self.last_text_model
+ self.last_text_model = model_id_and_revision
+ model, tokenizer = self.load_text_model_and_tokenizer(model_id_and_revision)
+
# ====== TOOL PREPROCESSING LOGIC ======
tool_model_family = None
for supported_model_families in _MODELS_WITH_TOOL_SUPPORT:
- if supported_model_families in self.model.config.architectures[0].lower():
+ if supported_model_families in model.config.architectures[0].lower():
tool_model_family = supported_model_families
break
# TODO: trigger 2 constrained generations after the tool call start token is emitted:
@@ -686,21 +791,19 @@ class ServeCommand(BaseTransformersCLICommand):
# ====== END OF TOOL PREPROCESSING LOGIC ======
if tool_model_family is not None:
- text = self.tokenizer.apply_chat_template(
+ text = tokenizer.apply_chat_template(
req["messages"], add_generation_prompt=True, tokenize=False, tools=req.get("tools")
)
else:
- text = self.tokenizer.apply_chat_template(req["messages"], add_generation_prompt=True, tokenize=False)
- inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)["input_ids"]
+ text = tokenizer.apply_chat_template(req["messages"], add_generation_prompt=True, tokenize=False)
+ inputs = tokenizer(text, return_tensors="pt").to(model.device)["input_ids"]
request_id = req.get("request_id", "req_0")
- generation_streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True, skip_prompt=True)
- generation_config = create_generation_config_from_req(
- req, model_generation_config=self.model.generation_config
- )
+ generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
+ generation_config = create_generation_config_from_req(req, model_generation_config=model.generation_config)
last_kv_cache = None
- if self.is_continuation(req) and not update_model:
+ if self.is_continuation(req) and not must_discard_cache:
last_kv_cache = self.last_kv_cache
generation_kwargs = {
@@ -715,7 +818,7 @@ class ServeCommand(BaseTransformersCLICommand):
def stream_chat_completion(streamer, _request_id):
# Thin wrapper to save the KV cache after generation
def generate_with_cache(**kwargs):
- generate_output = self.model.generate(**kwargs)
+ generate_output = model.generate(**kwargs)
self.last_kv_cache = generate_output.past_key_values
thread = Thread(target=generate_with_cache, kwargs=generation_kwargs)
@@ -726,7 +829,7 @@ class ServeCommand(BaseTransformersCLICommand):
# Emit the assistant role to start the stream. Other chunks won't have a role, as it is implicit
# they come from the assistant.
- yield self.build_chat_completion_chunk(request_id, role="assistant")
+ yield self.build_chat_completion_chunk(request_id, role="assistant", model=model_id_and_revision)
for result in streamer:
# ====== TOOL CALL LOGIC ======
@@ -740,10 +843,13 @@ class ServeCommand(BaseTransformersCLICommand):
if result.strip() == _TOOL_CALL_TOKENS[tool_model_family]["end"]:
tool_state.reset()
yield self.build_chat_completion_chunk(
- request_id=_request_id, role=None, finish_reason="tool_calls"
+ request_id=_request_id,
+ role=None,
+ finish_reason="tool_calls",
+ model=model_id_and_revision,
)
- continue
+ continue
# Inside a tool call
if tool_state.inside_tool_call:
tool_state.buffer += result
@@ -789,15 +895,17 @@ class ServeCommand(BaseTransformersCLICommand):
)
yield self.build_chat_completion_chunk(
- request_id=_request_id, role=None, tool_calls=[tool]
+ request_id=_request_id, role=None, tool_calls=[tool], model=model_id_and_revision
)
continue
# ====== END OF TOOL CALL LOGIC ======
# All non-tool related tokens are emitted as assistant messages. Empty text is skipped.
if result != "":
- yield self.build_chat_completion_chunk(_request_id, content=result)
- yield self.build_chat_completion_chunk(_request_id, finish_reason="stop")
+ yield self.build_chat_completion_chunk(
+ _request_id, content=result, model=model_id_and_revision
+ )
+ yield self.build_chat_completion_chunk(_request_id, finish_reason="stop", model=model_id_and_revision)
thread.join()
except Exception as e:
@@ -820,24 +928,20 @@ class ServeCommand(BaseTransformersCLICommand):
`Generator[str, None, None]`: A generator that yields the OpenAI Response events.
"""
# TODO -- Implement non-streaming mode
- if self.args.force_model is not None:
- req["model"] = self.args.force_model
+ model_id_and_revision = self.process_model_name(req["model"])
+ must_discard_cache = model_id_and_revision != self.last_text_model
+ self.last_text_model = model_id_and_revision
+ model, tokenizer = self.load_text_model_and_tokenizer(model_id_and_revision)
- update_model = self.canonicalized_model_name(req["model"]) != self.loaded_model
- if update_model:
- self.load_model_and_tokenizer(req["model"], self.args)
-
- text = self.tokenizer.apply_chat_template(req["input"], add_generation_prompt=True, tokenize=False)
- inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)["input_ids"]
+ text = tokenizer.apply_chat_template(req["input"], add_generation_prompt=True, tokenize=False)
+ inputs = tokenizer(text, return_tensors="pt").to(model.device)["input_ids"]
request_id = req.get("previous_response_id", "req_0")
- generation_streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True, skip_prompt=True)
- generation_config = create_generation_config_from_req(
- req, model_generation_config=self.model.generation_config
- )
+ generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
+ generation_config = create_generation_config_from_req(req, model_generation_config=model.generation_config)
last_kv_cache = None
- if self.is_continuation(req) and not update_model:
+ if self.is_continuation(req) and not must_discard_cache:
last_kv_cache = self.last_kv_cache
generation_kwargs = {
@@ -850,7 +954,12 @@ class ServeCommand(BaseTransformersCLICommand):
}
def stream_response(streamer, _request_id):
- thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
+ # Thin wrapper to save the KV cache after generation
+ def generate_with_cache(**kwargs):
+ generate_output = model.generate(**kwargs)
+ self.last_kv_cache = generate_output.past_key_values
+
+ thread = Thread(target=generate_with_cache, kwargs=generation_kwargs)
sequence_number = 0
output_index = 0
content_index = 0
@@ -868,7 +977,7 @@ class ServeCommand(BaseTransformersCLICommand):
id=f"resp_{request_id}",
created_at=created_at,
status="queued",
- model=self.loaded_model,
+ model=model_id_and_revision,
instructions=req.get("instructions"),
text={"format": {"type": "text"}},
object="response",
@@ -889,7 +998,7 @@ class ServeCommand(BaseTransformersCLICommand):
id=f"resp_{request_id}",
created_at=created_at,
status="in_progress",
- model=self.loaded_model,
+ model=model_id_and_revision,
instructions=req.get("instructions"),
text={"format": {"type": "text"}},
object="response",
@@ -994,7 +1103,7 @@ class ServeCommand(BaseTransformersCLICommand):
id=f"resp_{request_id}",
created_at=created_at,
status="completed",
- model=self.loaded_model,
+ model=model_id_and_revision,
instructions=req.get("instructions"),
text={"format": {"type": "text"}},
output=[response_output_item_done.item],
@@ -1026,7 +1135,7 @@ class ServeCommand(BaseTransformersCLICommand):
id=f"resp_{request_id}",
created_at=created_at,
status="failed",
- model=self.loaded_model,
+ model=model_id_and_revision,
instructions=req.get("instructions"),
text={"format": {"type": "text"}},
output=[],
@@ -1049,6 +1158,54 @@ class ServeCommand(BaseTransformersCLICommand):
return stream_response(generation_streamer, request_id)
+ def generate_transcription(self, req: dict) -> Generator[str, None, None]:
+ """
+ Generates an OpenAI Transcription using the audio file.
+
+ Args:
+ req (`dict`): The request containing the audio file and model information.
+
+ Returns:
+ `Generator[str, None, None]`: A generator that yields the transcription result.
+ """
+ # TODO: implement streaming transcription (currently, it's not streaming)
+ if not is_librosa_available():
+ raise ImportError(
+ "Missing librosa dependency for audio transcription. Please install with `pip install librosa`"
+ )
+ model_id_and_revision = self.process_model_name(req["model"])
+ audio_model, audio_processor = self.load_audio_model_and_processor(model_id_and_revision)
+
+ generation_streamer = TextIteratorStreamer(
+ audio_processor.tokenizer, skip_special_tokens=True, skip_prompt=True
+ )
+ generation_config = create_generation_config_from_req(
+ req, model_generation_config=audio_model.generation_config
+ )
+
+ # Read the binary audio file using librosa
+ model_sampling_rate = audio_processor.feature_extractor.sampling_rate
+ audio_bytes = io.BytesIO(req["file"])
+ audio_array, _ = librosa.load(audio_bytes, sr=model_sampling_rate, mono=True)
+ audio_inputs = audio_processor(audio_array, sampling_rate=model_sampling_rate, return_tensors="pt").to(
+ audio_model.device
+ )
+ audio_inputs["input_features"] = audio_inputs["input_features"].to(audio_model.dtype)
+
+ generation_kwargs = {
+ "streamer": generation_streamer,
+ "generation_config": generation_config,
+ "return_dict_in_generate": True,
+ }
+
+ def _generate_transcription():
+ generated_ids = audio_model.generate(**audio_inputs, **generation_kwargs)
+ transcription_text = audio_processor.batch_decode(generated_ids.sequences, skip_special_tokens=True)[0]
+ transcription = Transcription(text=transcription_text)
+ yield f"{transcription.model_dump_json(exclude_none=True)}"
+
+ return _generate_transcription()
+
def is_continuation(self, req: dict) -> bool:
"""
Determines whether the current request is a continuation of the last request. In other words, if it is the
@@ -1108,39 +1265,49 @@ class ServeCommand(BaseTransformersCLICommand):
return quantization_config
- def canonicalized_model_name(self, model_id: str) -> str:
+ def process_model_name(self, model_id: str) -> str:
"""
- Canonicalizes the model name to the format "model_id@revision". If the model_id DOESN'T contain an @, it
- defaults to "model_id@main".
+ Applies the `force_model` CLI argument and canonicalizes the model name to the format "model_id@revision".
+ If the model_id DOESN'T contain an @, it defaults to "model_id@main".
Args:
model_id (`str`): The model ID.
Returns:
- `str`: The canonicalized model name.
+ `str`: The canonicalized model name to be used
"""
+ if self.args.force_model is not None:
+ model_id = self.args.force_model
if "@" in model_id:
return model_id
return f"{model_id}@main"
- def load_model_and_tokenizer(self, model_id_and_revision: str, args: ServeArguments):
+ def _load_model_and_data_processor(
+ self, model_id_and_revision: str, model_cls: type[PreTrainedModel]
+ ) -> tuple[PreTrainedModel, Union[ProcessorMixin, PreTrainedTokenizerFast]]:
"""
- Loads the model and tokenizer from the given model ID and revision into the ServeCommand instance.
+ Generic method to load a model and a data processor from a model ID and revision, making use of the serve CLI
+ arguments.
Args:
model_id_and_revision (`str`):
The model ID and revision to load.
- args (`ServeArguments`):
- The serve arguments. May contain quantization settings, device, etc.
+ model_cls (`type[PreTrainedModel]`):
+ The model class to load.
+
+ Returns:
+ `tuple[PreTrainedModel, Union[ProcessorMixin, PreTrainedTokenizerFast]]`: The loaded model and
+ data processor (tokenizer, audio processor, etc.).
"""
- logger.warning(f"Loading {model_id_and_revision}")
+ args = self.args
+ logger.info(f"Loading {model_id_and_revision}")
if "@" in model_id_and_revision:
model_id, revision = model_id_and_revision.split("@", 1)
else:
model_id, revision = model_id_and_revision, "main"
- tokenizer = AutoTokenizer.from_pretrained(
+ data_processor = AutoProcessor.from_pretrained(
model_id,
revision=revision,
trust_remote_code=args.trust_remote_code,
@@ -1158,19 +1325,76 @@ class ServeCommand(BaseTransformersCLICommand):
"trust_remote_code": args.trust_remote_code,
}
- model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
-
- if model.generation_config.max_new_tokens is not None and model.generation_config.max_new_tokens < 1024:
- model.generation_config.max_new_tokens = 1024
+ model = model_cls.from_pretrained(model_id, **model_kwargs)
if getattr(model, "hf_device_map", None) is None:
model = model.to(args.device)
- self.loaded_model = f"{model_id}@{revision}"
+ has_default_max_length = (
+ model.generation_config.max_new_tokens is None and model.generation_config.max_length == 20
+ )
+ has_short_max_new_tokens = (
+ model.generation_config.max_new_tokens is not None and model.generation_config.max_new_tokens < 1024
+ )
+ if has_default_max_length or has_short_max_new_tokens:
+ model.generation_config.max_new_tokens = 1024
- logger.warning(f"Loaded model {self.loaded_model}")
- self.model = model
- self.tokenizer = tokenizer
+ logger.info(f"Loaded model {model_id_and_revision}")
+ return model, data_processor
+
+ def load_text_model_and_tokenizer(
+ self, model_id_and_revision: str
+ ) -> tuple[PreTrainedModel, PreTrainedTokenizerFast]:
+ """
+ Loads the text model and tokenizer from the given model ID and revision into the ServeCommand instance.
+
+ Args:
+ model_id_and_revision (`str`):
+ The model ID and revision to load.
+
+ Returns:
+ `tuple[PreTrainedModel, PreTrainedTokenizerFast]`: The loaded text model and tokenizer.
+ """
+ if model_id_and_revision not in self.loaded_models or self.loaded_models[model_id_and_revision].is_deleted():
+ model, tokenizer = self._load_model_and_data_processor(model_id_and_revision, AutoModelForCausalLM)
+ self.loaded_models[model_id_and_revision] = TimedModel(
+ model,
+ timeout_seconds=self.args.model_timeout,
+ processor=tokenizer,
+ )
+ else:
+ self.loaded_models[model_id_and_revision].reset_timer()
+ model = self.loaded_models[model_id_and_revision].model
+ tokenizer = self.loaded_models[model_id_and_revision].processor
+
+ return model, tokenizer
+
+ def load_audio_model_and_processor(self, model_id_and_revision: str) -> tuple[PreTrainedModel, ProcessorMixin]:
+ """
+ Loads the audio model and processor from the given model ID and revision into the ServeCommand instance.
+
+ Args:
+ model_id_and_revision (`str`):
+ The model ID and revision to load.
+
+ Returns:
+ `tuple[PreTrainedModel, ProcessorMixin]`: The loaded audio model and processor.
+ """
+ if model_id_and_revision not in self.loaded_models or self.loaded_models[model_id_and_revision].is_deleted():
+ audio_model, audio_processor = self._load_model_and_data_processor(
+ model_id_and_revision, AutoModelForSpeechSeq2Seq
+ )
+ self.loaded_models[model_id_and_revision] = TimedModel(
+ audio_model,
+ timeout_seconds=self.args.model_timeout,
+ processor=audio_processor,
+ )
+ else:
+ self.loaded_models[model_id_and_revision].reset_timer()
+ audio_model = self.loaded_models[model_id_and_revision].model
+ audio_processor = self.loaded_models[model_id_and_revision].processor
+
+ return audio_model, audio_processor
if __name__ == "__main__":
diff --git a/tests/commands/test_serving.py b/tests/commands/test_serving.py
index 403f08ae7e..ed344ef7ed 100644
--- a/tests/commands/test_serving.py
+++ b/tests/commands/test_serving.py
@@ -63,14 +63,13 @@ class ServeCLITest(unittest.TestCase):
"""
dummy = ServeCommand.__new__(ServeCommand)
dummy.args = type("Args", (), {})()
- dummy.loaded_model = "dummy_model@main"
# The keys for these fields must be present in every chunk
MANDATORY_FIELDS = ["data", "id", "choices", "created", "model", "object", "system_fingerprint"]
# Case 1: most fields are provided
chunk = ServeCommand.build_chat_completion_chunk(
- dummy, request_id="req0", content="hello", finish_reason="stop", role="user"
+ dummy, request_id="req0", content="hello", finish_reason="stop", role="user", model="dummy_model@main"
)
for field in MANDATORY_FIELDS:
self.assertIn(field, chunk)
@@ -79,13 +78,13 @@ class ServeCLITest(unittest.TestCase):
)
# Case 2: only the role is provided -- other fields in 'choices' are omitted
- chunk = dummy.build_chat_completion_chunk(request_id="req0", role="user")
+ chunk = dummy.build_chat_completion_chunk(request_id="req0", role="user", model="dummy_model@main")
for field in MANDATORY_FIELDS:
self.assertIn(field, chunk)
self.assertIn('"choices":[{"delta":{"role":"user"},"index":0}]', chunk)
# Case 3: only the content is provided -- other fields in 'choices' are omitted
- chunk = dummy.build_chat_completion_chunk(request_id="req0", content="hello")
+ chunk = dummy.build_chat_completion_chunk(request_id="req0", content="hello", model="dummy_model@main")
for field in MANDATORY_FIELDS:
self.assertIn(field, chunk)
self.assertIn('"choices":[{"delta":{"content":"hello"},"index":0}]', chunk)
@@ -96,7 +95,7 @@ class ServeCLITest(unittest.TestCase):
function=ChoiceDeltaToolCallFunction(name="foo_bar", arguments='{"foo1": "bar1", "foo2": "bar2"}'),
type="function",
)
- chunk = dummy.build_chat_completion_chunk(request_id="req0", tool_calls=[tool_call])
+ chunk = dummy.build_chat_completion_chunk(request_id="req0", tool_calls=[tool_call], model="dummy_model@main")
for field in MANDATORY_FIELDS:
self.assertIn(field, chunk)
expected_choices_content = (