Make forward asynchrone to avoid long computation timing out.
Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
This commit is contained in:
committed by
Lysandre Debut
parent
6e6c8c52ed
commit
908cd5ea27
@@ -1,6 +1,8 @@
|
|||||||
import logging
|
import logging
|
||||||
from argparse import ArgumentParser, Namespace
|
from argparse import ArgumentParser, Namespace
|
||||||
from typing import Any, List, Optional, Union
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
|
|
||||||
from transformers import Pipeline
|
from transformers import Pipeline
|
||||||
from transformers.commands import BaseTransformersCLICommand
|
from transformers.commands import BaseTransformersCLICommand
|
||||||
@@ -10,6 +12,7 @@ from transformers.pipelines import SUPPORTED_TASKS, pipeline
|
|||||||
try:
|
try:
|
||||||
from uvicorn import run
|
from uvicorn import run
|
||||||
from fastapi import FastAPI, HTTPException, Body
|
from fastapi import FastAPI, HTTPException, Body
|
||||||
|
from fastapi.routing import APIRoute
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
_serve_dependancies_installed = True
|
_serve_dependancies_installed = True
|
||||||
@@ -37,7 +40,7 @@ def serve_command_factory(args: Namespace):
|
|||||||
tokenizer=args.tokenizer,
|
tokenizer=args.tokenizer,
|
||||||
device=args.device,
|
device=args.device,
|
||||||
)
|
)
|
||||||
return ServeCommand(nlp, args.host, args.port)
|
return ServeCommand(nlp, args.host, args.port, args.workers)
|
||||||
|
|
||||||
|
|
||||||
class ServeModelInfoResult(BaseModel):
|
class ServeModelInfoResult(BaseModel):
|
||||||
@@ -89,6 +92,7 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
)
|
)
|
||||||
serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.")
|
serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.")
|
||||||
serve_parser.add_argument("--port", type=int, default=8888, help="Port the serving will listen to.")
|
serve_parser.add_argument("--port", type=int, default=8888, help="Port the serving will listen to.")
|
||||||
|
serve_parser.add_argument("--workers", type=int, default=1, help="Number of http workers")
|
||||||
serve_parser.add_argument("--model", type=str, help="Model's name or path to stored model.")
|
serve_parser.add_argument("--model", type=str, help="Model's name or path to stored model.")
|
||||||
serve_parser.add_argument("--config", type=str, help="Model's config name or path to stored model.")
|
serve_parser.add_argument("--config", type=str, help="Model's config name or path to stored model.")
|
||||||
serve_parser.add_argument("--tokenizer", type=str, help="Tokenizer name to use.")
|
serve_parser.add_argument("--tokenizer", type=str, help="Tokenizer name to use.")
|
||||||
@@ -100,12 +104,14 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
)
|
)
|
||||||
serve_parser.set_defaults(func=serve_command_factory)
|
serve_parser.set_defaults(func=serve_command_factory)
|
||||||
|
|
||||||
def __init__(self, pipeline: Pipeline, host: str, port: int):
|
def __init__(self, pipeline: Pipeline, host: str, port: int, workers: int):
|
||||||
|
|
||||||
self._pipeline = pipeline
|
self._pipeline = pipeline
|
||||||
|
|
||||||
self._host = host
|
self.host = host
|
||||||
self._port = port
|
self.port = port
|
||||||
|
self.workers = workers
|
||||||
|
|
||||||
if not _serve_dependancies_installed:
|
if not _serve_dependancies_installed:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Using serve command requires FastAPI and unicorn. "
|
"Using serve command requires FastAPI and unicorn. "
|
||||||
@@ -114,18 +120,42 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info("Serving model over {}:{}".format(host, port))
|
logger.info("Serving model over {}:{}".format(host, port))
|
||||||
self._app = FastAPI()
|
self._app = FastAPI(
|
||||||
|
routes=[
|
||||||
# Register routes
|
APIRoute(
|
||||||
self._app.add_api_route("/", self.model_info, response_model=ServeModelInfoResult, methods=["GET"])
|
"/",
|
||||||
self._app.add_api_route("/tokenize", self.tokenize, response_model=ServeTokenizeResult, methods=["POST"])
|
self.model_info,
|
||||||
self._app.add_api_route(
|
response_model=ServeModelInfoResult,
|
||||||
"/detokenize", self.detokenize, response_model=ServeDeTokenizeResult, methods=["POST"]
|
response_class=JSONResponse,
|
||||||
|
methods=["GET"],
|
||||||
|
),
|
||||||
|
APIRoute(
|
||||||
|
"/tokenize",
|
||||||
|
self.tokenize,
|
||||||
|
response_model=ServeTokenizeResult,
|
||||||
|
response_class=JSONResponse,
|
||||||
|
methods=["POST"],
|
||||||
|
),
|
||||||
|
APIRoute(
|
||||||
|
"/detokenize",
|
||||||
|
self.detokenize,
|
||||||
|
response_model=ServeDeTokenizeResult,
|
||||||
|
response_class=JSONResponse,
|
||||||
|
methods=["POST"],
|
||||||
|
),
|
||||||
|
APIRoute(
|
||||||
|
"/forward",
|
||||||
|
self.forward,
|
||||||
|
response_model=ServeForwardResult,
|
||||||
|
response_class=JSONResponse,
|
||||||
|
methods=["POST"],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
timeout=600,
|
||||||
)
|
)
|
||||||
self._app.add_api_route("/forward", self.forward, response_model=ServeForwardResult, methods=["POST"])
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
run(self._app, host=self._host, port=self._port)
|
run(self._app, host=self.host, port=self.port, workers=self.workers)
|
||||||
|
|
||||||
def model_info(self):
|
def model_info(self):
|
||||||
return ServeModelInfoResult(infos=vars(self._pipeline.model.config))
|
return ServeModelInfoResult(infos=vars(self._pipeline.model.config))
|
||||||
@@ -166,7 +196,7 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
|
raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
|
||||||
|
|
||||||
def forward(self, inputs: Union[str, dict, List[str], List[int], List[dict]] = Body(None, embed=True)):
|
async def forward(self, inputs=Body(None, embed=True)):
|
||||||
"""
|
"""
|
||||||
**inputs**:
|
**inputs**:
|
||||||
**attention_mask**:
|
**attention_mask**:
|
||||||
|
|||||||
@@ -28,8 +28,9 @@ from . import __version__
|
|||||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if os.environ.get("USE_TORCH", 'AUTO').upper() in ("1", "ON", "YES", "AUTO") and \
|
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
||||||
os.environ.get("USE_TF", 'AUTO').upper() not in ("1", "ON", "YES"):
|
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
||||||
|
if USE_TORCH in ("1", "ON", "YES", "AUTO") and USE_TF not in ("1", "ON", "YES"):
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
_torch_available = True # pylint: disable=invalid-name
|
_torch_available = True # pylint: disable=invalid-name
|
||||||
@@ -41,8 +42,10 @@ except ImportError:
|
|||||||
_torch_available = False # pylint: disable=invalid-name
|
_torch_available = False # pylint: disable=invalid-name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if os.environ.get("USE_TF", 'AUTO').upper() in ("1", "ON", "YES", "AUTO") and \
|
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
||||||
os.environ.get("USE_TORCH", 'AUTO').upper() not in ("1", "ON", "YES"):
|
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
||||||
|
|
||||||
|
if USE_TF in ("1", "ON", "YES", "AUTO") and USE_TORCH not in ("1", "ON", "YES"):
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2
|
assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2
|
||||||
|
|||||||
Reference in New Issue
Block a user