Make forward asynchrone to avoid long computation timing out.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
This commit is contained in:
Morgan Funtowicz
2020-01-10 20:59:04 +01:00
committed by Lysandre Debut
parent 6e6c8c52ed
commit 908cd5ea27
2 changed files with 52 additions and 19 deletions

View File

@@ -1,6 +1,8 @@
import logging
from argparse import ArgumentParser, Namespace
from typing import Any, List, Optional, Union
from typing import Any, List, Optional
from starlette.responses import JSONResponse
from transformers import Pipeline
from transformers.commands import BaseTransformersCLICommand
@@ -10,6 +12,7 @@ from transformers.pipelines import SUPPORTED_TASKS, pipeline
try:
from uvicorn import run
from fastapi import FastAPI, HTTPException, Body
from fastapi.routing import APIRoute
from pydantic import BaseModel
_serve_dependancies_installed = True
@@ -37,7 +40,7 @@ def serve_command_factory(args: Namespace):
tokenizer=args.tokenizer,
device=args.device,
)
return ServeCommand(nlp, args.host, args.port)
return ServeCommand(nlp, args.host, args.port, args.workers)
class ServeModelInfoResult(BaseModel):
@@ -89,6 +92,7 @@ class ServeCommand(BaseTransformersCLICommand):
)
serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.")
serve_parser.add_argument("--port", type=int, default=8888, help="Port the serving will listen to.")
serve_parser.add_argument("--workers", type=int, default=1, help="Number of http workers")
serve_parser.add_argument("--model", type=str, help="Model's name or path to stored model.")
serve_parser.add_argument("--config", type=str, help="Model's config name or path to stored model.")
serve_parser.add_argument("--tokenizer", type=str, help="Tokenizer name to use.")
@@ -100,12 +104,14 @@ class ServeCommand(BaseTransformersCLICommand):
)
serve_parser.set_defaults(func=serve_command_factory)
def __init__(self, pipeline: Pipeline, host: str, port: int):
def __init__(self, pipeline: Pipeline, host: str, port: int, workers: int):
self._pipeline = pipeline
self._host = host
self._port = port
self.host = host
self.port = port
self.workers = workers
if not _serve_dependancies_installed:
raise RuntimeError(
"Using serve command requires FastAPI and unicorn. "
@@ -114,18 +120,42 @@ class ServeCommand(BaseTransformersCLICommand):
)
else:
logger.info("Serving model over {}:{}".format(host, port))
self._app = FastAPI()
# Register routes
self._app.add_api_route("/", self.model_info, response_model=ServeModelInfoResult, methods=["GET"])
self._app.add_api_route("/tokenize", self.tokenize, response_model=ServeTokenizeResult, methods=["POST"])
self._app.add_api_route(
"/detokenize", self.detokenize, response_model=ServeDeTokenizeResult, methods=["POST"]
self._app = FastAPI(
routes=[
APIRoute(
"/",
self.model_info,
response_model=ServeModelInfoResult,
response_class=JSONResponse,
methods=["GET"],
),
APIRoute(
"/tokenize",
self.tokenize,
response_model=ServeTokenizeResult,
response_class=JSONResponse,
methods=["POST"],
),
APIRoute(
"/detokenize",
self.detokenize,
response_model=ServeDeTokenizeResult,
response_class=JSONResponse,
methods=["POST"],
),
APIRoute(
"/forward",
self.forward,
response_model=ServeForwardResult,
response_class=JSONResponse,
methods=["POST"],
),
],
timeout=600,
)
self._app.add_api_route("/forward", self.forward, response_model=ServeForwardResult, methods=["POST"])
def run(self):
run(self._app, host=self._host, port=self._port)
run(self._app, host=self.host, port=self.port, workers=self.workers)
def model_info(self):
return ServeModelInfoResult(infos=vars(self._pipeline.model.config))
@@ -166,7 +196,7 @@ class ServeCommand(BaseTransformersCLICommand):
except Exception as e:
raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
def forward(self, inputs: Union[str, dict, List[str], List[int], List[dict]] = Body(None, embed=True)):
async def forward(self, inputs=Body(None, embed=True)):
"""
**inputs**:
**attention_mask**: