update serving command

This commit is contained in:
thomwolf
2019-12-20 13:47:35 +01:00
parent 15dda5ea32
commit 73fcebf7ec
3 changed files with 27 additions and 16 deletions

View File

@@ -38,9 +38,9 @@ from setuptools import find_packages, setup
extras = {
'serving': ['uvicorn', 'fastapi'],
'serving-tf': ['uvicorn', 'fastapi', 'tensorflow'],
'serving-torch': ['uvicorn', 'fastapi', 'torch']
'serving': ['pydantic', 'uvicorn', 'fastapi'],
'serving-tf': ['pydantic', 'uvicorn', 'fastapi', 'tensorflow'],
'serving-torch': ['pydantic', 'uvicorn', 'fastapi', 'torch']
}
extras['all'] = [package for package in extras.values()]

View File

@@ -3,9 +3,9 @@ from argparse import ArgumentParser
from transformers.commands.download import DownloadCommand
from transformers.commands.run import RunCommand
from transformers.commands.serving import ServeCommand
from transformers.commands.user import UserCommands
from transformers.commands.convert import ConvertCommand
from transformers.commands.serving import ServeCommand
if __name__ == '__main__':
parser = ArgumentParser('Transformers CLI tool', usage='transformers-cli <command> [<args>]')

View File

@@ -1,16 +1,23 @@
from argparse import ArgumentParser, Namespace
from typing import List, Optional, Union, Any
from fastapi import FastAPI, HTTPException, Body
from logging import getLogger
import logging
from pydantic import BaseModel
try:
from uvicorn import run
from fastapi import FastAPI, HTTPException, Body
from pydantic import BaseModel
_serve_dependancies_installed = True
except (ImportError, AttributeError):
BaseModel = object
Body = lambda *x, **y: None
_serve_dependancies_installed = False
from transformers import Pipeline
from transformers.commands import BaseTransformersCLICommand
from transformers.pipelines import SUPPORTED_TASKS, pipeline
logger = logging.getLogger('transformers-cli/serving')
def serve_command_factory(args: Namespace):
"""
@@ -70,13 +77,17 @@ class ServeCommand(BaseTransformersCLICommand):
serve_parser.set_defaults(func=serve_command_factory)
def __init__(self, pipeline: Pipeline, host: str, port: int):
self._logger = getLogger('transformers-cli/serving')
self._pipeline = pipeline
self._logger.info('Serving model over {}:{}'.format(host, port))
self._host = host
self._port = port
if not _serve_dependancies_installed:
raise ImportError("Using serve command requires FastAPI and unicorn. "
"Please install transformers with [serving]: pip install transformers[serving]."
"Or install FastAPI and unicorn separatly.")
else:
logger.info('Serving model over {}:{}'.format(host, port))
self._app = FastAPI()
# Register routes