update serving command
This commit is contained in:
6
setup.py
6
setup.py
@@ -38,9 +38,9 @@ from setuptools import find_packages, setup
|
|||||||
|
|
||||||
|
|
||||||
extras = {
|
extras = {
|
||||||
'serving': ['uvicorn', 'fastapi'],
|
'serving': ['pydantic', 'uvicorn', 'fastapi'],
|
||||||
'serving-tf': ['uvicorn', 'fastapi', 'tensorflow'],
|
'serving-tf': ['pydantic', 'uvicorn', 'fastapi', 'tensorflow'],
|
||||||
'serving-torch': ['uvicorn', 'fastapi', 'torch']
|
'serving-torch': ['pydantic', 'uvicorn', 'fastapi', 'torch']
|
||||||
}
|
}
|
||||||
extras['all'] = [package for package in extras.values()]
|
extras['all'] = [package for package in extras.values()]
|
||||||
|
|
||||||
|
|||||||
@@ -3,9 +3,9 @@ from argparse import ArgumentParser
|
|||||||
|
|
||||||
from transformers.commands.download import DownloadCommand
|
from transformers.commands.download import DownloadCommand
|
||||||
from transformers.commands.run import RunCommand
|
from transformers.commands.run import RunCommand
|
||||||
from transformers.commands.serving import ServeCommand
|
|
||||||
from transformers.commands.user import UserCommands
|
from transformers.commands.user import UserCommands
|
||||||
from transformers.commands.convert import ConvertCommand
|
from transformers.commands.convert import ConvertCommand
|
||||||
|
from transformers.commands.serving import ServeCommand
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = ArgumentParser('Transformers CLI tool', usage='transformers-cli <command> [<args>]')
|
parser = ArgumentParser('Transformers CLI tool', usage='transformers-cli <command> [<args>]')
|
||||||
|
|||||||
@@ -1,16 +1,23 @@
|
|||||||
from argparse import ArgumentParser, Namespace
|
from argparse import ArgumentParser, Namespace
|
||||||
from typing import List, Optional, Union, Any
|
from typing import List, Optional, Union, Any
|
||||||
|
|
||||||
from fastapi import FastAPI, HTTPException, Body
|
import logging
|
||||||
from logging import getLogger
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
try:
|
||||||
from uvicorn import run
|
from uvicorn import run
|
||||||
|
from fastapi import FastAPI, HTTPException, Body
|
||||||
|
from pydantic import BaseModel
|
||||||
|
_serve_dependancies_installed = True
|
||||||
|
except (ImportError, AttributeError):
|
||||||
|
BaseModel = object
|
||||||
|
Body = lambda *x, **y: None
|
||||||
|
_serve_dependancies_installed = False
|
||||||
|
|
||||||
from transformers import Pipeline
|
from transformers import Pipeline
|
||||||
from transformers.commands import BaseTransformersCLICommand
|
from transformers.commands import BaseTransformersCLICommand
|
||||||
from transformers.pipelines import SUPPORTED_TASKS, pipeline
|
from transformers.pipelines import SUPPORTED_TASKS, pipeline
|
||||||
|
|
||||||
|
logger = logging.getLogger('transformers-cli/serving')
|
||||||
|
|
||||||
def serve_command_factory(args: Namespace):
|
def serve_command_factory(args: Namespace):
|
||||||
"""
|
"""
|
||||||
@@ -70,13 +77,17 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
serve_parser.set_defaults(func=serve_command_factory)
|
serve_parser.set_defaults(func=serve_command_factory)
|
||||||
|
|
||||||
def __init__(self, pipeline: Pipeline, host: str, port: int):
|
def __init__(self, pipeline: Pipeline, host: str, port: int):
|
||||||
self._logger = getLogger('transformers-cli/serving')
|
|
||||||
|
|
||||||
self._pipeline = pipeline
|
self._pipeline = pipeline
|
||||||
|
|
||||||
self._logger.info('Serving model over {}:{}'.format(host, port))
|
|
||||||
self._host = host
|
self._host = host
|
||||||
self._port = port
|
self._port = port
|
||||||
|
if not _serve_dependancies_installed:
|
||||||
|
raise ImportError("Using serve command requires FastAPI and unicorn. "
|
||||||
|
"Please install transformers with [serving]: pip install transformers[serving]."
|
||||||
|
"Or install FastAPI and unicorn separatly.")
|
||||||
|
else:
|
||||||
|
logger.info('Serving model over {}:{}'.format(host, port))
|
||||||
self._app = FastAPI()
|
self._app = FastAPI()
|
||||||
|
|
||||||
# Register routes
|
# Register routes
|
||||||
|
|||||||
Reference in New Issue
Block a user