Binding pipelines to the cli.
This commit is contained in:
@@ -2,6 +2,7 @@
|
|||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
from transformers.commands.download import DownloadCommand
|
from transformers.commands.download import DownloadCommand
|
||||||
|
from transformers.commands.run import RunCommand
|
||||||
from transformers.commands.serving import ServeCommand
|
from transformers.commands.serving import ServeCommand
|
||||||
from transformers.commands.user import UserCommands
|
from transformers.commands.user import UserCommands
|
||||||
from transformers.commands.train import TrainCommand
|
from transformers.commands.train import TrainCommand
|
||||||
@@ -14,9 +15,10 @@ if __name__ == '__main__':
|
|||||||
# Register commands
|
# Register commands
|
||||||
ConvertCommand.register_subcommand(commands_parser)
|
ConvertCommand.register_subcommand(commands_parser)
|
||||||
DownloadCommand.register_subcommand(commands_parser)
|
DownloadCommand.register_subcommand(commands_parser)
|
||||||
|
RunCommand.register_subcommand(commands_parser)
|
||||||
ServeCommand.register_subcommand(commands_parser)
|
ServeCommand.register_subcommand(commands_parser)
|
||||||
UserCommands.register_subcommand(commands_parser)
|
|
||||||
TrainCommand.register_subcommand(commands_parser)
|
TrainCommand.register_subcommand(commands_parser)
|
||||||
|
UserCommands.register_subcommand(commands_parser)
|
||||||
|
|
||||||
# Let's go
|
# Let's go
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|||||||
56
transformers/commands/run.py
Normal file
56
transformers/commands/run.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
|
from transformers.commands import BaseTransformersCLICommand
|
||||||
|
from transformers.pipelines import pipeline, Pipeline, PipelineDataFormat, SUPPORTED_TASKS
|
||||||
|
|
||||||
|
|
||||||
|
def try_infer_format_from_ext(path: str):
|
||||||
|
for ext in PipelineDataFormat.SUPPORTED_FORMATS:
|
||||||
|
if path.endswith(ext):
|
||||||
|
return ext
|
||||||
|
|
||||||
|
raise Exception(
|
||||||
|
'Unable to determine file format from file extension {}. '
|
||||||
|
'Please provide the format through --format {}'.format(path, PipelineDataFormat.SUPPORTED_FORMATS)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_command_factory(args):
|
||||||
|
nlp = pipeline(task=args.task, model=args.model, tokenizer=args.tokenizer)
|
||||||
|
format = try_infer_format_from_ext(args.input) if args.format == 'infer' else args.format
|
||||||
|
reader = PipelineDataFormat.from_str(format, args.output, args.input, args.column)
|
||||||
|
return RunCommand(nlp, reader)
|
||||||
|
|
||||||
|
|
||||||
|
class RunCommand(BaseTransformersCLICommand):
|
||||||
|
|
||||||
|
def __init__(self, nlp: Pipeline, reader: PipelineDataFormat):
|
||||||
|
self._nlp = nlp
|
||||||
|
self._reader = reader
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def register_subcommand(parser: ArgumentParser):
|
||||||
|
run_parser = parser.add_parser('run', help="Run a pipeline through the CLI")
|
||||||
|
run_parser.add_argument('--task', choices=SUPPORTED_TASKS.keys(), help='Task to run')
|
||||||
|
run_parser.add_argument('--model', type=str, required=True, help='Name or path to the model to instantiate.')
|
||||||
|
run_parser.add_argument('--tokenizer', type=str, help='Name of the tokenizer to use. (default: same as the model name)')
|
||||||
|
run_parser.add_argument('--column', type=str, required=True, help='Name of the column to use as input. (For multi columns input as QA use column1,columns2)')
|
||||||
|
run_parser.add_argument('--format', type=str, default='infer', choices=PipelineDataFormat.SUPPORTED_FORMATS, help='Input format to read from')
|
||||||
|
run_parser.add_argument('--input', type=str, required=True, help='Path to the file to use for inference')
|
||||||
|
run_parser.add_argument('--output', type=str, required=True, help='Path to the file that will be used post to write results.')
|
||||||
|
run_parser.add_argument('kwargs', nargs='*', help='Arguments to forward to the file format reader')
|
||||||
|
run_parser.set_defaults(func=run_command_factory)
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
nlp, output = self._nlp, []
|
||||||
|
for entry in self._reader:
|
||||||
|
if self._reader.is_multi_columns:
|
||||||
|
output += [nlp(**entry)]
|
||||||
|
else:
|
||||||
|
output += [nlp(entry)]
|
||||||
|
|
||||||
|
# Saving data
|
||||||
|
self._reader.save(output)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -14,6 +14,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
|
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from itertools import groupby
|
from itertools import groupby
|
||||||
@@ -25,11 +27,13 @@ from transformers import AutoTokenizer, PreTrainedTokenizer, PretrainedConfig, \
|
|||||||
SquadExample, squad_convert_examples_to_features, is_tf_available, is_torch_available, logger
|
SquadExample, squad_convert_examples_to_features, is_tf_available, is_torch_available, logger
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
from transformers import TFAutoModelForSequenceClassification, TFAutoModelForQuestionAnswering, TFAutoModelForTokenClassification
|
from transformers import TFAutoModel, TFAutoModelForSequenceClassification, \
|
||||||
|
TFAutoModelForQuestionAnswering, TFAutoModelForTokenClassification
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
from transformers import AutoModelForSequenceClassification, AutoModelForQuestionAnswering, AutoModelForTokenClassification
|
from transformers import AutoModel, AutoModelForSequenceClassification, \
|
||||||
|
AutoModelForQuestionAnswering, AutoModelForTokenClassification
|
||||||
|
|
||||||
|
|
||||||
class Pipeline(ABC):
|
class Pipeline(ABC):
|
||||||
@@ -58,6 +62,84 @@ class Pipeline(ABC):
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineDataFormat:
|
||||||
|
SUPPORTED_FORMATS = ['json', 'csv']
|
||||||
|
|
||||||
|
def __init__(self, output: str, path: str, column: str):
|
||||||
|
self.output = output
|
||||||
|
self.path = path
|
||||||
|
self.column = column.split(',')
|
||||||
|
self.is_multi_columns = len(self.column) > 1
|
||||||
|
|
||||||
|
if self.is_multi_columns:
|
||||||
|
self.column = [tuple(c.split('=')) if '=' in c else (c, c) for c in self.column]
|
||||||
|
|
||||||
|
from os.path import abspath, exists
|
||||||
|
if exists(abspath(self.output)):
|
||||||
|
raise OSError('{} already exists on disk'.format(self.output))
|
||||||
|
|
||||||
|
if not exists(abspath(self.path)):
|
||||||
|
raise OSError('{} doesnt exist on disk'.format(self.path))
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __iter__(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save(self, data: dict):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_str(name: str, output: str, path: str, column: str):
|
||||||
|
if name == 'json':
|
||||||
|
return JsonPipelineDataFormat(output, path, column)
|
||||||
|
elif name == 'csv':
|
||||||
|
return CsvPipelineDataFormat(output, path, column)
|
||||||
|
else:
|
||||||
|
raise KeyError('Unknown reader {} (Available reader are json/csv)'.format(name))
|
||||||
|
|
||||||
|
|
||||||
|
class CsvPipelineDataFormat(PipelineDataFormat):
|
||||||
|
def __init__(self, output: str, path: str, column: str):
|
||||||
|
super().__init__(output, path, column)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
with open(self.path, 'r') as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
for row in reader:
|
||||||
|
if self.is_multi_columns:
|
||||||
|
yield {k: row[c] for k, c in self.column}
|
||||||
|
else:
|
||||||
|
yield row[self.column]
|
||||||
|
|
||||||
|
def save(self, data: List[dict]):
|
||||||
|
with open(self.output, 'w') as f:
|
||||||
|
if len(data) > 0:
|
||||||
|
writer = csv.DictWriter(f, list(data[0].keys()))
|
||||||
|
writer.writeheader()
|
||||||
|
writer.writerows(data)
|
||||||
|
|
||||||
|
|
||||||
|
class JsonPipelineDataFormat(PipelineDataFormat):
|
||||||
|
|
||||||
|
def __init__(self, output: str, path: str, column: str):
|
||||||
|
super().__init__(output, path, column)
|
||||||
|
|
||||||
|
with open(path, 'r') as f:
|
||||||
|
self._entries = json.load(f)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
for entry in self._entries:
|
||||||
|
if self.is_multi_columns:
|
||||||
|
yield {k: entry[c] for k, c in self.column}
|
||||||
|
else:
|
||||||
|
yield entry[self.column]
|
||||||
|
|
||||||
|
def save(self, data: dict):
|
||||||
|
with open(self.output, 'w') as f:
|
||||||
|
json.dump(data, f)
|
||||||
|
|
||||||
|
|
||||||
class FeatureExtractionPipeline(Pipeline):
|
class FeatureExtractionPipeline(Pipeline):
|
||||||
|
|
||||||
def __call__(self, *texts, **kwargs):
|
def __call__(self, *texts, **kwargs):
|
||||||
@@ -127,7 +209,7 @@ class NerPipeline(Pipeline):
|
|||||||
label_idx = score.argmax()
|
label_idx = score.argmax()
|
||||||
|
|
||||||
answer += [{
|
answer += [{
|
||||||
'word': words[idx - 1], 'score': score[label_idx], 'entity': self.model.config.id2label[label_idx]
|
'word': words[idx - 1], 'score': score[label_idx].item(), 'entity': self.model.config.id2label[label_idx]
|
||||||
}]
|
}]
|
||||||
|
|
||||||
# Update token start
|
# Update token start
|
||||||
@@ -270,16 +352,18 @@ class QuestionAnsweringPipeline(Pipeline):
|
|||||||
char_to_word = np.array(example.char_to_word_offset)
|
char_to_word = np.array(example.char_to_word_offset)
|
||||||
|
|
||||||
# Convert the answer (tokens) back to the original text
|
# Convert the answer (tokens) back to the original text
|
||||||
answers += [[
|
answers += [
|
||||||
{
|
{
|
||||||
'score': score,
|
'score': score.item(),
|
||||||
'start': np.where(char_to_word == feature.token_to_orig_map[s])[0][0],
|
'start': np.where(char_to_word == feature.token_to_orig_map[s])[0][0].item(),
|
||||||
'end': np.where(char_to_word == feature.token_to_orig_map[e])[0][-1],
|
'end': np.where(char_to_word == feature.token_to_orig_map[e])[0][-1].item(),
|
||||||
'answer': ' '.join(example.doc_tokens[feature.token_to_orig_map[s]: feature.token_to_orig_map[e] + 1])
|
'answer': ' '.join(example.doc_tokens[feature.token_to_orig_map[s]: feature.token_to_orig_map[e] + 1])
|
||||||
}
|
}
|
||||||
for s, e, score in zip(starts, ends, scores)
|
for s, e, score in zip(starts, ends, scores)
|
||||||
]]
|
]
|
||||||
|
|
||||||
|
if len(answers) == 1:
|
||||||
|
return answers[0]
|
||||||
return answers
|
return answers
|
||||||
|
|
||||||
def decode(self, start: np.ndarray, end: np.ndarray, topk: int, max_answer_len: int) -> Tuple:
|
def decode(self, start: np.ndarray, end: np.ndarray, topk: int, max_answer_len: int) -> Tuple:
|
||||||
@@ -363,7 +447,7 @@ def pipeline(task: str, model, config: Optional[PretrainedConfig] = None, tokeni
|
|||||||
Utility factory method to build pipeline.
|
Utility factory method to build pipeline.
|
||||||
"""
|
"""
|
||||||
# Try to infer tokenizer from model name (if provided as str)
|
# Try to infer tokenizer from model name (if provided as str)
|
||||||
if not isinstance(tokenizer, PreTrainedTokenizer):
|
if tokenizer is None:
|
||||||
if not isinstance(model, str):
|
if not isinstance(model, str):
|
||||||
# Impossible to guest what is the right tokenizer here
|
# Impossible to guest what is the right tokenizer here
|
||||||
raise Exception('Tokenizer cannot be None if provided model is a PreTrainedModel instance')
|
raise Exception('Tokenizer cannot be None if provided model is a PreTrainedModel instance')
|
||||||
|
|||||||
Reference in New Issue
Block a user