Transformers serve VLM (#39454)
* Add support for VLMs in Transformers Serve * Raushan comments * Update src/transformers/commands/serving.py Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com> * Quick fix * CPU -> Auto * Update src/transformers/commands/serving.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Fixup --------- Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com> Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
@@ -11,22 +11,34 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import base64
|
||||||
import copy
|
import copy
|
||||||
|
import datetime
|
||||||
|
import enum
|
||||||
import functools
|
import functools
|
||||||
import gc
|
import gc
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
import tempfile
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from argparse import ArgumentParser, Namespace
|
from argparse import ArgumentParser, Namespace
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from io import BytesIO
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Generator, Optional, Union
|
from typing import Generator, Iterable, Optional, Union
|
||||||
|
|
||||||
from huggingface_hub import ModelInfo, model_info
|
from huggingface_hub import model_info
|
||||||
|
from huggingface_hub.constants import HF_HUB_OFFLINE
|
||||||
|
from openai.types.chat import ChatCompletionMessageParam
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
import transformers
|
||||||
|
from transformers.models.auto.modeling_auto import (
|
||||||
|
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
||||||
|
MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES,
|
||||||
|
)
|
||||||
from transformers.utils.import_utils import (
|
from transformers.utils.import_utils import (
|
||||||
is_fastapi_available,
|
is_fastapi_available,
|
||||||
is_librosa_available,
|
is_librosa_available,
|
||||||
@@ -35,7 +47,13 @@ from transformers.utils.import_utils import (
|
|||||||
is_uvicorn_available,
|
is_uvicorn_available,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .. import LogitsProcessorList, PreTrainedTokenizerFast, ProcessorMixin, TextIteratorStreamer
|
from .. import (
|
||||||
|
AutoConfig,
|
||||||
|
LogitsProcessorList,
|
||||||
|
PreTrainedTokenizerFast,
|
||||||
|
ProcessorMixin,
|
||||||
|
TextIteratorStreamer,
|
||||||
|
)
|
||||||
from ..generation.continuous_batching import ContinuousBatchingManager, RequestStatus
|
from ..generation.continuous_batching import ContinuousBatchingManager, RequestStatus
|
||||||
from ..utils import is_torch_available, logging
|
from ..utils import is_torch_available, logging
|
||||||
from . import BaseTransformersCLICommand
|
from . import BaseTransformersCLICommand
|
||||||
@@ -45,8 +63,6 @@ if is_torch_available():
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoModelForCausalLM,
|
|
||||||
AutoModelForSpeechSeq2Seq,
|
|
||||||
AutoProcessor,
|
AutoProcessor,
|
||||||
BitsAndBytesConfig,
|
BitsAndBytesConfig,
|
||||||
GenerationConfig,
|
GenerationConfig,
|
||||||
@@ -187,6 +203,13 @@ _TOOL_CALL_TOKENS = {
|
|||||||
_MODELS_WITH_TOOL_SUPPORT = list(_TOOL_CALL_TOKENS.keys())
|
_MODELS_WITH_TOOL_SUPPORT = list(_TOOL_CALL_TOKENS.keys())
|
||||||
|
|
||||||
|
|
||||||
|
class Modality(enum.Enum):
|
||||||
|
LLM = "LLM"
|
||||||
|
VLM = "VLM"
|
||||||
|
STT = "STT"
|
||||||
|
TTS = "TTS"
|
||||||
|
|
||||||
|
|
||||||
def serve_command_factory(args: Namespace):
|
def serve_command_factory(args: Namespace):
|
||||||
"""
|
"""
|
||||||
Factory function used to instantiate serving server from provided command line arguments.
|
Factory function used to instantiate serving server from provided command line arguments.
|
||||||
@@ -271,7 +294,7 @@ class ToolState:
|
|||||||
|
|
||||||
class TimedModel:
|
class TimedModel:
|
||||||
"""
|
"""
|
||||||
A class that holds a PreTrainedModel instance and its associated processor (tokenizer, audio processor, etc.).
|
A class that holds a PreTrainedModel instance and its associated processor.
|
||||||
Automatically deletes the instances after a specified timeout.
|
Automatically deletes the instances after a specified timeout.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -325,7 +348,13 @@ class ServeArguments:
|
|||||||
`transformers serve --help`
|
`transformers serve --help`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
device: str = field(default="cpu", metadata={"help": "Device to use for inference."})
|
device: str = field(
|
||||||
|
default="auto",
|
||||||
|
metadata={
|
||||||
|
"help": "Device to use for inference; will default to `auto` and"
|
||||||
|
"place the model on an accelerator if available."
|
||||||
|
},
|
||||||
|
)
|
||||||
torch_dtype: Optional[str] = field(
|
torch_dtype: Optional[str] = field(
|
||||||
default="auto",
|
default="auto",
|
||||||
metadata={
|
metadata={
|
||||||
@@ -438,7 +467,7 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
# cache and avoid re-running prefil
|
# cache and avoid re-running prefil
|
||||||
self.last_messages = None
|
self.last_messages = None
|
||||||
self.last_kv_cache = None
|
self.last_kv_cache = None
|
||||||
self.last_text_model = None
|
self.last_model = None
|
||||||
|
|
||||||
def _validate_request(
|
def _validate_request(
|
||||||
self,
|
self,
|
||||||
@@ -632,27 +661,15 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
output = self.generate_transcription(parsed_request)
|
output = self.generate_transcription(parsed_request)
|
||||||
return StreamingResponse(output, media_type="text/event-stream")
|
return StreamingResponse(output, media_type="text/event-stream")
|
||||||
|
|
||||||
|
@app.options("/v1/models")
|
||||||
@app.get("/v1/models")
|
@app.get("/v1/models")
|
||||||
def get_all_models():
|
def get_all_models():
|
||||||
return JSONResponse(
|
return JSONResponse({"object": "list", "data": self.get_gen_models()})
|
||||||
{
|
|
||||||
"object": "list",
|
|
||||||
"data": [
|
|
||||||
{
|
|
||||||
"id": model.id,
|
|
||||||
"object": "model",
|
|
||||||
"created": model.created_at.timestamp(),
|
|
||||||
"owned_by": model.author,
|
|
||||||
}
|
|
||||||
for model in self.get_text_gen_models()
|
|
||||||
],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
uvicorn.run(app, host=self.args.host, port=self.args.port, log_level=self.args.log_level)
|
uvicorn.run(app, host=self.args.host, port=self.args.port, log_level=self.args.log_level)
|
||||||
|
|
||||||
@functools.lru_cache(maxsize=None)
|
@functools.lru_cache(maxsize=None)
|
||||||
def get_text_gen_models(self) -> list[ModelInfo]:
|
def get_gen_models(self) -> list[dict[str, any]]:
|
||||||
"""
|
"""
|
||||||
This is by no means a limit to which models may be instantiated with `transformers serve`: any chat-based
|
This is by no means a limit to which models may be instantiated with `transformers serve`: any chat-based
|
||||||
model working with generate can work.
|
model working with generate can work.
|
||||||
@@ -660,16 +677,42 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
This is a limited list of models to ensure we have a discoverable /v1/models endpoint for third-party
|
This is a limited list of models to ensure we have a discoverable /v1/models endpoint for third-party
|
||||||
integrations.
|
integrations.
|
||||||
"""
|
"""
|
||||||
|
models = [
|
||||||
|
"Menlo/Jan-nano",
|
||||||
|
"Menlo/Jan-nano-128k",
|
||||||
|
"Qwen/Qwen2.5-0.5B-Instruct",
|
||||||
|
"Qwen/Qwen2.5-3B-Instruct",
|
||||||
|
"Qwen/Qwen2.5-7B-Instruct",
|
||||||
|
"Qwen/Qwen2.5-14B-Instruct",
|
||||||
|
"meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"meta-llama/Llama-3.2-1B-Instruct",
|
||||||
|
"meta-llama/Llama-3.3-70B-Instruct",
|
||||||
|
"HuggingFaceTB/SmolVLM-Instruct",
|
||||||
|
"ibm-granite/granite-vision-3.2-2b",
|
||||||
|
"Qwen/Qwen2.5-VL-7B-Instruct",
|
||||||
|
"OpenGVLab/InternVL3-1B",
|
||||||
|
]
|
||||||
|
|
||||||
|
if HF_HUB_OFFLINE:
|
||||||
return [
|
return [
|
||||||
model_info("Menlo/Jan-nano"),
|
{
|
||||||
model_info("Menlo/Jan-nano-128k"),
|
"id": model,
|
||||||
model_info("Qwen/Qwen2.5-0.5B-Instruct"),
|
"object": "model",
|
||||||
model_info("Qwen/Qwen2.5-3B-Instruct"),
|
"created": datetime.datetime.now().timestamp(),
|
||||||
model_info("Qwen/Qwen2.5-7B-Instruct"),
|
"owned_by": model.split("/")[0],
|
||||||
model_info("Qwen/Qwen2.5-14B-Instruct"),
|
}
|
||||||
model_info("meta-llama/Llama-3.1-8B-Instruct"),
|
for model in models
|
||||||
model_info("meta-llama/Llama-3.2-1B-Instruct"),
|
]
|
||||||
model_info("meta-llama/Llama-3.3-70B-Instruct"),
|
else:
|
||||||
|
model_infos = [model_info(model) for model in models]
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": model.id,
|
||||||
|
"object": "model",
|
||||||
|
"created": model.created_at.timestamp(),
|
||||||
|
"owned_by": model.author,
|
||||||
|
}
|
||||||
|
for model in model_infos
|
||||||
]
|
]
|
||||||
|
|
||||||
def continuous_batching_chat_completion(self, req: dict) -> Generator[str, None, None]:
|
def continuous_batching_chat_completion(self, req: dict) -> Generator[str, None, None]:
|
||||||
@@ -684,14 +727,16 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
model_id_and_revision = self.process_model_name(req["model"])
|
model_id_and_revision = self.process_model_name(req["model"])
|
||||||
must_discard_cache = model_id_and_revision != self.last_text_model
|
must_discard_cache = model_id_and_revision != self.last_model
|
||||||
self.last_text_model = model_id_and_revision
|
self.last_model = model_id_and_revision
|
||||||
if must_discard_cache:
|
if must_discard_cache:
|
||||||
# When switching models, terminate a continuous batching manager if it is running.
|
# When switching models, terminate a continuous batching manager if it is running.
|
||||||
if self.running_continuous_batching_manager is not None:
|
if self.running_continuous_batching_manager is not None:
|
||||||
self.running_continuous_batching_manager.stop(block=True, timeout=2)
|
self.running_continuous_batching_manager.stop(block=True, timeout=2)
|
||||||
self.running_continuous_batching_manager = None
|
self.running_continuous_batching_manager = None
|
||||||
model, tokenizer = self.load_text_model_and_tokenizer(model_id_and_revision)
|
model, processor = self.load_model_and_processor(model_id_and_revision)
|
||||||
|
|
||||||
|
tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
|
||||||
|
|
||||||
generation_config = create_generation_config_from_req(
|
generation_config = create_generation_config_from_req(
|
||||||
req,
|
req,
|
||||||
@@ -717,7 +762,7 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
self.running_continuous_batching_manager.start()
|
self.running_continuous_batching_manager.start()
|
||||||
|
|
||||||
# TODO (Joao, Lysandre): this should also work with tool support
|
# TODO (Joao, Lysandre): this should also work with tool support
|
||||||
inputs = tokenizer.apply_chat_template(req["messages"], return_tensors="pt", add_generation_prompt=True).to(
|
inputs = processor.apply_chat_template(req["messages"], return_tensors="pt", add_generation_prompt=True).to(
|
||||||
model.device
|
model.device
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -759,6 +804,50 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
|
|
||||||
return stream_chat_completion(inputs[0])
|
return stream_chat_completion(inputs[0])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_model_modality(model: PreTrainedModel) -> Modality:
|
||||||
|
model_classname = model.__class__.__name__
|
||||||
|
if model_classname in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values():
|
||||||
|
modality = Modality.VLM
|
||||||
|
elif model_classname in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
|
||||||
|
modality = Modality.LLM
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown modality: {model_classname}")
|
||||||
|
|
||||||
|
return modality
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_processor_inputs_from_inbound_messages(messages, modality: Modality):
|
||||||
|
processor_inputs = []
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
parsed_message = {"role": message["role"], "content": []}
|
||||||
|
|
||||||
|
if modality == Modality.LLM:
|
||||||
|
# If we're working with LLMs, then "content" is a single string.
|
||||||
|
content = message["content"] if isinstance(message["content"], str) else message["content"]["text"]
|
||||||
|
parsed_message["content"] = content
|
||||||
|
|
||||||
|
elif modality == Modality.VLM:
|
||||||
|
# If we're working with VLMs, then "content" is a dictionary, containing a "type" key indicating
|
||||||
|
# which other key will be present and the type of the value of said key.
|
||||||
|
if isinstance(message["content"], str):
|
||||||
|
parsed_message["content"].append({"type": "text", "text": message["content"]})
|
||||||
|
else:
|
||||||
|
for content in message["content"]:
|
||||||
|
if content["type"] == "text":
|
||||||
|
parsed_message["content"].append(content)
|
||||||
|
elif content["type"] == "image_url":
|
||||||
|
image_data = re.sub("^data:image/.+;base64,", "", content["image_url"]["url"])
|
||||||
|
image = Image.open(BytesIO(base64.b64decode(image_data)))
|
||||||
|
|
||||||
|
file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
||||||
|
image.save(file.name)
|
||||||
|
|
||||||
|
parsed_message["content"].append({"type": "image", "url": file.name})
|
||||||
|
processor_inputs.append(parsed_message)
|
||||||
|
return processor_inputs
|
||||||
|
|
||||||
def generate_chat_completion(self, req: dict) -> Generator[str, None, None]:
|
def generate_chat_completion(self, req: dict) -> Generator[str, None, None]:
|
||||||
"""
|
"""
|
||||||
Generates an OpenAI Chat Completion using `generate`.
|
Generates an OpenAI Chat Completion using `generate`.
|
||||||
@@ -769,15 +858,24 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
Returns:
|
Returns:
|
||||||
`Generator[str, None, None]`: A generator that yields the OpenAI Chat Completion chunks.
|
`Generator[str, None, None]`: A generator that yields the OpenAI Chat Completion chunks.
|
||||||
"""
|
"""
|
||||||
|
if self.args.force_model is not None:
|
||||||
|
req["model"] = self.args.force_model
|
||||||
|
|
||||||
|
messages: Iterable[ChatCompletionMessageParam] = req["messages"]
|
||||||
|
|
||||||
# HACK for tiny-agents: it sends a request after the assistant message (???). Let's assume we can't have a
|
# HACK for tiny-agents: it sends a request after the assistant message (???). Let's assume we can't have a
|
||||||
# request whose last message is from the assistant.
|
# request whose last message is from the assistant.
|
||||||
if req["messages"][-1]["role"] == "assistant":
|
if messages[-1]["role"] == "assistant":
|
||||||
return
|
return
|
||||||
|
|
||||||
model_id_and_revision = self.process_model_name(req["model"])
|
model_id_and_revision = self.process_model_name(req["model"])
|
||||||
must_discard_cache = model_id_and_revision != self.last_text_model
|
must_discard_cache = model_id_and_revision != self.last_model
|
||||||
self.last_text_model = model_id_and_revision
|
|
||||||
model, tokenizer = self.load_text_model_and_tokenizer(model_id_and_revision)
|
self.last_model = model_id_and_revision
|
||||||
|
model, processor = self.load_model_and_processor(model_id_and_revision)
|
||||||
|
|
||||||
|
modality = self.get_model_modality(model)
|
||||||
|
processor_inputs = self.get_processor_inputs_from_inbound_messages(messages, modality)
|
||||||
|
|
||||||
# ====== TOOL PREPROCESSING LOGIC ======
|
# ====== TOOL PREPROCESSING LOGIC ======
|
||||||
tool_model_family = None
|
tool_model_family = None
|
||||||
@@ -790,16 +888,18 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
# 2. force generation to pick from that tool's arguments
|
# 2. force generation to pick from that tool's arguments
|
||||||
# ====== END OF TOOL PREPROCESSING LOGIC ======
|
# ====== END OF TOOL PREPROCESSING LOGIC ======
|
||||||
|
|
||||||
if tool_model_family is not None:
|
inputs = processor.apply_chat_template(
|
||||||
text = tokenizer.apply_chat_template(
|
processor_inputs,
|
||||||
req["messages"], add_generation_prompt=True, tokenize=False, tools=req.get("tools")
|
add_generation_prompt=True,
|
||||||
|
tools=req.get("tools", None),
|
||||||
|
return_tensors="pt",
|
||||||
|
return_dict=True,
|
||||||
|
tokenize=True,
|
||||||
)
|
)
|
||||||
else:
|
inputs = inputs.to(model.device)
|
||||||
text = tokenizer.apply_chat_template(req["messages"], add_generation_prompt=True, tokenize=False)
|
|
||||||
inputs = tokenizer(text, return_tensors="pt").to(model.device)["input_ids"]
|
|
||||||
request_id = req.get("request_id", "req_0")
|
request_id = req.get("request_id", "req_0")
|
||||||
|
|
||||||
generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
|
generation_streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True)
|
||||||
generation_config = create_generation_config_from_req(req, model_generation_config=model.generation_config)
|
generation_config = create_generation_config_from_req(req, model_generation_config=model.generation_config)
|
||||||
|
|
||||||
last_kv_cache = None
|
last_kv_cache = None
|
||||||
@@ -807,8 +907,7 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
last_kv_cache = self.last_kv_cache
|
last_kv_cache = self.last_kv_cache
|
||||||
|
|
||||||
generation_kwargs = {
|
generation_kwargs = {
|
||||||
"inputs": inputs,
|
**inputs,
|
||||||
"attention_mask": torch.ones_like(inputs),
|
|
||||||
"streamer": generation_streamer,
|
"streamer": generation_streamer,
|
||||||
"generation_config": generation_config,
|
"generation_config": generation_config,
|
||||||
"return_dict_in_generate": True,
|
"return_dict_in_generate": True,
|
||||||
@@ -929,15 +1028,14 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
"""
|
"""
|
||||||
# TODO -- Implement non-streaming mode
|
# TODO -- Implement non-streaming mode
|
||||||
model_id_and_revision = self.process_model_name(req["model"])
|
model_id_and_revision = self.process_model_name(req["model"])
|
||||||
must_discard_cache = model_id_and_revision != self.last_text_model
|
must_discard_cache = model_id_and_revision != self.last_model
|
||||||
self.last_text_model = model_id_and_revision
|
self.last_model = model_id_and_revision
|
||||||
model, tokenizer = self.load_text_model_and_tokenizer(model_id_and_revision)
|
model, processor = self.load_model_and_processor(model_id_and_revision)
|
||||||
|
|
||||||
text = tokenizer.apply_chat_template(req["input"], add_generation_prompt=True, tokenize=False)
|
inputs = processor.apply_chat_template(req["input"], add_generation_prompt=True).to(model.device)
|
||||||
inputs = tokenizer(text, return_tensors="pt").to(model.device)["input_ids"]
|
|
||||||
request_id = req.get("previous_response_id", "req_0")
|
request_id = req.get("previous_response_id", "req_0")
|
||||||
|
|
||||||
generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
|
generation_streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True)
|
||||||
generation_config = create_generation_config_from_req(req, model_generation_config=model.generation_config)
|
generation_config = create_generation_config_from_req(req, model_generation_config=model.generation_config)
|
||||||
|
|
||||||
last_kv_cache = None
|
last_kv_cache = None
|
||||||
@@ -1282,9 +1380,7 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
return model_id
|
return model_id
|
||||||
return f"{model_id}@main"
|
return f"{model_id}@main"
|
||||||
|
|
||||||
def _load_model_and_data_processor(
|
def _load_model_and_data_processor(self, model_id_and_revision: str):
|
||||||
self, model_id_and_revision: str, model_cls: type[PreTrainedModel]
|
|
||||||
) -> tuple[PreTrainedModel, Union[ProcessorMixin, PreTrainedTokenizerFast]]:
|
|
||||||
"""
|
"""
|
||||||
Generic method to load a model and a data processor from a model ID and revision, making use of the serve CLI
|
Generic method to load a model and a data processor from a model ID and revision, making use of the serve CLI
|
||||||
arguments.
|
arguments.
|
||||||
@@ -1325,7 +1421,9 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
"trust_remote_code": args.trust_remote_code,
|
"trust_remote_code": args.trust_remote_code,
|
||||||
}
|
}
|
||||||
|
|
||||||
model = model_cls.from_pretrained(model_id, **model_kwargs)
|
config = AutoConfig.from_pretrained(model_id, **model_kwargs)
|
||||||
|
architecture = getattr(transformers, config.architectures[0])
|
||||||
|
model = architecture.from_pretrained(model_id, **model_kwargs)
|
||||||
|
|
||||||
if getattr(model, "hf_device_map", None) is None:
|
if getattr(model, "hf_device_map", None) is None:
|
||||||
model = model.to(args.device)
|
model = model.to(args.device)
|
||||||
@@ -1342,32 +1440,30 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
logger.info(f"Loaded model {model_id_and_revision}")
|
logger.info(f"Loaded model {model_id_and_revision}")
|
||||||
return model, data_processor
|
return model, data_processor
|
||||||
|
|
||||||
def load_text_model_and_tokenizer(
|
def load_model_and_processor(self, model_id_and_revision: str) -> tuple[PreTrainedModel, PreTrainedTokenizerFast]:
|
||||||
self, model_id_and_revision: str
|
|
||||||
) -> tuple[PreTrainedModel, PreTrainedTokenizerFast]:
|
|
||||||
"""
|
"""
|
||||||
Loads the text model and tokenizer from the given model ID and revision into the ServeCommand instance.
|
Loads the text model and processor from the given model ID and revision into the ServeCommand instance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_id_and_revision (`str`):
|
model_id_and_revision (`str`):
|
||||||
The model ID and revision to load.
|
The model ID and revision to load.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`tuple[PreTrainedModel, PreTrainedTokenizerFast]`: The loaded text model and tokenizer.
|
`tuple[PreTrainedModel, PreTrainedTokenizerFast]`: The loaded text model and processor.
|
||||||
"""
|
"""
|
||||||
if model_id_and_revision not in self.loaded_models or self.loaded_models[model_id_and_revision].is_deleted():
|
if model_id_and_revision not in self.loaded_models or self.loaded_models[model_id_and_revision].is_deleted():
|
||||||
model, tokenizer = self._load_model_and_data_processor(model_id_and_revision, AutoModelForCausalLM)
|
model, processor = self._load_model_and_data_processor(model_id_and_revision)
|
||||||
self.loaded_models[model_id_and_revision] = TimedModel(
|
self.loaded_models[model_id_and_revision] = TimedModel(
|
||||||
model,
|
model,
|
||||||
timeout_seconds=self.args.model_timeout,
|
timeout_seconds=self.args.model_timeout,
|
||||||
processor=tokenizer,
|
processor=processor,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.loaded_models[model_id_and_revision].reset_timer()
|
self.loaded_models[model_id_and_revision].reset_timer()
|
||||||
model = self.loaded_models[model_id_and_revision].model
|
model = self.loaded_models[model_id_and_revision].model
|
||||||
tokenizer = self.loaded_models[model_id_and_revision].processor
|
processor = self.loaded_models[model_id_and_revision].processor
|
||||||
|
|
||||||
return model, tokenizer
|
return model, processor
|
||||||
|
|
||||||
def load_audio_model_and_processor(self, model_id_and_revision: str) -> tuple[PreTrainedModel, ProcessorMixin]:
|
def load_audio_model_and_processor(self, model_id_and_revision: str) -> tuple[PreTrainedModel, ProcessorMixin]:
|
||||||
"""
|
"""
|
||||||
@@ -1381,9 +1477,7 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
`tuple[PreTrainedModel, ProcessorMixin]`: The loaded audio model and processor.
|
`tuple[PreTrainedModel, ProcessorMixin]`: The loaded audio model and processor.
|
||||||
"""
|
"""
|
||||||
if model_id_and_revision not in self.loaded_models or self.loaded_models[model_id_and_revision].is_deleted():
|
if model_id_and_revision not in self.loaded_models or self.loaded_models[model_id_and_revision].is_deleted():
|
||||||
audio_model, audio_processor = self._load_model_and_data_processor(
|
audio_model, audio_processor = self._load_model_and_data_processor(model_id_and_revision)
|
||||||
model_id_and_revision, AutoModelForSpeechSeq2Seq
|
|
||||||
)
|
|
||||||
self.loaded_models[model_id_and_revision] = TimedModel(
|
self.loaded_models[model_id_and_revision] = TimedModel(
|
||||||
audio_model,
|
audio_model,
|
||||||
timeout_seconds=self.args.model_timeout,
|
timeout_seconds=self.args.model_timeout,
|
||||||
|
|||||||
@@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
@@ -23,7 +24,7 @@ from parameterized import parameterized
|
|||||||
|
|
||||||
import transformers.commands.transformers_cli as cli
|
import transformers.commands.transformers_cli as cli
|
||||||
from transformers import GenerationConfig
|
from transformers import GenerationConfig
|
||||||
from transformers.commands.serving import ServeArguments, ServeCommand
|
from transformers.commands.serving import Modality, ServeArguments, ServeCommand
|
||||||
from transformers.testing_utils import CaptureStd, require_openai, slow
|
from transformers.testing_utils import CaptureStd, require_openai, slow
|
||||||
from transformers.utils.import_utils import is_openai_available
|
from transformers.utils.import_utils import is_openai_available
|
||||||
|
|
||||||
@@ -258,6 +259,104 @@ class ServeCompletionsMixin:
|
|||||||
# TODO: speed-based test to confirm that KV cache is working across requests
|
# TODO: speed-based test to confirm that KV cache is working across requests
|
||||||
|
|
||||||
|
|
||||||
|
class ServeCompletionsGenerateMockTests(unittest.TestCase):
|
||||||
|
def test_processor_inputs_from_inbound_messages_llm(self):
|
||||||
|
modality = Modality.LLM
|
||||||
|
messages = expected_outputs = [
|
||||||
|
{"role": "user", "content": "How are you doing?"},
|
||||||
|
{"role": "assistant", "content": "I'm doing great, thank you for asking! How can I assist you today?"},
|
||||||
|
{"role": "user", "content": "Can you help me write tests?"},
|
||||||
|
]
|
||||||
|
outputs = ServeCommand.get_processor_inputs_from_inbound_messages(messages, modality)
|
||||||
|
self.assertListEqual(expected_outputs, outputs)
|
||||||
|
|
||||||
|
def test_processor_inputs_from_inbound_messages_vlm_text_only(self):
|
||||||
|
modality = Modality.VLM
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "How are you doing?"},
|
||||||
|
{"role": "assistant", "content": "I'm doing great, thank you for asking! How can I assist you today?"},
|
||||||
|
{"role": "user", "content": "Can you help me write tests?"},
|
||||||
|
]
|
||||||
|
|
||||||
|
expected_outputs = [
|
||||||
|
{"role": "user", "content": [{"type": "text", "text": "How are you doing?"}]},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "I'm doing great, thank you for asking! How can I assist you today?"}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "user", "content": [{"type": "text", "text": "Can you help me write tests?"}]},
|
||||||
|
]
|
||||||
|
|
||||||
|
outputs = ServeCommand.get_processor_inputs_from_inbound_messages(messages, modality)
|
||||||
|
self.assertListEqual(expected_outputs, outputs)
|
||||||
|
|
||||||
|
def test_processor_inputs_from_inbound_messages_vlm_text_and_image_in_base_64(self):
|
||||||
|
modality = Modality.VLM
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "How many pixels are in the image?"},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAASABIAAD/4QBARXhpZgAATU0AKgAAAAgAAYdpAAQAAAABAAAAGgAAAAAAAqACAAQAAAABAAAABaADAAQAAAABAAAABQAAAAD/7QA4UGhvdG9zaG9wIDMuMAA4QklNBAQAAAAAAAA4QklNBCUAAAAAABDUHYzZjwCyBOmACZjs+EJ+/8AAEQgABQAFAwEiAAIRAQMRAf/EAB8AAAEFAQEBAQEBAAAAAAAAAAABAgMEBQYHCAkKC//EALUQAAIBAwMCBAMFBQQEAAABfQECAwAEEQUSITFBBhNRYQcicRQygZGhCCNCscEVUtHwJDNicoIJChYXGBkaJSYnKCkqNDU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6g4SFhoeIiYqSk5SVlpeYmZqio6Slpqeoqaqys7S1tre4ubrCw8TFxsfIycrS09TV1tfY2drh4uPk5ebn6Onq8fLz9PX29/j5+v/EAB8BAAMBAQEBAQEBAQEAAAAAAAABAgMEBQYHCAkKC//EALURAAIBAgQEAwQHBQQEAAECdwABAgMRBAUhMQYSQVEHYXETIjKBCBRCkaGxwQkjM1LwFWJy0QoWJDThJfEXGBkaJicoKSo1Njc4OTpDREVGR0hJSlNUVVZXWFlaY2RlZmdoaWpzdHV2d3h5eoKDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uLj5OXm5+jp6vLz9PX29/j5+v/bAEMAAQEBAQEBAgEBAgICAgICAwICAgIDBAMDAwMDBAUEBAQEBAQFBQUFBQUFBQYGBgYGBgcHBwcHCAgICAgICAgICP/bAEMBAQEBAgICAwICAwgFBAUICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICP/dAAQAAf/aAAwDAQACEQMRAD8A/v4ooooA/9k="
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "The number of pixels in the image cannot be determined from the provided information.",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "Alright"},
|
||||||
|
]
|
||||||
|
|
||||||
|
expected_outputs = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "How many pixels are in the image?"},
|
||||||
|
{"type": "image", "url": "/var/folders/4v/64sxdhsd3gz3r8vhhnyc0mqw0000gn/T/tmp50oyghk6.png"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "The number of pixels in the image cannot be determined from the provided information.",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "user", "content": [{"type": "text", "text": "Alright"}]},
|
||||||
|
]
|
||||||
|
|
||||||
|
outputs = ServeCommand.get_processor_inputs_from_inbound_messages(messages, modality)
|
||||||
|
|
||||||
|
for expected_output, output in zip(expected_outputs, outputs):
|
||||||
|
expected_output_content = expected_output["content"]
|
||||||
|
output_content = output["content"]
|
||||||
|
|
||||||
|
self.assertEqual(type(expected_output_content), type(output_content))
|
||||||
|
|
||||||
|
if isinstance(expected_output_content, list):
|
||||||
|
for expected_output_content_item, output_content_item in zip(expected_output_content, output_content):
|
||||||
|
self.assertIn("type", expected_output_content_item)
|
||||||
|
self.assertIn("type", output_content_item)
|
||||||
|
self.assertTrue(expected_output_content_item["type"] == output_content_item["type"])
|
||||||
|
|
||||||
|
if expected_output_content_item["type"] == "text":
|
||||||
|
self.assertEqual(expected_output_content_item["text"], output_content_item["text"])
|
||||||
|
|
||||||
|
if expected_output_content_item["type"] == "image":
|
||||||
|
self.assertTrue(os.path.exists(output_content_item["url"]))
|
||||||
|
else:
|
||||||
|
raise ValueError("VLMs should only receive content as lists.")
|
||||||
|
|
||||||
|
|
||||||
@slow # server startup time is slow on our push CI
|
@slow # server startup time is slow on our push CI
|
||||||
@require_openai
|
@require_openai
|
||||||
class ServeCompletionsGenerateIntegrationTest(ServeCompletionsMixin, unittest.TestCase):
|
class ServeCompletionsGenerateIntegrationTest(ServeCompletionsMixin, unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user