Make forward asynchrone to avoid long computation timing out.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
This commit is contained in:
committed by
Lysandre Debut
parent
6e6c8c52ed
commit
908cd5ea27
@@ -1,6 +1,8 @@
|
||||
import logging
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from typing import Any, List, Optional, Union
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from transformers import Pipeline
|
||||
from transformers.commands import BaseTransformersCLICommand
|
||||
@@ -10,6 +12,7 @@ from transformers.pipelines import SUPPORTED_TASKS, pipeline
|
||||
try:
|
||||
from uvicorn import run
|
||||
from fastapi import FastAPI, HTTPException, Body
|
||||
from fastapi.routing import APIRoute
|
||||
from pydantic import BaseModel
|
||||
|
||||
_serve_dependancies_installed = True
|
||||
@@ -37,7 +40,7 @@ def serve_command_factory(args: Namespace):
|
||||
tokenizer=args.tokenizer,
|
||||
device=args.device,
|
||||
)
|
||||
return ServeCommand(nlp, args.host, args.port)
|
||||
return ServeCommand(nlp, args.host, args.port, args.workers)
|
||||
|
||||
|
||||
class ServeModelInfoResult(BaseModel):
|
||||
@@ -89,6 +92,7 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
)
|
||||
serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.")
|
||||
serve_parser.add_argument("--port", type=int, default=8888, help="Port the serving will listen to.")
|
||||
serve_parser.add_argument("--workers", type=int, default=1, help="Number of http workers")
|
||||
serve_parser.add_argument("--model", type=str, help="Model's name or path to stored model.")
|
||||
serve_parser.add_argument("--config", type=str, help="Model's config name or path to stored model.")
|
||||
serve_parser.add_argument("--tokenizer", type=str, help="Tokenizer name to use.")
|
||||
@@ -100,12 +104,14 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
)
|
||||
serve_parser.set_defaults(func=serve_command_factory)
|
||||
|
||||
def __init__(self, pipeline: Pipeline, host: str, port: int):
|
||||
def __init__(self, pipeline: Pipeline, host: str, port: int, workers: int):
|
||||
|
||||
self._pipeline = pipeline
|
||||
|
||||
self._host = host
|
||||
self._port = port
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.workers = workers
|
||||
|
||||
if not _serve_dependancies_installed:
|
||||
raise RuntimeError(
|
||||
"Using serve command requires FastAPI and unicorn. "
|
||||
@@ -114,18 +120,42 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
)
|
||||
else:
|
||||
logger.info("Serving model over {}:{}".format(host, port))
|
||||
self._app = FastAPI()
|
||||
|
||||
# Register routes
|
||||
self._app.add_api_route("/", self.model_info, response_model=ServeModelInfoResult, methods=["GET"])
|
||||
self._app.add_api_route("/tokenize", self.tokenize, response_model=ServeTokenizeResult, methods=["POST"])
|
||||
self._app.add_api_route(
|
||||
"/detokenize", self.detokenize, response_model=ServeDeTokenizeResult, methods=["POST"]
|
||||
self._app = FastAPI(
|
||||
routes=[
|
||||
APIRoute(
|
||||
"/",
|
||||
self.model_info,
|
||||
response_model=ServeModelInfoResult,
|
||||
response_class=JSONResponse,
|
||||
methods=["GET"],
|
||||
),
|
||||
APIRoute(
|
||||
"/tokenize",
|
||||
self.tokenize,
|
||||
response_model=ServeTokenizeResult,
|
||||
response_class=JSONResponse,
|
||||
methods=["POST"],
|
||||
),
|
||||
APIRoute(
|
||||
"/detokenize",
|
||||
self.detokenize,
|
||||
response_model=ServeDeTokenizeResult,
|
||||
response_class=JSONResponse,
|
||||
methods=["POST"],
|
||||
),
|
||||
APIRoute(
|
||||
"/forward",
|
||||
self.forward,
|
||||
response_model=ServeForwardResult,
|
||||
response_class=JSONResponse,
|
||||
methods=["POST"],
|
||||
),
|
||||
],
|
||||
timeout=600,
|
||||
)
|
||||
self._app.add_api_route("/forward", self.forward, response_model=ServeForwardResult, methods=["POST"])
|
||||
|
||||
def run(self):
|
||||
run(self._app, host=self._host, port=self._port)
|
||||
run(self._app, host=self.host, port=self.port, workers=self.workers)
|
||||
|
||||
def model_info(self):
|
||||
return ServeModelInfoResult(infos=vars(self._pipeline.model.config))
|
||||
@@ -166,7 +196,7 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
|
||||
|
||||
def forward(self, inputs: Union[str, dict, List[str], List[int], List[dict]] = Body(None, embed=True)):
|
||||
async def forward(self, inputs=Body(None, embed=True)):
|
||||
"""
|
||||
**inputs**:
|
||||
**attention_mask**:
|
||||
|
||||
Reference in New Issue
Block a user