From 908cd5ea279fa57955c3e0e723a76950de134c9b Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Fri, 10 Jan 2020 20:59:04 +0100 Subject: [PATCH] Make forward asynchrone to avoid long computation timing out. Signed-off-by: Morgan Funtowicz --- src/transformers/commands/serving.py | 60 +++++++++++++++++++++------- src/transformers/file_utils.py | 11 +++-- 2 files changed, 52 insertions(+), 19 deletions(-) diff --git a/src/transformers/commands/serving.py b/src/transformers/commands/serving.py index e05e4513fc..0d92f818c5 100644 --- a/src/transformers/commands/serving.py +++ b/src/transformers/commands/serving.py @@ -1,6 +1,8 @@ import logging from argparse import ArgumentParser, Namespace -from typing import Any, List, Optional, Union +from typing import Any, List, Optional + +from starlette.responses import JSONResponse from transformers import Pipeline from transformers.commands import BaseTransformersCLICommand @@ -10,6 +12,7 @@ from transformers.pipelines import SUPPORTED_TASKS, pipeline try: from uvicorn import run from fastapi import FastAPI, HTTPException, Body + from fastapi.routing import APIRoute from pydantic import BaseModel _serve_dependancies_installed = True @@ -37,7 +40,7 @@ def serve_command_factory(args: Namespace): tokenizer=args.tokenizer, device=args.device, ) - return ServeCommand(nlp, args.host, args.port) + return ServeCommand(nlp, args.host, args.port, args.workers) class ServeModelInfoResult(BaseModel): @@ -89,6 +92,7 @@ class ServeCommand(BaseTransformersCLICommand): ) serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.") serve_parser.add_argument("--port", type=int, default=8888, help="Port the serving will listen to.") + serve_parser.add_argument("--workers", type=int, default=1, help="Number of http workers") serve_parser.add_argument("--model", type=str, help="Model's name or path to stored model.") serve_parser.add_argument("--config", type=str, help="Model's config name or path to stored model.") serve_parser.add_argument("--tokenizer", type=str, help="Tokenizer name to use.") @@ -100,12 +104,14 @@ class ServeCommand(BaseTransformersCLICommand): ) serve_parser.set_defaults(func=serve_command_factory) - def __init__(self, pipeline: Pipeline, host: str, port: int): + def __init__(self, pipeline: Pipeline, host: str, port: int, workers: int): self._pipeline = pipeline - self._host = host - self._port = port + self.host = host + self.port = port + self.workers = workers + if not _serve_dependancies_installed: raise RuntimeError( "Using serve command requires FastAPI and unicorn. " @@ -114,18 +120,42 @@ class ServeCommand(BaseTransformersCLICommand): ) else: logger.info("Serving model over {}:{}".format(host, port)) - self._app = FastAPI() - - # Register routes - self._app.add_api_route("/", self.model_info, response_model=ServeModelInfoResult, methods=["GET"]) - self._app.add_api_route("/tokenize", self.tokenize, response_model=ServeTokenizeResult, methods=["POST"]) - self._app.add_api_route( - "/detokenize", self.detokenize, response_model=ServeDeTokenizeResult, methods=["POST"] + self._app = FastAPI( + routes=[ + APIRoute( + "/", + self.model_info, + response_model=ServeModelInfoResult, + response_class=JSONResponse, + methods=["GET"], + ), + APIRoute( + "/tokenize", + self.tokenize, + response_model=ServeTokenizeResult, + response_class=JSONResponse, + methods=["POST"], + ), + APIRoute( + "/detokenize", + self.detokenize, + response_model=ServeDeTokenizeResult, + response_class=JSONResponse, + methods=["POST"], + ), + APIRoute( + "/forward", + self.forward, + response_model=ServeForwardResult, + response_class=JSONResponse, + methods=["POST"], + ), + ], + timeout=600, ) - self._app.add_api_route("/forward", self.forward, response_model=ServeForwardResult, methods=["POST"]) def run(self): - run(self._app, host=self._host, port=self._port) + run(self._app, host=self.host, port=self.port, workers=self.workers) def model_info(self): return ServeModelInfoResult(infos=vars(self._pipeline.model.config)) @@ -166,7 +196,7 @@ class ServeCommand(BaseTransformersCLICommand): except Exception as e: raise HTTPException(status_code=500, detail={"model": "", "error": str(e)}) - def forward(self, inputs: Union[str, dict, List[str], List[int], List[dict]] = Body(None, embed=True)): + async def forward(self, inputs=Body(None, embed=True)): """ **inputs**: **attention_mask**: diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index a0489a4e06..61106e73e9 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -28,8 +28,9 @@ from . import __version__ logger = logging.getLogger(__name__) # pylint: disable=invalid-name try: - if os.environ.get("USE_TORCH", 'AUTO').upper() in ("1", "ON", "YES", "AUTO") and \ - os.environ.get("USE_TF", 'AUTO').upper() not in ("1", "ON", "YES"): + USE_TF = os.environ.get("USE_TF", "AUTO").upper() + USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() + if USE_TORCH in ("1", "ON", "YES", "AUTO") and USE_TF not in ("1", "ON", "YES"): import torch _torch_available = True # pylint: disable=invalid-name @@ -41,8 +42,10 @@ except ImportError: _torch_available = False # pylint: disable=invalid-name try: - if os.environ.get("USE_TF", 'AUTO').upper() in ("1", "ON", "YES", "AUTO") and \ - os.environ.get("USE_TORCH", 'AUTO').upper() not in ("1", "ON", "YES"): + USE_TF = os.environ.get("USE_TF", "AUTO").upper() + USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() + + if USE_TF in ("1", "ON", "YES", "AUTO") and USE_TORCH not in ("1", "ON", "YES"): import tensorflow as tf assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2