Make forward asynchrone to avoid long computation timing out.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
This commit is contained in:
Morgan Funtowicz
2020-01-10 20:59:04 +01:00
committed by Lysandre Debut
parent 6e6c8c52ed
commit 908cd5ea27
2 changed files with 52 additions and 19 deletions

View File

@@ -1,6 +1,8 @@
import logging import logging
from argparse import ArgumentParser, Namespace from argparse import ArgumentParser, Namespace
from typing import Any, List, Optional, Union from typing import Any, List, Optional
from starlette.responses import JSONResponse
from transformers import Pipeline from transformers import Pipeline
from transformers.commands import BaseTransformersCLICommand from transformers.commands import BaseTransformersCLICommand
@@ -10,6 +12,7 @@ from transformers.pipelines import SUPPORTED_TASKS, pipeline
try: try:
from uvicorn import run from uvicorn import run
from fastapi import FastAPI, HTTPException, Body from fastapi import FastAPI, HTTPException, Body
from fastapi.routing import APIRoute
from pydantic import BaseModel from pydantic import BaseModel
_serve_dependancies_installed = True _serve_dependancies_installed = True
@@ -37,7 +40,7 @@ def serve_command_factory(args: Namespace):
tokenizer=args.tokenizer, tokenizer=args.tokenizer,
device=args.device, device=args.device,
) )
return ServeCommand(nlp, args.host, args.port) return ServeCommand(nlp, args.host, args.port, args.workers)
class ServeModelInfoResult(BaseModel): class ServeModelInfoResult(BaseModel):
@@ -89,6 +92,7 @@ class ServeCommand(BaseTransformersCLICommand):
) )
serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.") serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.")
serve_parser.add_argument("--port", type=int, default=8888, help="Port the serving will listen to.") serve_parser.add_argument("--port", type=int, default=8888, help="Port the serving will listen to.")
serve_parser.add_argument("--workers", type=int, default=1, help="Number of http workers")
serve_parser.add_argument("--model", type=str, help="Model's name or path to stored model.") serve_parser.add_argument("--model", type=str, help="Model's name or path to stored model.")
serve_parser.add_argument("--config", type=str, help="Model's config name or path to stored model.") serve_parser.add_argument("--config", type=str, help="Model's config name or path to stored model.")
serve_parser.add_argument("--tokenizer", type=str, help="Tokenizer name to use.") serve_parser.add_argument("--tokenizer", type=str, help="Tokenizer name to use.")
@@ -100,12 +104,14 @@ class ServeCommand(BaseTransformersCLICommand):
) )
serve_parser.set_defaults(func=serve_command_factory) serve_parser.set_defaults(func=serve_command_factory)
def __init__(self, pipeline: Pipeline, host: str, port: int): def __init__(self, pipeline: Pipeline, host: str, port: int, workers: int):
self._pipeline = pipeline self._pipeline = pipeline
self._host = host self.host = host
self._port = port self.port = port
self.workers = workers
if not _serve_dependancies_installed: if not _serve_dependancies_installed:
raise RuntimeError( raise RuntimeError(
"Using serve command requires FastAPI and unicorn. " "Using serve command requires FastAPI and unicorn. "
@@ -114,18 +120,42 @@ class ServeCommand(BaseTransformersCLICommand):
) )
else: else:
logger.info("Serving model over {}:{}".format(host, port)) logger.info("Serving model over {}:{}".format(host, port))
self._app = FastAPI() self._app = FastAPI(
routes=[
# Register routes APIRoute(
self._app.add_api_route("/", self.model_info, response_model=ServeModelInfoResult, methods=["GET"]) "/",
self._app.add_api_route("/tokenize", self.tokenize, response_model=ServeTokenizeResult, methods=["POST"]) self.model_info,
self._app.add_api_route( response_model=ServeModelInfoResult,
"/detokenize", self.detokenize, response_model=ServeDeTokenizeResult, methods=["POST"] response_class=JSONResponse,
methods=["GET"],
),
APIRoute(
"/tokenize",
self.tokenize,
response_model=ServeTokenizeResult,
response_class=JSONResponse,
methods=["POST"],
),
APIRoute(
"/detokenize",
self.detokenize,
response_model=ServeDeTokenizeResult,
response_class=JSONResponse,
methods=["POST"],
),
APIRoute(
"/forward",
self.forward,
response_model=ServeForwardResult,
response_class=JSONResponse,
methods=["POST"],
),
],
timeout=600,
) )
self._app.add_api_route("/forward", self.forward, response_model=ServeForwardResult, methods=["POST"])
def run(self): def run(self):
run(self._app, host=self._host, port=self._port) run(self._app, host=self.host, port=self.port, workers=self.workers)
def model_info(self): def model_info(self):
return ServeModelInfoResult(infos=vars(self._pipeline.model.config)) return ServeModelInfoResult(infos=vars(self._pipeline.model.config))
@@ -166,7 +196,7 @@ class ServeCommand(BaseTransformersCLICommand):
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail={"model": "", "error": str(e)}) raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
def forward(self, inputs: Union[str, dict, List[str], List[int], List[dict]] = Body(None, embed=True)): async def forward(self, inputs=Body(None, embed=True)):
""" """
**inputs**: **inputs**:
**attention_mask**: **attention_mask**:

View File

@@ -28,8 +28,9 @@ from . import __version__
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
try: try:
if os.environ.get("USE_TORCH", 'AUTO').upper() in ("1", "ON", "YES", "AUTO") and \ USE_TF = os.environ.get("USE_TF", "AUTO").upper()
os.environ.get("USE_TF", 'AUTO').upper() not in ("1", "ON", "YES"): USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
if USE_TORCH in ("1", "ON", "YES", "AUTO") and USE_TF not in ("1", "ON", "YES"):
import torch import torch
_torch_available = True # pylint: disable=invalid-name _torch_available = True # pylint: disable=invalid-name
@@ -41,8 +42,10 @@ except ImportError:
_torch_available = False # pylint: disable=invalid-name _torch_available = False # pylint: disable=invalid-name
try: try:
if os.environ.get("USE_TF", 'AUTO').upper() in ("1", "ON", "YES", "AUTO") and \ USE_TF = os.environ.get("USE_TF", "AUTO").upper()
os.environ.get("USE_TORCH", 'AUTO').upper() not in ("1", "ON", "YES"): USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
if USE_TF in ("1", "ON", "YES", "AUTO") and USE_TORCH not in ("1", "ON", "YES"):
import tensorflow as tf import tensorflow as tf
assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2 assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2