[serve] Add speech to text (/v1/audio/transcriptions) (#39434)
* Scaffolding * Explicit content * Naïve Responses API streaming implementation * Cleanup * Scaffolding * Explicit content * Naïve Responses API streaming implementation * Cleanup * use openai * validate request, including detecting unused fields * dict indexing * dict var access * tmp commit (tests failing) * add slow * use oai output type in completions * (little rebase errors) * working spec? * guard type hint * type hints. fix state (CB can now load different models) * type hints; fn names; error type * add docstrings * responses + kv cache * metadata support; fix kv cache; error event * add output_index and content_index * docstrings * add test_build_response_event * docs/comments * gate test requirements; terminate cb manager on model switch * nasty type hints * more type hints * disable validation by default; enable force models * todo * experiment: base model from typed dict * audio working * fix bad rebase * load audio with librosa * implement timed models * almost working * make fixup * fix tests * transcription request type * tokenizer -> processor * add example in docs --------- Co-authored-by: Lysandre <hi@lysand.re>
This commit is contained in:
@@ -70,8 +70,13 @@ vllm serve Qwen/Qwen2.5-1.5B-Instruct \
|
|||||||
> [!WARNING]
|
> [!WARNING]
|
||||||
> This section is experimental and subject to change in future versions
|
> This section is experimental and subject to change in future versions
|
||||||
|
|
||||||
<!-- TODO: LLMs -> models, after we add audio/image input/output support -->
|
You can serve models of diverse modalities supported by `transformers` with the `transformers serve` CLI. It spawns a local server that offers compatibility with the OpenAI SDK, which is the _de facto_ standard for LLM conversations and other related tasks. This way, you can use the server from many third party applications, or test it using the `transformers chat` CLI ([docs](conversations.md#chat-cli)).
|
||||||
You can serve LLMs supported by `transformers` with the `transformers serve` CLI. It spawns a local server that offers a Chat Completion API or a Response API compatible with the OpenAI SDK, which are the _de facto_ standard for LLM conversations. This way, you can use the server from many third party applications, or test it using the `transformers chat` CLI ([docs](conversations.md#chat-cli)).
|
|
||||||
|
The server supports the following REST APIs:
|
||||||
|
- `/v1/chat/completions`
|
||||||
|
- `/v1/responses`
|
||||||
|
- `/v1/audio/transcriptions`
|
||||||
|
- `/v1/models`
|
||||||
|
|
||||||
To launch a server, simply use the `transformers serve` CLI command:
|
To launch a server, simply use the `transformers serve` CLI command:
|
||||||
|
|
||||||
@@ -109,7 +114,7 @@ The server is also an MCP client, so it can interact with MCP tools in agentic u
|
|||||||
<!-- TODO: example with a minimal python example, and explain that it is possible to pass a full generation config in the request -->
|
<!-- TODO: example with a minimal python example, and explain that it is possible to pass a full generation config in the request -->
|
||||||
|
|
||||||
|
|
||||||
### Usage example 1: apps with local requests (feat. Jan)
|
### Usage example 1: chat with local requests (feat. Jan)
|
||||||
|
|
||||||
This example shows how to use `transformers serve` as a local LLM provider for the [Jan](https://jan.ai/) app. Jan is a ChatGPT-alternative graphical interface, fully running on your machine. The requests to `transformers serve` come directly from the local app -- while this section focuses on Jan, you can extrapolate some instructions to other apps that make local requests.
|
This example shows how to use `transformers serve` as a local LLM provider for the [Jan](https://jan.ai/) app. Jan is a ChatGPT-alternative graphical interface, fully running on your machine. The requests to `transformers serve` come directly from the local app -- while this section focuses on Jan, you can extrapolate some instructions to other apps that make local requests.
|
||||||
|
|
||||||
@@ -139,17 +144,17 @@ ssh -N -f -L 8000:localhost:8000 your_server_account@your_server_IP -p port_to_s
|
|||||||
Port forwarding is not Jan-specific: you can use it to connect `transformers serve` running in a different machine with an app of your choice.
|
Port forwarding is not Jan-specific: you can use it to connect `transformers serve` running in a different machine with an app of your choice.
|
||||||
|
|
||||||
|
|
||||||
### Usage example 2: apps with external requests (feat. Cursor)
|
### Usage example 2: chat with external requests (feat. Cursor)
|
||||||
|
|
||||||
This example shows how to use `transformers serve` as a local LLM provider for [Cursor](https://cursor.com/), the popular IDE. Unlike in the previous example, requests to `transformers serve` will come from an external IP (Cursor's server IPs), which requires some additional setup. Furthermore, some of Cursor's requests require [CORS](https://developer.mozilla.org/en-US/docs/Web/HTTP/Guides/CORS), which is disabled by default for security reasons.
|
This example shows how to use `transformers serve` as a local LLM provider for [Cursor](https://cursor.com/), the popular IDE. Unlike in the previous example, requests to `transformers serve` will come from an external IP (Cursor's server IPs), which requires some additional setup. Furthermore, some of Cursor's requests require [CORS](https://developer.mozilla.org/en-US/docs/Web/HTTP/Guides/CORS), which is disabled by default for security reasons.
|
||||||
|
|
||||||
To launch our server with CORS enabled, run
|
To launch a server with CORS enabled, run
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
transformers serve --enable-cors
|
transformers serve --enable-cors
|
||||||
```
|
```
|
||||||
|
|
||||||
We'll also need to expose our server to external IPs. A potential solution is to use [`ngrok`](https://ngrok.com/), which has a permissive free tier. After setting up your `ngrok` account and authenticating on your server machine, you run
|
You'll also need to expose your server to external IPs. A potential solution is to use [`ngrok`](https://ngrok.com/), which has a permissive free tier. After setting up your `ngrok` account and authenticating on your server machine, you run
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
ngrok http [port]
|
ngrok http [port]
|
||||||
@@ -161,7 +166,7 @@ where `port` is the port used by `transformers serve` (`8000` by default). On th
|
|||||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/transformers_serve_ngrok.png"/>
|
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/transformers_serve_ngrok.png"/>
|
||||||
</h3>
|
</h3>
|
||||||
|
|
||||||
We're now ready to set things up on the app side! In Cursor, while we can't set a new provider, we can change the endpoint for OpenAI requests in the model selection settings. First, navigate to "Settings" > "Cursor Settings", "Models" tab, and expand the "API Keys" collapsible. To set our `transformers serve` endpoint, follow this order:
|
You're now ready to set things up on the app side! In Cursor, while you can't set a new provider, you can change the endpoint for OpenAI requests in the model selection settings. First, navigate to "Settings" > "Cursor Settings", "Models" tab, and expand the "API Keys" collapsible. To set your `transformers serve` endpoint, follow this order:
|
||||||
1. Unselect ALL models in the list above (e.g. `gpt4`, ...);
|
1. Unselect ALL models in the list above (e.g. `gpt4`, ...);
|
||||||
2. Add and select the model you want to use (e.g. `Qwen/Qwen3-4B`)
|
2. Add and select the model you want to use (e.g. `Qwen/Qwen3-4B`)
|
||||||
3. Add some random text to OpenAI API Key. This field won't be used, but it can’t be empty;
|
3. Add some random text to OpenAI API Key. This field won't be used, but it can’t be empty;
|
||||||
@@ -225,3 +230,26 @@ Image URL: https://evalstate-flux1-schnell.hf.space/gradio_api/file=/tmp/gradio/
|
|||||||
|
|
||||||
I have generated an image of a cat on the moon using the Flux 1 Schnell Image Generator. The image is 1024x1024 pixels and was created with 4 inference steps. Let me know if you would like to make any changes or need further assistance!
|
I have generated an image of a cat on the moon using the Flux 1 Schnell Image Generator. The image is 1024x1024 pixels and was created with 4 inference steps. Let me know if you would like to make any changes or need further assistance!
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Usage example 4: speech to text transcription (feat. Open WebUI)
|
||||||
|
|
||||||
|
This guide shows how to do audio transcription for chat purposes, using `transformers serve` and [Open WebUI](https://openwebui.com/). This guide assumes you have Open WebUI installed on your machine and ready to run. Please refer to the examples above to use the text functionalities of `transformer serve` with Open WebUI -- the instructions are the same.
|
||||||
|
|
||||||
|
To start, let's launch the server. Some of Open WebUI's requests require [CORS](https://developer.mozilla.org/en-US/docs/Web/HTTP/Guides/CORS), which is disabled by default for security reasons, so you need to enable it:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
transformers serve --enable-cors
|
||||||
|
```
|
||||||
|
|
||||||
|
Before you can speak into Open WebUI, you need to update its settings to use your server for speech to text (STT) tasks. Launch Open WebUI, and navigate to the audio tab inside the admin settings. If you're using Open WebUI with the default ports, [this link (default)](http://localhost:3000/admin/settings/audio) or [this link (python deployment)](http://localhost:8080/admin/settings/audio) will take you there. Do the following changes there:
|
||||||
|
1. Change the type of "Speech-to-Text Engine" to "OpenAI";
|
||||||
|
2. Update the address to your server's address -- `http://localhost:8000/v1` by default;
|
||||||
|
3. Type your model of choice into the "STT Model" field, e.g. `openai/whisper-large-v3` ([available models](https://huggingface.co/models?pipeline_tag=automatic-speech-recognition&sort=trending)).
|
||||||
|
|
||||||
|
If you've done everything correctly, the audio tab should look like this
|
||||||
|
|
||||||
|
<h3 align="center">
|
||||||
|
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/transformers_openwebui_stt_settings.png"/>
|
||||||
|
</h3>
|
||||||
|
|
||||||
|
You're now ready to speak! Open a new chat, utter a few words after hitting the microphone button, and you should see the corresponding text on the chat input after the model transcribes it.
|
||||||
|
|||||||
@@ -14,24 +14,28 @@
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import functools
|
import functools
|
||||||
|
import gc
|
||||||
|
import io
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
from argparse import ArgumentParser, Namespace
|
from argparse import ArgumentParser, Namespace
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Generator, Optional
|
from typing import Generator, Optional, Union
|
||||||
|
|
||||||
from huggingface_hub import ModelInfo, model_info
|
from huggingface_hub import ModelInfo, model_info
|
||||||
|
|
||||||
from transformers.utils.import_utils import (
|
from transformers.utils.import_utils import (
|
||||||
is_fastapi_available,
|
is_fastapi_available,
|
||||||
|
is_librosa_available,
|
||||||
is_openai_available,
|
is_openai_available,
|
||||||
is_pydantic_available,
|
is_pydantic_available,
|
||||||
is_uvicorn_available,
|
is_uvicorn_available,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .. import LogitsProcessorList, PreTrainedTokenizerFast, TextIteratorStreamer
|
from .. import LogitsProcessorList, PreTrainedTokenizerFast, ProcessorMixin, TextIteratorStreamer
|
||||||
from ..generation.continuous_batching import ContinuousBatchingManager, RequestStatus
|
from ..generation.continuous_batching import ContinuousBatchingManager, RequestStatus
|
||||||
from ..utils import is_torch_available, logging
|
from ..utils import is_torch_available, logging
|
||||||
from . import BaseTransformersCLICommand
|
from . import BaseTransformersCLICommand
|
||||||
@@ -42,12 +46,16 @@ if is_torch_available():
|
|||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoTokenizer,
|
AutoModelForSpeechSeq2Seq,
|
||||||
|
AutoProcessor,
|
||||||
BitsAndBytesConfig,
|
BitsAndBytesConfig,
|
||||||
GenerationConfig,
|
GenerationConfig,
|
||||||
PreTrainedModel,
|
PreTrainedModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if is_librosa_available():
|
||||||
|
import librosa
|
||||||
|
|
||||||
serve_dependencies_available = (
|
serve_dependencies_available = (
|
||||||
is_pydantic_available() and is_fastapi_available() and is_uvicorn_available() and is_openai_available()
|
is_pydantic_available() and is_fastapi_available() and is_uvicorn_available() and is_openai_available()
|
||||||
)
|
)
|
||||||
@@ -56,6 +64,8 @@ if serve_dependencies_available:
|
|||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
|
from openai.types.audio.transcription import Transcription
|
||||||
|
from openai.types.audio.transcription_create_params import TranscriptionCreateParamsBase
|
||||||
from openai.types.chat.chat_completion_chunk import (
|
from openai.types.chat.chat_completion_chunk import (
|
||||||
ChatCompletionChunk,
|
ChatCompletionChunk,
|
||||||
Choice,
|
Choice,
|
||||||
@@ -90,20 +100,28 @@ if serve_dependencies_available:
|
|||||||
OpenAI's ResponseCreateParamsStreaming with an additional field for the generation config (as a json string).
|
OpenAI's ResponseCreateParamsStreaming with an additional field for the generation config (as a json string).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
generation_config: Optional[str]
|
generation_config: str
|
||||||
|
|
||||||
class TransformersCompletionCreateParamsStreaming(CompletionCreateParamsStreaming, total=False):
|
class TransformersCompletionCreateParamsStreaming(CompletionCreateParamsStreaming, total=False):
|
||||||
"""
|
"""
|
||||||
OpenAI's CompletionCreateParamsStreaming with additional fields for the generation config (as a json string)
|
OpenAI's CompletionCreateParamsStreaming with an additional field for the generation config (as a json string).
|
||||||
and the request ID to re-use the previous KV cache.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
generation_config: Optional[str]
|
generation_config: str
|
||||||
request_id: Optional[str]
|
|
||||||
|
|
||||||
# Contrarily to OpenAI's output types, input types are `TypedDict`, which don't have validation
|
class TransformersTranscriptionCreateParams(TranscriptionCreateParamsBase, total=False):
|
||||||
|
"""
|
||||||
|
OpenAI's TranscriptionCreateParamsBase with an additional field for the generation config (as a json string).
|
||||||
|
"""
|
||||||
|
|
||||||
|
file: bytes # Overwritten -- pydantic isn't happy with `typing.IO[bytes]`, present in the original type
|
||||||
|
generation_config: str
|
||||||
|
stream: Optional[bool] = False
|
||||||
|
|
||||||
|
# Contrarily to OpenAI's output types, input types are `TypedDict`, which don't have built-in validation.
|
||||||
response_validator = TypeAdapter(TransformersResponseCreateParamsStreaming)
|
response_validator = TypeAdapter(TransformersResponseCreateParamsStreaming)
|
||||||
completion_validator = TypeAdapter(TransformersCompletionCreateParamsStreaming)
|
completion_validator = TypeAdapter(TransformersCompletionCreateParamsStreaming)
|
||||||
|
transcription_validator = TypeAdapter(TransformersTranscriptionCreateParams)
|
||||||
|
|
||||||
# Define request fields that are not yet used in `transformers serve`. Receiving these fields will raise an
|
# Define request fields that are not yet used in `transformers serve`. Receiving these fields will raise an
|
||||||
# HTTPException.
|
# HTTPException.
|
||||||
@@ -146,6 +164,14 @@ if serve_dependencies_available:
|
|||||||
"user",
|
"user",
|
||||||
"web_search_options",
|
"web_search_options",
|
||||||
}
|
}
|
||||||
|
UNUSED_TRANSCRIPTION_FIELDS = {
|
||||||
|
"chunking_strategy",
|
||||||
|
"include",
|
||||||
|
"language",
|
||||||
|
"prompt",
|
||||||
|
"response_format",
|
||||||
|
"timestamp_granularities",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -226,10 +252,6 @@ def create_generation_config_from_req(
|
|||||||
if req.get("seed") is not None:
|
if req.get("seed") is not None:
|
||||||
torch.manual_seed(req["seed"])
|
torch.manual_seed(req["seed"])
|
||||||
|
|
||||||
# Sets server-specific defaults, if unset
|
|
||||||
if generation_config.max_new_tokens is None:
|
|
||||||
generation_config.max_new_tokens = 1024
|
|
||||||
|
|
||||||
return generation_config
|
return generation_config
|
||||||
|
|
||||||
|
|
||||||
@@ -247,6 +269,53 @@ class ToolState:
|
|||||||
self.buffer = ""
|
self.buffer = ""
|
||||||
|
|
||||||
|
|
||||||
|
class TimedModel:
|
||||||
|
"""
|
||||||
|
A class that holds a PreTrainedModel instance and its associated processor (tokenizer, audio processor, etc.).
|
||||||
|
Automatically deletes the instances after a specified timeout.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: "PreTrainedModel",
|
||||||
|
timeout_seconds: int,
|
||||||
|
processor: Optional[Union["ProcessorMixin", "PreTrainedTokenizerFast"]] = None,
|
||||||
|
):
|
||||||
|
self.model = model
|
||||||
|
self._name_or_path = str(model.name_or_path)
|
||||||
|
self.processor = processor
|
||||||
|
self.timeout_seconds = timeout_seconds
|
||||||
|
self._timer = threading.Timer(self.timeout_seconds, self._delete_model)
|
||||||
|
self._timer.start()
|
||||||
|
|
||||||
|
def reset_timer(self):
|
||||||
|
"""Reset the timer for the deletion of the instances."""
|
||||||
|
self._timer.cancel()
|
||||||
|
self._timer = threading.Timer(self.timeout_seconds, self._delete_model)
|
||||||
|
self._timer.start()
|
||||||
|
|
||||||
|
def _delete_model(self):
|
||||||
|
"""Delete the wrapped model and processor and clean up resources."""
|
||||||
|
if hasattr(self, "model") and self.model is not None:
|
||||||
|
del self.model
|
||||||
|
del self.processor
|
||||||
|
self.model = None
|
||||||
|
self.processor = None
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
# Clear CUDA cache if available
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"{self._name_or_path} was removed from memory after {self.timeout_seconds} seconds of inactivity"
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_deleted(self):
|
||||||
|
"""Check if the instances have been deleted."""
|
||||||
|
return not hasattr(self, "model") or self.model is None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ServeArguments:
|
class ServeArguments:
|
||||||
r"""
|
r"""
|
||||||
@@ -289,6 +358,10 @@ class ServeArguments:
|
|||||||
# Serving settings
|
# Serving settings
|
||||||
host: str = field(default="localhost", metadata={"help": "Interface the server will listen to."})
|
host: str = field(default="localhost", metadata={"help": "Interface the server will listen to."})
|
||||||
port: int = field(default=8000, metadata={"help": "Port the server will listen to."})
|
port: int = field(default=8000, metadata={"help": "Port the server will listen to."})
|
||||||
|
model_timeout: int = field(
|
||||||
|
default=300,
|
||||||
|
metadata={"help": "Time in seconds after which a model will be removed from memory."},
|
||||||
|
)
|
||||||
|
|
||||||
# Other settings
|
# Other settings
|
||||||
log_level: str = field(
|
log_level: str = field(
|
||||||
@@ -357,16 +430,15 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
cb_logger.setLevel(logging.log_levels[self.args.log_level.lower()])
|
cb_logger.setLevel(logging.log_levels[self.args.log_level.lower()])
|
||||||
|
|
||||||
# Internal state:
|
# Internal state:
|
||||||
# 1. Tracks the most recently used model, to prevent reloading the model unnecessarily
|
# 1. Tracks models in memory, to prevent reloading the model unnecessarily
|
||||||
self.loaded_model: Optional[str] = None
|
self.loaded_models: dict[str, TimedModel] = {}
|
||||||
self.running_continuous_batching_manager: Optional[ContinuousBatchingManager] = None
|
self.running_continuous_batching_manager: Optional[ContinuousBatchingManager] = None
|
||||||
self.model: PreTrainedModel
|
|
||||||
self.tokenizer: PreTrainedTokenizerFast
|
|
||||||
|
|
||||||
# 2. preserves information about the last call and last KV cache, to determine whether we can reuse the KV
|
# 2. preserves information about the last call and last KV cache, to determine whether we can reuse the KV
|
||||||
# cache and avoid re-running prefil
|
# cache and avoid re-running prefil
|
||||||
self.last_messages = None
|
self.last_messages = None
|
||||||
self.last_kv_cache = None
|
self.last_kv_cache = None
|
||||||
|
self.last_text_model = None
|
||||||
|
|
||||||
def _validate_request(
|
def _validate_request(
|
||||||
self,
|
self,
|
||||||
@@ -433,10 +505,19 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
unused_fields=UNUSED_CHAT_COMPLETION_FIELDS,
|
unused_fields=UNUSED_CHAT_COMPLETION_FIELDS,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def validate_transcription_request(self, request: dict):
|
||||||
|
self._validate_request(
|
||||||
|
request=request,
|
||||||
|
schema=TransformersTranscriptionCreateParams,
|
||||||
|
validator=transcription_validator,
|
||||||
|
unused_fields=UNUSED_TRANSCRIPTION_FIELDS,
|
||||||
|
)
|
||||||
|
|
||||||
def build_chat_completion_chunk(
|
def build_chat_completion_chunk(
|
||||||
self,
|
self,
|
||||||
request_id: Optional[str] = "",
|
request_id: Optional[str] = "",
|
||||||
content: Optional[str] = None,
|
content: Optional[str] = None,
|
||||||
|
model: Optional[str] = None,
|
||||||
role: Optional[str] = None,
|
role: Optional[str] = None,
|
||||||
finish_reason: Optional[str] = None,
|
finish_reason: Optional[str] = None,
|
||||||
tool_calls: Optional[list["ChoiceDeltaToolCall"]] = None,
|
tool_calls: Optional[list["ChoiceDeltaToolCall"]] = None,
|
||||||
@@ -452,6 +533,8 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
The request ID.
|
The request ID.
|
||||||
content (`str`, *optional*):
|
content (`str`, *optional*):
|
||||||
Content of the response from the model.
|
Content of the response from the model.
|
||||||
|
model (`str`, *optional*):
|
||||||
|
The model that generated the content.
|
||||||
role (`str`, *optional*):
|
role (`str`, *optional*):
|
||||||
The role of the next content, until a new role is defined.
|
The role of the next content, until a new role is defined.
|
||||||
finish_reason (`str`, *optional*):
|
finish_reason (`str`, *optional*):
|
||||||
@@ -465,7 +548,7 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
chunk = ChatCompletionChunk(
|
chunk = ChatCompletionChunk(
|
||||||
id=request_id,
|
id=request_id,
|
||||||
created=int(time.time()),
|
created=int(time.time()),
|
||||||
model=self.loaded_model,
|
model=model,
|
||||||
choices=[
|
choices=[
|
||||||
Choice(
|
Choice(
|
||||||
delta=ChoiceDelta(
|
delta=ChoiceDelta(
|
||||||
@@ -529,6 +612,26 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
output = self.generate_response(request)
|
output = self.generate_response(request)
|
||||||
return StreamingResponse(output, media_type="text/event-stream")
|
return StreamingResponse(output, media_type="text/event-stream")
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
@app.post("/v1/audio/transcriptions")
|
||||||
|
async def audio_transcriptions(request: Request):
|
||||||
|
# Parses the multipart/form-data request into the request format used by other endpoints
|
||||||
|
async with request.form() as form:
|
||||||
|
parsed_request = TransformersTranscriptionCreateParams(
|
||||||
|
file=await form["file"].read(),
|
||||||
|
model=form["model"],
|
||||||
|
# TODO: add other fields
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"Received file: {form['file'].filename}; MIME type: {form['file'].content_type}; "
|
||||||
|
f"size: {form['file'].size / 1024:.2f} KiB"
|
||||||
|
)
|
||||||
|
self.validate_transcription_request(request=parsed_request)
|
||||||
|
|
||||||
|
output = self.generate_transcription(parsed_request)
|
||||||
|
return StreamingResponse(output, media_type="text/event-stream")
|
||||||
|
|
||||||
@app.get("/v1/models")
|
@app.get("/v1/models")
|
||||||
def get_all_models():
|
def get_all_models():
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
@@ -579,22 +682,22 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
Returns:
|
Returns:
|
||||||
`Generator[str, None, None]`: A generator that yields the OpenAI Chat Completion chunks.
|
`Generator[str, None, None]`: A generator that yields the OpenAI Chat Completion chunks.
|
||||||
"""
|
"""
|
||||||
if self.args.force_model is not None:
|
|
||||||
req["model"] = self.args.force_model
|
|
||||||
|
|
||||||
update_model = self.canonicalized_model_name(req["model"]) != self.loaded_model
|
model_id_and_revision = self.process_model_name(req["model"])
|
||||||
if update_model:
|
must_discard_cache = model_id_and_revision != self.last_text_model
|
||||||
|
self.last_text_model = model_id_and_revision
|
||||||
|
if must_discard_cache:
|
||||||
# When switching models, terminate a continuous batching manager if it is running.
|
# When switching models, terminate a continuous batching manager if it is running.
|
||||||
if self.running_continuous_batching_manager is not None:
|
if self.running_continuous_batching_manager is not None:
|
||||||
self.running_continuous_batching_manager.stop(block=True, timeout=2)
|
self.running_continuous_batching_manager.stop(block=True, timeout=2)
|
||||||
self.running_continuous_batching_manager = None
|
self.running_continuous_batching_manager = None
|
||||||
self.load_model_and_tokenizer(req["model"], self.args)
|
model, tokenizer = self.load_text_model_and_tokenizer(model_id_and_revision)
|
||||||
|
|
||||||
generation_config = create_generation_config_from_req(
|
generation_config = create_generation_config_from_req(
|
||||||
req,
|
req,
|
||||||
model_generation_config=self.model.generation_config,
|
model_generation_config=model.generation_config,
|
||||||
eos_token_id=self.tokenizer.eos_token_id,
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
pad_token_id=self.tokenizer.pad_token_id,
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
use_cache=False,
|
use_cache=False,
|
||||||
num_blocks=1,
|
num_blocks=1,
|
||||||
block_size=1024,
|
block_size=1024,
|
||||||
@@ -604,7 +707,7 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.running_continuous_batching_manager is None:
|
if self.running_continuous_batching_manager is None:
|
||||||
self.running_continuous_batching_manager = self.model.init_continuous_batching(
|
self.running_continuous_batching_manager = model.init_continuous_batching(
|
||||||
generation_config=generation_config, streaming=True
|
generation_config=generation_config, streaming=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -614,9 +717,9 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
self.running_continuous_batching_manager.start()
|
self.running_continuous_batching_manager.start()
|
||||||
|
|
||||||
# TODO (Joao, Lysandre): this should also work with tool support
|
# TODO (Joao, Lysandre): this should also work with tool support
|
||||||
inputs = self.tokenizer.apply_chat_template(
|
inputs = tokenizer.apply_chat_template(req["messages"], return_tensors="pt", add_generation_prompt=True).to(
|
||||||
req["messages"], return_tensors="pt", add_generation_prompt=True
|
model.device
|
||||||
).to(self.model.device)
|
)
|
||||||
|
|
||||||
def stream_chat_completion(_inputs):
|
def stream_chat_completion(_inputs):
|
||||||
try:
|
try:
|
||||||
@@ -628,7 +731,7 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
|
|
||||||
# Emit the assistant role to start the stream. Other chunks won't have a role, as it is implicit
|
# Emit the assistant role to start the stream. Other chunks won't have a role, as it is implicit
|
||||||
# they come from the assistant.
|
# they come from the assistant.
|
||||||
yield self.build_chat_completion_chunk(request_id, role="assistant")
|
yield self.build_chat_completion_chunk(request_id, role="assistant", model=model_id_and_revision)
|
||||||
|
|
||||||
for result in self.running_continuous_batching_manager:
|
for result in self.running_continuous_batching_manager:
|
||||||
if result.request_id != request_id:
|
if result.request_id != request_id:
|
||||||
@@ -641,10 +744,14 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
|
|
||||||
finish_reason = "stop" if result.status == RequestStatus.FINISHED else None
|
finish_reason = "stop" if result.status == RequestStatus.FINISHED else None
|
||||||
if result.status == RequestStatus.FINISHED:
|
if result.status == RequestStatus.FINISHED:
|
||||||
yield self.build_chat_completion_chunk(request_id, finish_reason=finish_reason)
|
yield self.build_chat_completion_chunk(
|
||||||
|
request_id, finish_reason=finish_reason, model=model_id_and_revision
|
||||||
|
)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
yield self.build_chat_completion_chunk(request_id=request_id, content=result.next_token)
|
yield self.build_chat_completion_chunk(
|
||||||
|
request_id=request_id, content=result.next_token, model=model_id_and_revision
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
@@ -662,22 +769,20 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
Returns:
|
Returns:
|
||||||
`Generator[str, None, None]`: A generator that yields the OpenAI Chat Completion chunks.
|
`Generator[str, None, None]`: A generator that yields the OpenAI Chat Completion chunks.
|
||||||
"""
|
"""
|
||||||
if self.args.force_model is not None:
|
|
||||||
req["model"] = self.args.force_model
|
|
||||||
|
|
||||||
update_model = self.canonicalized_model_name(req["model"]) != self.loaded_model
|
|
||||||
if update_model:
|
|
||||||
self.load_model_and_tokenizer(req["model"], self.args)
|
|
||||||
|
|
||||||
# HACK for tiny-agents: it sends a request after the assistant message (???). Let's assume we can't have a
|
# HACK for tiny-agents: it sends a request after the assistant message (???). Let's assume we can't have a
|
||||||
# request whose last message is from the assistant.
|
# request whose last message is from the assistant.
|
||||||
if req["messages"][-1]["role"] == "assistant":
|
if req["messages"][-1]["role"] == "assistant":
|
||||||
return
|
return
|
||||||
|
|
||||||
|
model_id_and_revision = self.process_model_name(req["model"])
|
||||||
|
must_discard_cache = model_id_and_revision != self.last_text_model
|
||||||
|
self.last_text_model = model_id_and_revision
|
||||||
|
model, tokenizer = self.load_text_model_and_tokenizer(model_id_and_revision)
|
||||||
|
|
||||||
# ====== TOOL PREPROCESSING LOGIC ======
|
# ====== TOOL PREPROCESSING LOGIC ======
|
||||||
tool_model_family = None
|
tool_model_family = None
|
||||||
for supported_model_families in _MODELS_WITH_TOOL_SUPPORT:
|
for supported_model_families in _MODELS_WITH_TOOL_SUPPORT:
|
||||||
if supported_model_families in self.model.config.architectures[0].lower():
|
if supported_model_families in model.config.architectures[0].lower():
|
||||||
tool_model_family = supported_model_families
|
tool_model_family = supported_model_families
|
||||||
break
|
break
|
||||||
# TODO: trigger 2 constrained generations after the tool call start token is emitted:
|
# TODO: trigger 2 constrained generations after the tool call start token is emitted:
|
||||||
@@ -686,21 +791,19 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
# ====== END OF TOOL PREPROCESSING LOGIC ======
|
# ====== END OF TOOL PREPROCESSING LOGIC ======
|
||||||
|
|
||||||
if tool_model_family is not None:
|
if tool_model_family is not None:
|
||||||
text = self.tokenizer.apply_chat_template(
|
text = tokenizer.apply_chat_template(
|
||||||
req["messages"], add_generation_prompt=True, tokenize=False, tools=req.get("tools")
|
req["messages"], add_generation_prompt=True, tokenize=False, tools=req.get("tools")
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
text = self.tokenizer.apply_chat_template(req["messages"], add_generation_prompt=True, tokenize=False)
|
text = tokenizer.apply_chat_template(req["messages"], add_generation_prompt=True, tokenize=False)
|
||||||
inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)["input_ids"]
|
inputs = tokenizer(text, return_tensors="pt").to(model.device)["input_ids"]
|
||||||
request_id = req.get("request_id", "req_0")
|
request_id = req.get("request_id", "req_0")
|
||||||
|
|
||||||
generation_streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True, skip_prompt=True)
|
generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
|
||||||
generation_config = create_generation_config_from_req(
|
generation_config = create_generation_config_from_req(req, model_generation_config=model.generation_config)
|
||||||
req, model_generation_config=self.model.generation_config
|
|
||||||
)
|
|
||||||
|
|
||||||
last_kv_cache = None
|
last_kv_cache = None
|
||||||
if self.is_continuation(req) and not update_model:
|
if self.is_continuation(req) and not must_discard_cache:
|
||||||
last_kv_cache = self.last_kv_cache
|
last_kv_cache = self.last_kv_cache
|
||||||
|
|
||||||
generation_kwargs = {
|
generation_kwargs = {
|
||||||
@@ -715,7 +818,7 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
def stream_chat_completion(streamer, _request_id):
|
def stream_chat_completion(streamer, _request_id):
|
||||||
# Thin wrapper to save the KV cache after generation
|
# Thin wrapper to save the KV cache after generation
|
||||||
def generate_with_cache(**kwargs):
|
def generate_with_cache(**kwargs):
|
||||||
generate_output = self.model.generate(**kwargs)
|
generate_output = model.generate(**kwargs)
|
||||||
self.last_kv_cache = generate_output.past_key_values
|
self.last_kv_cache = generate_output.past_key_values
|
||||||
|
|
||||||
thread = Thread(target=generate_with_cache, kwargs=generation_kwargs)
|
thread = Thread(target=generate_with_cache, kwargs=generation_kwargs)
|
||||||
@@ -726,7 +829,7 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
|
|
||||||
# Emit the assistant role to start the stream. Other chunks won't have a role, as it is implicit
|
# Emit the assistant role to start the stream. Other chunks won't have a role, as it is implicit
|
||||||
# they come from the assistant.
|
# they come from the assistant.
|
||||||
yield self.build_chat_completion_chunk(request_id, role="assistant")
|
yield self.build_chat_completion_chunk(request_id, role="assistant", model=model_id_and_revision)
|
||||||
|
|
||||||
for result in streamer:
|
for result in streamer:
|
||||||
# ====== TOOL CALL LOGIC ======
|
# ====== TOOL CALL LOGIC ======
|
||||||
@@ -740,10 +843,13 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
if result.strip() == _TOOL_CALL_TOKENS[tool_model_family]["end"]:
|
if result.strip() == _TOOL_CALL_TOKENS[tool_model_family]["end"]:
|
||||||
tool_state.reset()
|
tool_state.reset()
|
||||||
yield self.build_chat_completion_chunk(
|
yield self.build_chat_completion_chunk(
|
||||||
request_id=_request_id, role=None, finish_reason="tool_calls"
|
request_id=_request_id,
|
||||||
|
role=None,
|
||||||
|
finish_reason="tool_calls",
|
||||||
|
model=model_id_and_revision,
|
||||||
)
|
)
|
||||||
continue
|
|
||||||
|
|
||||||
|
continue
|
||||||
# Inside a tool call
|
# Inside a tool call
|
||||||
if tool_state.inside_tool_call:
|
if tool_state.inside_tool_call:
|
||||||
tool_state.buffer += result
|
tool_state.buffer += result
|
||||||
@@ -789,15 +895,17 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
)
|
)
|
||||||
|
|
||||||
yield self.build_chat_completion_chunk(
|
yield self.build_chat_completion_chunk(
|
||||||
request_id=_request_id, role=None, tool_calls=[tool]
|
request_id=_request_id, role=None, tool_calls=[tool], model=model_id_and_revision
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
# ====== END OF TOOL CALL LOGIC ======
|
# ====== END OF TOOL CALL LOGIC ======
|
||||||
|
|
||||||
# All non-tool related tokens are emitted as assistant messages. Empty text is skipped.
|
# All non-tool related tokens are emitted as assistant messages. Empty text is skipped.
|
||||||
if result != "":
|
if result != "":
|
||||||
yield self.build_chat_completion_chunk(_request_id, content=result)
|
yield self.build_chat_completion_chunk(
|
||||||
yield self.build_chat_completion_chunk(_request_id, finish_reason="stop")
|
_request_id, content=result, model=model_id_and_revision
|
||||||
|
)
|
||||||
|
yield self.build_chat_completion_chunk(_request_id, finish_reason="stop", model=model_id_and_revision)
|
||||||
|
|
||||||
thread.join()
|
thread.join()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -820,24 +928,20 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
`Generator[str, None, None]`: A generator that yields the OpenAI Response events.
|
`Generator[str, None, None]`: A generator that yields the OpenAI Response events.
|
||||||
"""
|
"""
|
||||||
# TODO -- Implement non-streaming mode
|
# TODO -- Implement non-streaming mode
|
||||||
if self.args.force_model is not None:
|
model_id_and_revision = self.process_model_name(req["model"])
|
||||||
req["model"] = self.args.force_model
|
must_discard_cache = model_id_and_revision != self.last_text_model
|
||||||
|
self.last_text_model = model_id_and_revision
|
||||||
|
model, tokenizer = self.load_text_model_and_tokenizer(model_id_and_revision)
|
||||||
|
|
||||||
update_model = self.canonicalized_model_name(req["model"]) != self.loaded_model
|
text = tokenizer.apply_chat_template(req["input"], add_generation_prompt=True, tokenize=False)
|
||||||
if update_model:
|
inputs = tokenizer(text, return_tensors="pt").to(model.device)["input_ids"]
|
||||||
self.load_model_and_tokenizer(req["model"], self.args)
|
|
||||||
|
|
||||||
text = self.tokenizer.apply_chat_template(req["input"], add_generation_prompt=True, tokenize=False)
|
|
||||||
inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)["input_ids"]
|
|
||||||
request_id = req.get("previous_response_id", "req_0")
|
request_id = req.get("previous_response_id", "req_0")
|
||||||
|
|
||||||
generation_streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True, skip_prompt=True)
|
generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
|
||||||
generation_config = create_generation_config_from_req(
|
generation_config = create_generation_config_from_req(req, model_generation_config=model.generation_config)
|
||||||
req, model_generation_config=self.model.generation_config
|
|
||||||
)
|
|
||||||
|
|
||||||
last_kv_cache = None
|
last_kv_cache = None
|
||||||
if self.is_continuation(req) and not update_model:
|
if self.is_continuation(req) and not must_discard_cache:
|
||||||
last_kv_cache = self.last_kv_cache
|
last_kv_cache = self.last_kv_cache
|
||||||
|
|
||||||
generation_kwargs = {
|
generation_kwargs = {
|
||||||
@@ -850,7 +954,12 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def stream_response(streamer, _request_id):
|
def stream_response(streamer, _request_id):
|
||||||
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
|
# Thin wrapper to save the KV cache after generation
|
||||||
|
def generate_with_cache(**kwargs):
|
||||||
|
generate_output = model.generate(**kwargs)
|
||||||
|
self.last_kv_cache = generate_output.past_key_values
|
||||||
|
|
||||||
|
thread = Thread(target=generate_with_cache, kwargs=generation_kwargs)
|
||||||
sequence_number = 0
|
sequence_number = 0
|
||||||
output_index = 0
|
output_index = 0
|
||||||
content_index = 0
|
content_index = 0
|
||||||
@@ -868,7 +977,7 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
id=f"resp_{request_id}",
|
id=f"resp_{request_id}",
|
||||||
created_at=created_at,
|
created_at=created_at,
|
||||||
status="queued",
|
status="queued",
|
||||||
model=self.loaded_model,
|
model=model_id_and_revision,
|
||||||
instructions=req.get("instructions"),
|
instructions=req.get("instructions"),
|
||||||
text={"format": {"type": "text"}},
|
text={"format": {"type": "text"}},
|
||||||
object="response",
|
object="response",
|
||||||
@@ -889,7 +998,7 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
id=f"resp_{request_id}",
|
id=f"resp_{request_id}",
|
||||||
created_at=created_at,
|
created_at=created_at,
|
||||||
status="in_progress",
|
status="in_progress",
|
||||||
model=self.loaded_model,
|
model=model_id_and_revision,
|
||||||
instructions=req.get("instructions"),
|
instructions=req.get("instructions"),
|
||||||
text={"format": {"type": "text"}},
|
text={"format": {"type": "text"}},
|
||||||
object="response",
|
object="response",
|
||||||
@@ -994,7 +1103,7 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
id=f"resp_{request_id}",
|
id=f"resp_{request_id}",
|
||||||
created_at=created_at,
|
created_at=created_at,
|
||||||
status="completed",
|
status="completed",
|
||||||
model=self.loaded_model,
|
model=model_id_and_revision,
|
||||||
instructions=req.get("instructions"),
|
instructions=req.get("instructions"),
|
||||||
text={"format": {"type": "text"}},
|
text={"format": {"type": "text"}},
|
||||||
output=[response_output_item_done.item],
|
output=[response_output_item_done.item],
|
||||||
@@ -1026,7 +1135,7 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
id=f"resp_{request_id}",
|
id=f"resp_{request_id}",
|
||||||
created_at=created_at,
|
created_at=created_at,
|
||||||
status="failed",
|
status="failed",
|
||||||
model=self.loaded_model,
|
model=model_id_and_revision,
|
||||||
instructions=req.get("instructions"),
|
instructions=req.get("instructions"),
|
||||||
text={"format": {"type": "text"}},
|
text={"format": {"type": "text"}},
|
||||||
output=[],
|
output=[],
|
||||||
@@ -1049,6 +1158,54 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
|
|
||||||
return stream_response(generation_streamer, request_id)
|
return stream_response(generation_streamer, request_id)
|
||||||
|
|
||||||
|
def generate_transcription(self, req: dict) -> Generator[str, None, None]:
|
||||||
|
"""
|
||||||
|
Generates an OpenAI Transcription using the audio file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
req (`dict`): The request containing the audio file and model information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Generator[str, None, None]`: A generator that yields the transcription result.
|
||||||
|
"""
|
||||||
|
# TODO: implement streaming transcription (currently, it's not streaming)
|
||||||
|
if not is_librosa_available():
|
||||||
|
raise ImportError(
|
||||||
|
"Missing librosa dependency for audio transcription. Please install with `pip install librosa`"
|
||||||
|
)
|
||||||
|
model_id_and_revision = self.process_model_name(req["model"])
|
||||||
|
audio_model, audio_processor = self.load_audio_model_and_processor(model_id_and_revision)
|
||||||
|
|
||||||
|
generation_streamer = TextIteratorStreamer(
|
||||||
|
audio_processor.tokenizer, skip_special_tokens=True, skip_prompt=True
|
||||||
|
)
|
||||||
|
generation_config = create_generation_config_from_req(
|
||||||
|
req, model_generation_config=audio_model.generation_config
|
||||||
|
)
|
||||||
|
|
||||||
|
# Read the binary audio file using librosa
|
||||||
|
model_sampling_rate = audio_processor.feature_extractor.sampling_rate
|
||||||
|
audio_bytes = io.BytesIO(req["file"])
|
||||||
|
audio_array, _ = librosa.load(audio_bytes, sr=model_sampling_rate, mono=True)
|
||||||
|
audio_inputs = audio_processor(audio_array, sampling_rate=model_sampling_rate, return_tensors="pt").to(
|
||||||
|
audio_model.device
|
||||||
|
)
|
||||||
|
audio_inputs["input_features"] = audio_inputs["input_features"].to(audio_model.dtype)
|
||||||
|
|
||||||
|
generation_kwargs = {
|
||||||
|
"streamer": generation_streamer,
|
||||||
|
"generation_config": generation_config,
|
||||||
|
"return_dict_in_generate": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _generate_transcription():
|
||||||
|
generated_ids = audio_model.generate(**audio_inputs, **generation_kwargs)
|
||||||
|
transcription_text = audio_processor.batch_decode(generated_ids.sequences, skip_special_tokens=True)[0]
|
||||||
|
transcription = Transcription(text=transcription_text)
|
||||||
|
yield f"{transcription.model_dump_json(exclude_none=True)}"
|
||||||
|
|
||||||
|
return _generate_transcription()
|
||||||
|
|
||||||
def is_continuation(self, req: dict) -> bool:
|
def is_continuation(self, req: dict) -> bool:
|
||||||
"""
|
"""
|
||||||
Determines whether the current request is a continuation of the last request. In other words, if it is the
|
Determines whether the current request is a continuation of the last request. In other words, if it is the
|
||||||
@@ -1108,39 +1265,49 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
|
|
||||||
return quantization_config
|
return quantization_config
|
||||||
|
|
||||||
def canonicalized_model_name(self, model_id: str) -> str:
|
def process_model_name(self, model_id: str) -> str:
|
||||||
"""
|
"""
|
||||||
Canonicalizes the model name to the format "model_id@revision". If the model_id DOESN'T contain an @, it
|
Applies the `force_model` CLI argument and canonicalizes the model name to the format "model_id@revision".
|
||||||
defaults to "model_id@main".
|
If the model_id DOESN'T contain an @, it defaults to "model_id@main".
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_id (`str`): The model ID.
|
model_id (`str`): The model ID.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`str`: The canonicalized model name.
|
`str`: The canonicalized model name to be used
|
||||||
"""
|
"""
|
||||||
|
if self.args.force_model is not None:
|
||||||
|
model_id = self.args.force_model
|
||||||
if "@" in model_id:
|
if "@" in model_id:
|
||||||
return model_id
|
return model_id
|
||||||
return f"{model_id}@main"
|
return f"{model_id}@main"
|
||||||
|
|
||||||
def load_model_and_tokenizer(self, model_id_and_revision: str, args: ServeArguments):
|
def _load_model_and_data_processor(
|
||||||
|
self, model_id_and_revision: str, model_cls: type[PreTrainedModel]
|
||||||
|
) -> tuple[PreTrainedModel, Union[ProcessorMixin, PreTrainedTokenizerFast]]:
|
||||||
"""
|
"""
|
||||||
Loads the model and tokenizer from the given model ID and revision into the ServeCommand instance.
|
Generic method to load a model and a data processor from a model ID and revision, making use of the serve CLI
|
||||||
|
arguments.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_id_and_revision (`str`):
|
model_id_and_revision (`str`):
|
||||||
The model ID and revision to load.
|
The model ID and revision to load.
|
||||||
args (`ServeArguments`):
|
model_cls (`type[PreTrainedModel]`):
|
||||||
The serve arguments. May contain quantization settings, device, etc.
|
The model class to load.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`tuple[PreTrainedModel, Union[ProcessorMixin, PreTrainedTokenizerFast]]`: The loaded model and
|
||||||
|
data processor (tokenizer, audio processor, etc.).
|
||||||
"""
|
"""
|
||||||
logger.warning(f"Loading {model_id_and_revision}")
|
args = self.args
|
||||||
|
logger.info(f"Loading {model_id_and_revision}")
|
||||||
|
|
||||||
if "@" in model_id_and_revision:
|
if "@" in model_id_and_revision:
|
||||||
model_id, revision = model_id_and_revision.split("@", 1)
|
model_id, revision = model_id_and_revision.split("@", 1)
|
||||||
else:
|
else:
|
||||||
model_id, revision = model_id_and_revision, "main"
|
model_id, revision = model_id_and_revision, "main"
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
data_processor = AutoProcessor.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
trust_remote_code=args.trust_remote_code,
|
trust_remote_code=args.trust_remote_code,
|
||||||
@@ -1158,19 +1325,76 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
"trust_remote_code": args.trust_remote_code,
|
"trust_remote_code": args.trust_remote_code,
|
||||||
}
|
}
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
|
model = model_cls.from_pretrained(model_id, **model_kwargs)
|
||||||
|
|
||||||
if model.generation_config.max_new_tokens is not None and model.generation_config.max_new_tokens < 1024:
|
|
||||||
model.generation_config.max_new_tokens = 1024
|
|
||||||
|
|
||||||
if getattr(model, "hf_device_map", None) is None:
|
if getattr(model, "hf_device_map", None) is None:
|
||||||
model = model.to(args.device)
|
model = model.to(args.device)
|
||||||
|
|
||||||
self.loaded_model = f"{model_id}@{revision}"
|
has_default_max_length = (
|
||||||
|
model.generation_config.max_new_tokens is None and model.generation_config.max_length == 20
|
||||||
|
)
|
||||||
|
has_short_max_new_tokens = (
|
||||||
|
model.generation_config.max_new_tokens is not None and model.generation_config.max_new_tokens < 1024
|
||||||
|
)
|
||||||
|
if has_default_max_length or has_short_max_new_tokens:
|
||||||
|
model.generation_config.max_new_tokens = 1024
|
||||||
|
|
||||||
logger.warning(f"Loaded model {self.loaded_model}")
|
logger.info(f"Loaded model {model_id_and_revision}")
|
||||||
self.model = model
|
return model, data_processor
|
||||||
self.tokenizer = tokenizer
|
|
||||||
|
def load_text_model_and_tokenizer(
|
||||||
|
self, model_id_and_revision: str
|
||||||
|
) -> tuple[PreTrainedModel, PreTrainedTokenizerFast]:
|
||||||
|
"""
|
||||||
|
Loads the text model and tokenizer from the given model ID and revision into the ServeCommand instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id_and_revision (`str`):
|
||||||
|
The model ID and revision to load.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`tuple[PreTrainedModel, PreTrainedTokenizerFast]`: The loaded text model and tokenizer.
|
||||||
|
"""
|
||||||
|
if model_id_and_revision not in self.loaded_models or self.loaded_models[model_id_and_revision].is_deleted():
|
||||||
|
model, tokenizer = self._load_model_and_data_processor(model_id_and_revision, AutoModelForCausalLM)
|
||||||
|
self.loaded_models[model_id_and_revision] = TimedModel(
|
||||||
|
model,
|
||||||
|
timeout_seconds=self.args.model_timeout,
|
||||||
|
processor=tokenizer,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.loaded_models[model_id_and_revision].reset_timer()
|
||||||
|
model = self.loaded_models[model_id_and_revision].model
|
||||||
|
tokenizer = self.loaded_models[model_id_and_revision].processor
|
||||||
|
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
def load_audio_model_and_processor(self, model_id_and_revision: str) -> tuple[PreTrainedModel, ProcessorMixin]:
|
||||||
|
"""
|
||||||
|
Loads the audio model and processor from the given model ID and revision into the ServeCommand instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id_and_revision (`str`):
|
||||||
|
The model ID and revision to load.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`tuple[PreTrainedModel, ProcessorMixin]`: The loaded audio model and processor.
|
||||||
|
"""
|
||||||
|
if model_id_and_revision not in self.loaded_models or self.loaded_models[model_id_and_revision].is_deleted():
|
||||||
|
audio_model, audio_processor = self._load_model_and_data_processor(
|
||||||
|
model_id_and_revision, AutoModelForSpeechSeq2Seq
|
||||||
|
)
|
||||||
|
self.loaded_models[model_id_and_revision] = TimedModel(
|
||||||
|
audio_model,
|
||||||
|
timeout_seconds=self.args.model_timeout,
|
||||||
|
processor=audio_processor,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.loaded_models[model_id_and_revision].reset_timer()
|
||||||
|
audio_model = self.loaded_models[model_id_and_revision].model
|
||||||
|
audio_processor = self.loaded_models[model_id_and_revision].processor
|
||||||
|
|
||||||
|
return audio_model, audio_processor
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -63,14 +63,13 @@ class ServeCLITest(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
dummy = ServeCommand.__new__(ServeCommand)
|
dummy = ServeCommand.__new__(ServeCommand)
|
||||||
dummy.args = type("Args", (), {})()
|
dummy.args = type("Args", (), {})()
|
||||||
dummy.loaded_model = "dummy_model@main"
|
|
||||||
|
|
||||||
# The keys for these fields must be present in every chunk
|
# The keys for these fields must be present in every chunk
|
||||||
MANDATORY_FIELDS = ["data", "id", "choices", "created", "model", "object", "system_fingerprint"]
|
MANDATORY_FIELDS = ["data", "id", "choices", "created", "model", "object", "system_fingerprint"]
|
||||||
|
|
||||||
# Case 1: most fields are provided
|
# Case 1: most fields are provided
|
||||||
chunk = ServeCommand.build_chat_completion_chunk(
|
chunk = ServeCommand.build_chat_completion_chunk(
|
||||||
dummy, request_id="req0", content="hello", finish_reason="stop", role="user"
|
dummy, request_id="req0", content="hello", finish_reason="stop", role="user", model="dummy_model@main"
|
||||||
)
|
)
|
||||||
for field in MANDATORY_FIELDS:
|
for field in MANDATORY_FIELDS:
|
||||||
self.assertIn(field, chunk)
|
self.assertIn(field, chunk)
|
||||||
@@ -79,13 +78,13 @@ class ServeCLITest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Case 2: only the role is provided -- other fields in 'choices' are omitted
|
# Case 2: only the role is provided -- other fields in 'choices' are omitted
|
||||||
chunk = dummy.build_chat_completion_chunk(request_id="req0", role="user")
|
chunk = dummy.build_chat_completion_chunk(request_id="req0", role="user", model="dummy_model@main")
|
||||||
for field in MANDATORY_FIELDS:
|
for field in MANDATORY_FIELDS:
|
||||||
self.assertIn(field, chunk)
|
self.assertIn(field, chunk)
|
||||||
self.assertIn('"choices":[{"delta":{"role":"user"},"index":0}]', chunk)
|
self.assertIn('"choices":[{"delta":{"role":"user"},"index":0}]', chunk)
|
||||||
|
|
||||||
# Case 3: only the content is provided -- other fields in 'choices' are omitted
|
# Case 3: only the content is provided -- other fields in 'choices' are omitted
|
||||||
chunk = dummy.build_chat_completion_chunk(request_id="req0", content="hello")
|
chunk = dummy.build_chat_completion_chunk(request_id="req0", content="hello", model="dummy_model@main")
|
||||||
for field in MANDATORY_FIELDS:
|
for field in MANDATORY_FIELDS:
|
||||||
self.assertIn(field, chunk)
|
self.assertIn(field, chunk)
|
||||||
self.assertIn('"choices":[{"delta":{"content":"hello"},"index":0}]', chunk)
|
self.assertIn('"choices":[{"delta":{"content":"hello"},"index":0}]', chunk)
|
||||||
@@ -96,7 +95,7 @@ class ServeCLITest(unittest.TestCase):
|
|||||||
function=ChoiceDeltaToolCallFunction(name="foo_bar", arguments='{"foo1": "bar1", "foo2": "bar2"}'),
|
function=ChoiceDeltaToolCallFunction(name="foo_bar", arguments='{"foo1": "bar1", "foo2": "bar2"}'),
|
||||||
type="function",
|
type="function",
|
||||||
)
|
)
|
||||||
chunk = dummy.build_chat_completion_chunk(request_id="req0", tool_calls=[tool_call])
|
chunk = dummy.build_chat_completion_chunk(request_id="req0", tool_calls=[tool_call], model="dummy_model@main")
|
||||||
for field in MANDATORY_FIELDS:
|
for field in MANDATORY_FIELDS:
|
||||||
self.assertIn(field, chunk)
|
self.assertIn(field, chunk)
|
||||||
expected_choices_content = (
|
expected_choices_content = (
|
||||||
|
|||||||
Reference in New Issue
Block a user