update serving command

This commit is contained in:
thomwolf
2019-12-20 13:47:35 +01:00
parent 15dda5ea32
commit 73fcebf7ec
3 changed files with 27 additions and 16 deletions

View File

@@ -38,9 +38,9 @@ from setuptools import find_packages, setup
extras = { extras = {
'serving': ['uvicorn', 'fastapi'], 'serving': ['pydantic', 'uvicorn', 'fastapi'],
'serving-tf': ['uvicorn', 'fastapi', 'tensorflow'], 'serving-tf': ['pydantic', 'uvicorn', 'fastapi', 'tensorflow'],
'serving-torch': ['uvicorn', 'fastapi', 'torch'] 'serving-torch': ['pydantic', 'uvicorn', 'fastapi', 'torch']
} }
extras['all'] = [package for package in extras.values()] extras['all'] = [package for package in extras.values()]

View File

@@ -3,9 +3,9 @@ from argparse import ArgumentParser
from transformers.commands.download import DownloadCommand from transformers.commands.download import DownloadCommand
from transformers.commands.run import RunCommand from transformers.commands.run import RunCommand
from transformers.commands.serving import ServeCommand
from transformers.commands.user import UserCommands from transformers.commands.user import UserCommands
from transformers.commands.convert import ConvertCommand from transformers.commands.convert import ConvertCommand
from transformers.commands.serving import ServeCommand
if __name__ == '__main__': if __name__ == '__main__':
parser = ArgumentParser('Transformers CLI tool', usage='transformers-cli <command> [<args>]') parser = ArgumentParser('Transformers CLI tool', usage='transformers-cli <command> [<args>]')

View File

@@ -1,16 +1,23 @@
from argparse import ArgumentParser, Namespace from argparse import ArgumentParser, Namespace
from typing import List, Optional, Union, Any from typing import List, Optional, Union, Any
from fastapi import FastAPI, HTTPException, Body import logging
from logging import getLogger
from pydantic import BaseModel try:
from uvicorn import run from uvicorn import run
from fastapi import FastAPI, HTTPException, Body
from pydantic import BaseModel
_serve_dependancies_installed = True
except (ImportError, AttributeError):
BaseModel = object
Body = lambda *x, **y: None
_serve_dependancies_installed = False
from transformers import Pipeline from transformers import Pipeline
from transformers.commands import BaseTransformersCLICommand from transformers.commands import BaseTransformersCLICommand
from transformers.pipelines import SUPPORTED_TASKS, pipeline from transformers.pipelines import SUPPORTED_TASKS, pipeline
logger = logging.getLogger('transformers-cli/serving')
def serve_command_factory(args: Namespace): def serve_command_factory(args: Namespace):
""" """
@@ -70,20 +77,24 @@ class ServeCommand(BaseTransformersCLICommand):
serve_parser.set_defaults(func=serve_command_factory) serve_parser.set_defaults(func=serve_command_factory)
def __init__(self, pipeline: Pipeline, host: str, port: int): def __init__(self, pipeline: Pipeline, host: str, port: int):
self._logger = getLogger('transformers-cli/serving')
self._pipeline = pipeline self._pipeline = pipeline
self._logger.info('Serving model over {}:{}'.format(host, port))
self._host = host self._host = host
self._port = port self._port = port
self._app = FastAPI() if not _serve_dependancies_installed:
raise ImportError("Using serve command requires FastAPI and unicorn. "
"Please install transformers with [serving]: pip install transformers[serving]."
"Or install FastAPI and unicorn separatly.")
else:
logger.info('Serving model over {}:{}'.format(host, port))
self._app = FastAPI()
# Register routes # Register routes
self._app.add_api_route('/', self.model_info, response_model=ServeModelInfoResult, methods=['GET']) self._app.add_api_route('/', self.model_info, response_model=ServeModelInfoResult, methods=['GET'])
self._app.add_api_route('/tokenize', self.tokenize, response_model=ServeTokenizeResult, methods=['POST']) self._app.add_api_route('/tokenize', self.tokenize, response_model=ServeTokenizeResult, methods=['POST'])
self._app.add_api_route('/detokenize', self.detokenize, response_model=ServeDeTokenizeResult, methods=['POST']) self._app.add_api_route('/detokenize', self.detokenize, response_model=ServeDeTokenizeResult, methods=['POST'])
self._app.add_api_route('/forward', self.forward, response_model=ServeForwardResult, methods=['POST']) self._app.add_api_route('/forward', self.forward, response_model=ServeForwardResult, methods=['POST'])
def run(self): def run(self):
run(self._app, host=self._host, port=self._port) run(self._app, host=self._host, port=self._port)