Pipelines: miscellanea of QoL improvements and small features... (#4632)
* [hf_api] Attach all unknown attributes for future-proof compatibility * [Pipeline] NerPipeline is really a TokenClassificationPipeline * modelcard.py: I don't think we need to force the download * Remove config, tokenizer from SUPPORTED_TASKS as we're moving to one model = one weight + one tokenizer * FillMaskPipeline: also output token in string form * TextClassificationPipeline: option to return all scores, not just the argmax * Update docs/source/main_classes/pipelines.rst
This commit is contained in:
@@ -8,7 +8,7 @@ Recognition, Masked Language Modeling, Sentiment Analysis, Feature Extraction an
|
||||
There are two categories of pipeline abstractions to be aware about:
|
||||
|
||||
- The :class:`~transformers.pipeline` which is the most powerful object encapsulating all other pipelines
|
||||
- The other task-specific pipelines, such as :class:`~transformers.NerPipeline`
|
||||
- The other task-specific pipelines, such as :class:`~transformers.TokenClassificationPipeline`
|
||||
or :class:`~transformers.QuestionAnsweringPipeline`
|
||||
|
||||
The pipeline abstraction
|
||||
@@ -30,15 +30,15 @@ Parent class: Pipeline
|
||||
.. autoclass:: transformers.Pipeline
|
||||
:members: predict, transform, save_pretrained
|
||||
|
||||
NerPipeline
|
||||
==========================================
|
||||
|
||||
.. autoclass:: transformers.NerPipeline
|
||||
|
||||
TokenClassificationPipeline
|
||||
==========================================
|
||||
|
||||
This class is an alias of the :class:`~transformers.NerPipeline` defined above. Please refer to that pipeline for
|
||||
.. autoclass:: transformers.TokenClassificationPipeline
|
||||
|
||||
NerPipeline
|
||||
==========================================
|
||||
|
||||
This class is an alias of the :class:`~transformers.TokenClassificationPipeline` defined above. Please refer to that pipeline for
|
||||
documentation and usage examples.
|
||||
|
||||
FillMaskPipeline
|
||||
|
||||
@@ -64,6 +64,8 @@ class S3Object:
|
||||
self.lastModified = lastModified
|
||||
self.size = size
|
||||
self.rfilename = rfilename
|
||||
for k, v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
|
||||
class ModelInfo:
|
||||
@@ -78,7 +80,7 @@ class ModelInfo:
|
||||
author: Optional[str] = None,
|
||||
downloads: Optional[int] = None,
|
||||
tags: List[str] = [],
|
||||
siblings: List[Dict] = [], # list of files that constitute the model
|
||||
siblings: Optional[List[Dict]] = None, # list of files that constitute the model
|
||||
**kwargs
|
||||
):
|
||||
self.modelId = modelId
|
||||
@@ -86,7 +88,9 @@ class ModelInfo:
|
||||
self.author = author
|
||||
self.downloads = downloads
|
||||
self.tags = tags
|
||||
self.siblings = [S3Object(**x) for x in siblings]
|
||||
self.siblings = [S3Object(**x) for x in siblings] if siblings is not None else None
|
||||
for k, v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
|
||||
class HfApi:
|
||||
|
||||
@@ -149,9 +149,7 @@ class ModelCard:
|
||||
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
resolved_model_card_file = cached_path(
|
||||
model_card_file, cache_dir=cache_dir, force_download=True, proxies=proxies, resume_download=False
|
||||
)
|
||||
resolved_model_card_file = cached_path(model_card_file, cache_dir=cache_dir, proxies=proxies)
|
||||
if resolved_model_card_file is None:
|
||||
raise EnvironmentError
|
||||
if resolved_model_card_file == model_card_file:
|
||||
|
||||
@@ -28,7 +28,7 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence,
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, AutoConfig
|
||||
from .configuration_auto import AutoConfig
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .data import SquadExample, squad_convert_examples_to_features
|
||||
from .file_utils import is_tf_available, is_torch_available
|
||||
@@ -717,10 +717,23 @@ class TextClassificationPipeline(Pipeline):
|
||||
on the associated CUDA device id.
|
||||
"""
|
||||
|
||||
def __init__(self, return_all_scores: bool = False, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.return_all_scores = return_all_scores
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
outputs = super().__call__(*args, **kwargs)
|
||||
scores = np.exp(outputs) / np.exp(outputs).sum(-1, keepdims=True)
|
||||
return [{"label": self.model.config.id2label[item.argmax()], "score": item.max().item()} for item in scores]
|
||||
if self.return_all_scores:
|
||||
return [
|
||||
[{"label": self.model.config.id2label[i], "score": score} for i, score in enumerate(item)]
|
||||
for item in scores
|
||||
]
|
||||
else:
|
||||
return [
|
||||
{"label": self.model.config.id2label[item.argmax()], "score": item.max().item()} for item in scores
|
||||
]
|
||||
|
||||
|
||||
class FillMaskPipeline(Pipeline):
|
||||
@@ -813,7 +826,14 @@ class FillMaskPipeline(Pipeline):
|
||||
tokens[masked_index] = p
|
||||
# Filter padding out:
|
||||
tokens = tokens[np.where(tokens != self.tokenizer.pad_token_id)]
|
||||
result.append({"sequence": self.tokenizer.decode(tokens), "score": v, "token": p})
|
||||
result.append(
|
||||
{
|
||||
"sequence": self.tokenizer.decode(tokens),
|
||||
"score": v,
|
||||
"token": p,
|
||||
"token_str": self.tokenizer.convert_ids_to_tokens(p),
|
||||
}
|
||||
)
|
||||
|
||||
# Append
|
||||
results += [result]
|
||||
@@ -823,7 +843,7 @@ class FillMaskPipeline(Pipeline):
|
||||
return results
|
||||
|
||||
|
||||
class NerPipeline(Pipeline):
|
||||
class TokenClassificationPipeline(Pipeline):
|
||||
"""
|
||||
Named Entity Recognition pipeline using ModelForTokenClassification head. See the
|
||||
`named entity recognition usage <../usage.html#named-entity-recognition>`__ examples for more information.
|
||||
@@ -987,7 +1007,7 @@ class NerPipeline(Pipeline):
|
||||
return entity_group
|
||||
|
||||
|
||||
TokenClassificationPipeline = NerPipeline
|
||||
NerPipeline = TokenClassificationPipeline
|
||||
|
||||
|
||||
class QuestionAnsweringArgumentHandler(ArgumentHandler):
|
||||
@@ -1577,11 +1597,7 @@ SUPPORTED_TASKS = {
|
||||
"impl": FeatureExtractionPipeline,
|
||||
"tf": TFAutoModel if is_tf_available() else None,
|
||||
"pt": AutoModel if is_torch_available() else None,
|
||||
"default": {
|
||||
"model": {"pt": "distilbert-base-cased", "tf": "distilbert-base-cased"},
|
||||
"config": None,
|
||||
"tokenizer": "distilbert-base-cased",
|
||||
},
|
||||
"default": {"model": {"pt": "distilbert-base-cased", "tf": "distilbert-base-cased"}},
|
||||
},
|
||||
"sentiment-analysis": {
|
||||
"impl": TextClassificationPipeline,
|
||||
@@ -1592,12 +1608,10 @@ SUPPORTED_TASKS = {
|
||||
"pt": "distilbert-base-uncased-finetuned-sst-2-english",
|
||||
"tf": "distilbert-base-uncased-finetuned-sst-2-english",
|
||||
},
|
||||
"config": "distilbert-base-uncased-finetuned-sst-2-english",
|
||||
"tokenizer": "distilbert-base-uncased",
|
||||
},
|
||||
},
|
||||
"ner": {
|
||||
"impl": NerPipeline,
|
||||
"impl": TokenClassificationPipeline,
|
||||
"tf": TFAutoModelForTokenClassification if is_tf_available() else None,
|
||||
"pt": AutoModelForTokenClassification if is_torch_available() else None,
|
||||
"default": {
|
||||
@@ -1605,8 +1619,6 @@ SUPPORTED_TASKS = {
|
||||
"pt": "dbmdz/bert-large-cased-finetuned-conll03-english",
|
||||
"tf": "dbmdz/bert-large-cased-finetuned-conll03-english",
|
||||
},
|
||||
"config": "dbmdz/bert-large-cased-finetuned-conll03-english",
|
||||
"tokenizer": "bert-large-cased",
|
||||
},
|
||||
},
|
||||
"question-answering": {
|
||||
@@ -1615,61 +1627,43 @@ SUPPORTED_TASKS = {
|
||||
"pt": AutoModelForQuestionAnswering if is_torch_available() else None,
|
||||
"default": {
|
||||
"model": {"pt": "distilbert-base-cased-distilled-squad", "tf": "distilbert-base-cased-distilled-squad"},
|
||||
"config": None,
|
||||
"tokenizer": ("distilbert-base-cased", {"use_fast": False}),
|
||||
},
|
||||
},
|
||||
"fill-mask": {
|
||||
"impl": FillMaskPipeline,
|
||||
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
|
||||
"pt": AutoModelWithLMHead if is_torch_available() else None,
|
||||
"default": {
|
||||
"model": {"pt": "distilroberta-base", "tf": "distilroberta-base"},
|
||||
"config": None,
|
||||
"tokenizer": ("distilroberta-base", {"use_fast": False}),
|
||||
},
|
||||
"default": {"model": {"pt": "distilroberta-base", "tf": "distilroberta-base"}},
|
||||
},
|
||||
"summarization": {
|
||||
"impl": SummarizationPipeline,
|
||||
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
|
||||
"pt": AutoModelWithLMHead if is_torch_available() else None,
|
||||
"default": {"model": {"pt": "facebook/bart-large-cnn", "tf": "t5-small"}, "config": None, "tokenizer": None},
|
||||
"default": {"model": {"pt": "facebook/bart-large-cnn", "tf": "t5-small"}},
|
||||
},
|
||||
"translation_en_to_fr": {
|
||||
"impl": TranslationPipeline,
|
||||
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
|
||||
"pt": AutoModelWithLMHead if is_torch_available() else None,
|
||||
"default": {
|
||||
"model": {"pt": "t5-base", "tf": "t5-base"},
|
||||
"config": None,
|
||||
"tokenizer": ("t5-base", {"use_fast": False}),
|
||||
},
|
||||
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
|
||||
},
|
||||
"translation_en_to_de": {
|
||||
"impl": TranslationPipeline,
|
||||
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
|
||||
"pt": AutoModelWithLMHead if is_torch_available() else None,
|
||||
"default": {
|
||||
"model": {"pt": "t5-base", "tf": "t5-base"},
|
||||
"config": None,
|
||||
"tokenizer": ("t5-base", {"use_fast": False}),
|
||||
},
|
||||
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
|
||||
},
|
||||
"translation_en_to_ro": {
|
||||
"impl": TranslationPipeline,
|
||||
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
|
||||
"pt": AutoModelWithLMHead if is_torch_available() else None,
|
||||
"default": {
|
||||
"model": {"pt": "t5-base", "tf": "t5-base"},
|
||||
"config": None,
|
||||
"tokenizer": ("t5-base", {"use_fast": False}),
|
||||
},
|
||||
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
|
||||
},
|
||||
"text-generation": {
|
||||
"impl": TextGenerationPipeline,
|
||||
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
|
||||
"pt": AutoModelWithLMHead if is_torch_available() else None,
|
||||
"default": {"model": {"pt": "gpt2", "tf": "gpt2"}, "config": None, "tokenizer": "gpt2"},
|
||||
"default": {"model": {"pt": "gpt2", "tf": "gpt2"}},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1698,11 +1692,12 @@ def pipeline(
|
||||
|
||||
- "feature-extraction": will return a :class:`~transformers.FeatureExtractionPipeline`
|
||||
- "sentiment-analysis": will return a :class:`~transformers.TextClassificationPipeline`
|
||||
- "ner": will return a :class:`~transformers.NerPipeline`
|
||||
- "ner": will return a :class:`~transformers.TokenClassificationPipeline`
|
||||
- "question-answering": will return a :class:`~transformers.QuestionAnsweringPipeline`
|
||||
- "fill-mask": will return a :class:`~transformers.FillMaskPipeline`
|
||||
- "summarization": will return a :class:`~transformers.SummarizationPipeline`
|
||||
- "translation_xx_to_yy": will return a :class:`~transformers.TranslationPipeline`
|
||||
- "text-generation": will return a :class:`~transformers.TextGenerationPipeline`
|
||||
model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`, defaults to :obj:`None`):
|
||||
The model that will be used by the pipeline to make predictions. This can be :obj:`None`,
|
||||
a model identifier or an actual pre-trained model inheriting from
|
||||
@@ -1759,14 +1754,13 @@ def pipeline(
|
||||
|
||||
# Use default model/config/tokenizer for the task if no model is provided
|
||||
if model is None:
|
||||
models, config, tokenizer = [targeted_task["default"][k] for k in ["model", "config", "tokenizer"]]
|
||||
model = models[framework]
|
||||
model = targeted_task["default"]["model"][framework]
|
||||
|
||||
# Try to infer tokenizer from model or config name (if provided as str)
|
||||
if tokenizer is None:
|
||||
if isinstance(model, str) and model in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
|
||||
if isinstance(model, str):
|
||||
tokenizer = model
|
||||
elif isinstance(config, str) and config in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
|
||||
elif isinstance(config, str):
|
||||
tokenizer = config
|
||||
else:
|
||||
# Impossible to guest what is the right tokenizer here
|
||||
|
||||
Reference in New Issue
Block a user