Pipelines: miscellanea of QoL improvements and small features... (#4632)
* [hf_api] Attach all unknown attributes for future-proof compatibility * [Pipeline] NerPipeline is really a TokenClassificationPipeline * modelcard.py: I don't think we need to force the download * Remove config, tokenizer from SUPPORTED_TASKS as we're moving to one model = one weight + one tokenizer * FillMaskPipeline: also output token in string form * TextClassificationPipeline: option to return all scores, not just the argmax * Update docs/source/main_classes/pipelines.rst
This commit is contained in:
@@ -8,7 +8,7 @@ Recognition, Masked Language Modeling, Sentiment Analysis, Feature Extraction an
|
|||||||
There are two categories of pipeline abstractions to be aware about:
|
There are two categories of pipeline abstractions to be aware about:
|
||||||
|
|
||||||
- The :class:`~transformers.pipeline` which is the most powerful object encapsulating all other pipelines
|
- The :class:`~transformers.pipeline` which is the most powerful object encapsulating all other pipelines
|
||||||
- The other task-specific pipelines, such as :class:`~transformers.NerPipeline`
|
- The other task-specific pipelines, such as :class:`~transformers.TokenClassificationPipeline`
|
||||||
or :class:`~transformers.QuestionAnsweringPipeline`
|
or :class:`~transformers.QuestionAnsweringPipeline`
|
||||||
|
|
||||||
The pipeline abstraction
|
The pipeline abstraction
|
||||||
@@ -30,15 +30,15 @@ Parent class: Pipeline
|
|||||||
.. autoclass:: transformers.Pipeline
|
.. autoclass:: transformers.Pipeline
|
||||||
:members: predict, transform, save_pretrained
|
:members: predict, transform, save_pretrained
|
||||||
|
|
||||||
NerPipeline
|
|
||||||
==========================================
|
|
||||||
|
|
||||||
.. autoclass:: transformers.NerPipeline
|
|
||||||
|
|
||||||
TokenClassificationPipeline
|
TokenClassificationPipeline
|
||||||
==========================================
|
==========================================
|
||||||
|
|
||||||
This class is an alias of the :class:`~transformers.NerPipeline` defined above. Please refer to that pipeline for
|
.. autoclass:: transformers.TokenClassificationPipeline
|
||||||
|
|
||||||
|
NerPipeline
|
||||||
|
==========================================
|
||||||
|
|
||||||
|
This class is an alias of the :class:`~transformers.TokenClassificationPipeline` defined above. Please refer to that pipeline for
|
||||||
documentation and usage examples.
|
documentation and usage examples.
|
||||||
|
|
||||||
FillMaskPipeline
|
FillMaskPipeline
|
||||||
|
|||||||
@@ -64,6 +64,8 @@ class S3Object:
|
|||||||
self.lastModified = lastModified
|
self.lastModified = lastModified
|
||||||
self.size = size
|
self.size = size
|
||||||
self.rfilename = rfilename
|
self.rfilename = rfilename
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
setattr(self, k, v)
|
||||||
|
|
||||||
|
|
||||||
class ModelInfo:
|
class ModelInfo:
|
||||||
@@ -78,7 +80,7 @@ class ModelInfo:
|
|||||||
author: Optional[str] = None,
|
author: Optional[str] = None,
|
||||||
downloads: Optional[int] = None,
|
downloads: Optional[int] = None,
|
||||||
tags: List[str] = [],
|
tags: List[str] = [],
|
||||||
siblings: List[Dict] = [], # list of files that constitute the model
|
siblings: Optional[List[Dict]] = None, # list of files that constitute the model
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
self.modelId = modelId
|
self.modelId = modelId
|
||||||
@@ -86,7 +88,9 @@ class ModelInfo:
|
|||||||
self.author = author
|
self.author = author
|
||||||
self.downloads = downloads
|
self.downloads = downloads
|
||||||
self.tags = tags
|
self.tags = tags
|
||||||
self.siblings = [S3Object(**x) for x in siblings]
|
self.siblings = [S3Object(**x) for x in siblings] if siblings is not None else None
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
setattr(self, k, v)
|
||||||
|
|
||||||
|
|
||||||
class HfApi:
|
class HfApi:
|
||||||
|
|||||||
@@ -149,9 +149,7 @@ class ModelCard:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Load from URL or cache if already cached
|
# Load from URL or cache if already cached
|
||||||
resolved_model_card_file = cached_path(
|
resolved_model_card_file = cached_path(model_card_file, cache_dir=cache_dir, proxies=proxies)
|
||||||
model_card_file, cache_dir=cache_dir, force_download=True, proxies=proxies, resume_download=False
|
|
||||||
)
|
|
||||||
if resolved_model_card_file is None:
|
if resolved_model_card_file is None:
|
||||||
raise EnvironmentError
|
raise EnvironmentError
|
||||||
if resolved_model_card_file == model_card_file:
|
if resolved_model_card_file == model_card_file:
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence,
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, AutoConfig
|
from .configuration_auto import AutoConfig
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
from .data import SquadExample, squad_convert_examples_to_features
|
from .data import SquadExample, squad_convert_examples_to_features
|
||||||
from .file_utils import is_tf_available, is_torch_available
|
from .file_utils import is_tf_available, is_torch_available
|
||||||
@@ -717,10 +717,23 @@ class TextClassificationPipeline(Pipeline):
|
|||||||
on the associated CUDA device id.
|
on the associated CUDA device id.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, return_all_scores: bool = False, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
self.return_all_scores = return_all_scores
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
outputs = super().__call__(*args, **kwargs)
|
outputs = super().__call__(*args, **kwargs)
|
||||||
scores = np.exp(outputs) / np.exp(outputs).sum(-1, keepdims=True)
|
scores = np.exp(outputs) / np.exp(outputs).sum(-1, keepdims=True)
|
||||||
return [{"label": self.model.config.id2label[item.argmax()], "score": item.max().item()} for item in scores]
|
if self.return_all_scores:
|
||||||
|
return [
|
||||||
|
[{"label": self.model.config.id2label[i], "score": score} for i, score in enumerate(item)]
|
||||||
|
for item in scores
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
return [
|
||||||
|
{"label": self.model.config.id2label[item.argmax()], "score": item.max().item()} for item in scores
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class FillMaskPipeline(Pipeline):
|
class FillMaskPipeline(Pipeline):
|
||||||
@@ -813,7 +826,14 @@ class FillMaskPipeline(Pipeline):
|
|||||||
tokens[masked_index] = p
|
tokens[masked_index] = p
|
||||||
# Filter padding out:
|
# Filter padding out:
|
||||||
tokens = tokens[np.where(tokens != self.tokenizer.pad_token_id)]
|
tokens = tokens[np.where(tokens != self.tokenizer.pad_token_id)]
|
||||||
result.append({"sequence": self.tokenizer.decode(tokens), "score": v, "token": p})
|
result.append(
|
||||||
|
{
|
||||||
|
"sequence": self.tokenizer.decode(tokens),
|
||||||
|
"score": v,
|
||||||
|
"token": p,
|
||||||
|
"token_str": self.tokenizer.convert_ids_to_tokens(p),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Append
|
# Append
|
||||||
results += [result]
|
results += [result]
|
||||||
@@ -823,7 +843,7 @@ class FillMaskPipeline(Pipeline):
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
class NerPipeline(Pipeline):
|
class TokenClassificationPipeline(Pipeline):
|
||||||
"""
|
"""
|
||||||
Named Entity Recognition pipeline using ModelForTokenClassification head. See the
|
Named Entity Recognition pipeline using ModelForTokenClassification head. See the
|
||||||
`named entity recognition usage <../usage.html#named-entity-recognition>`__ examples for more information.
|
`named entity recognition usage <../usage.html#named-entity-recognition>`__ examples for more information.
|
||||||
@@ -987,7 +1007,7 @@ class NerPipeline(Pipeline):
|
|||||||
return entity_group
|
return entity_group
|
||||||
|
|
||||||
|
|
||||||
TokenClassificationPipeline = NerPipeline
|
NerPipeline = TokenClassificationPipeline
|
||||||
|
|
||||||
|
|
||||||
class QuestionAnsweringArgumentHandler(ArgumentHandler):
|
class QuestionAnsweringArgumentHandler(ArgumentHandler):
|
||||||
@@ -1577,11 +1597,7 @@ SUPPORTED_TASKS = {
|
|||||||
"impl": FeatureExtractionPipeline,
|
"impl": FeatureExtractionPipeline,
|
||||||
"tf": TFAutoModel if is_tf_available() else None,
|
"tf": TFAutoModel if is_tf_available() else None,
|
||||||
"pt": AutoModel if is_torch_available() else None,
|
"pt": AutoModel if is_torch_available() else None,
|
||||||
"default": {
|
"default": {"model": {"pt": "distilbert-base-cased", "tf": "distilbert-base-cased"}},
|
||||||
"model": {"pt": "distilbert-base-cased", "tf": "distilbert-base-cased"},
|
|
||||||
"config": None,
|
|
||||||
"tokenizer": "distilbert-base-cased",
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
"sentiment-analysis": {
|
"sentiment-analysis": {
|
||||||
"impl": TextClassificationPipeline,
|
"impl": TextClassificationPipeline,
|
||||||
@@ -1592,12 +1608,10 @@ SUPPORTED_TASKS = {
|
|||||||
"pt": "distilbert-base-uncased-finetuned-sst-2-english",
|
"pt": "distilbert-base-uncased-finetuned-sst-2-english",
|
||||||
"tf": "distilbert-base-uncased-finetuned-sst-2-english",
|
"tf": "distilbert-base-uncased-finetuned-sst-2-english",
|
||||||
},
|
},
|
||||||
"config": "distilbert-base-uncased-finetuned-sst-2-english",
|
|
||||||
"tokenizer": "distilbert-base-uncased",
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"ner": {
|
"ner": {
|
||||||
"impl": NerPipeline,
|
"impl": TokenClassificationPipeline,
|
||||||
"tf": TFAutoModelForTokenClassification if is_tf_available() else None,
|
"tf": TFAutoModelForTokenClassification if is_tf_available() else None,
|
||||||
"pt": AutoModelForTokenClassification if is_torch_available() else None,
|
"pt": AutoModelForTokenClassification if is_torch_available() else None,
|
||||||
"default": {
|
"default": {
|
||||||
@@ -1605,8 +1619,6 @@ SUPPORTED_TASKS = {
|
|||||||
"pt": "dbmdz/bert-large-cased-finetuned-conll03-english",
|
"pt": "dbmdz/bert-large-cased-finetuned-conll03-english",
|
||||||
"tf": "dbmdz/bert-large-cased-finetuned-conll03-english",
|
"tf": "dbmdz/bert-large-cased-finetuned-conll03-english",
|
||||||
},
|
},
|
||||||
"config": "dbmdz/bert-large-cased-finetuned-conll03-english",
|
|
||||||
"tokenizer": "bert-large-cased",
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"question-answering": {
|
"question-answering": {
|
||||||
@@ -1615,61 +1627,43 @@ SUPPORTED_TASKS = {
|
|||||||
"pt": AutoModelForQuestionAnswering if is_torch_available() else None,
|
"pt": AutoModelForQuestionAnswering if is_torch_available() else None,
|
||||||
"default": {
|
"default": {
|
||||||
"model": {"pt": "distilbert-base-cased-distilled-squad", "tf": "distilbert-base-cased-distilled-squad"},
|
"model": {"pt": "distilbert-base-cased-distilled-squad", "tf": "distilbert-base-cased-distilled-squad"},
|
||||||
"config": None,
|
|
||||||
"tokenizer": ("distilbert-base-cased", {"use_fast": False}),
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"fill-mask": {
|
"fill-mask": {
|
||||||
"impl": FillMaskPipeline,
|
"impl": FillMaskPipeline,
|
||||||
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
|
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
|
||||||
"pt": AutoModelWithLMHead if is_torch_available() else None,
|
"pt": AutoModelWithLMHead if is_torch_available() else None,
|
||||||
"default": {
|
"default": {"model": {"pt": "distilroberta-base", "tf": "distilroberta-base"}},
|
||||||
"model": {"pt": "distilroberta-base", "tf": "distilroberta-base"},
|
|
||||||
"config": None,
|
|
||||||
"tokenizer": ("distilroberta-base", {"use_fast": False}),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
"summarization": {
|
"summarization": {
|
||||||
"impl": SummarizationPipeline,
|
"impl": SummarizationPipeline,
|
||||||
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
|
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
|
||||||
"pt": AutoModelWithLMHead if is_torch_available() else None,
|
"pt": AutoModelWithLMHead if is_torch_available() else None,
|
||||||
"default": {"model": {"pt": "facebook/bart-large-cnn", "tf": "t5-small"}, "config": None, "tokenizer": None},
|
"default": {"model": {"pt": "facebook/bart-large-cnn", "tf": "t5-small"}},
|
||||||
},
|
},
|
||||||
"translation_en_to_fr": {
|
"translation_en_to_fr": {
|
||||||
"impl": TranslationPipeline,
|
"impl": TranslationPipeline,
|
||||||
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
|
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
|
||||||
"pt": AutoModelWithLMHead if is_torch_available() else None,
|
"pt": AutoModelWithLMHead if is_torch_available() else None,
|
||||||
"default": {
|
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
|
||||||
"model": {"pt": "t5-base", "tf": "t5-base"},
|
|
||||||
"config": None,
|
|
||||||
"tokenizer": ("t5-base", {"use_fast": False}),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
"translation_en_to_de": {
|
"translation_en_to_de": {
|
||||||
"impl": TranslationPipeline,
|
"impl": TranslationPipeline,
|
||||||
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
|
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
|
||||||
"pt": AutoModelWithLMHead if is_torch_available() else None,
|
"pt": AutoModelWithLMHead if is_torch_available() else None,
|
||||||
"default": {
|
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
|
||||||
"model": {"pt": "t5-base", "tf": "t5-base"},
|
|
||||||
"config": None,
|
|
||||||
"tokenizer": ("t5-base", {"use_fast": False}),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
"translation_en_to_ro": {
|
"translation_en_to_ro": {
|
||||||
"impl": TranslationPipeline,
|
"impl": TranslationPipeline,
|
||||||
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
|
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
|
||||||
"pt": AutoModelWithLMHead if is_torch_available() else None,
|
"pt": AutoModelWithLMHead if is_torch_available() else None,
|
||||||
"default": {
|
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
|
||||||
"model": {"pt": "t5-base", "tf": "t5-base"},
|
|
||||||
"config": None,
|
|
||||||
"tokenizer": ("t5-base", {"use_fast": False}),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
"text-generation": {
|
"text-generation": {
|
||||||
"impl": TextGenerationPipeline,
|
"impl": TextGenerationPipeline,
|
||||||
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
|
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
|
||||||
"pt": AutoModelWithLMHead if is_torch_available() else None,
|
"pt": AutoModelWithLMHead if is_torch_available() else None,
|
||||||
"default": {"model": {"pt": "gpt2", "tf": "gpt2"}, "config": None, "tokenizer": "gpt2"},
|
"default": {"model": {"pt": "gpt2", "tf": "gpt2"}},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1698,11 +1692,12 @@ def pipeline(
|
|||||||
|
|
||||||
- "feature-extraction": will return a :class:`~transformers.FeatureExtractionPipeline`
|
- "feature-extraction": will return a :class:`~transformers.FeatureExtractionPipeline`
|
||||||
- "sentiment-analysis": will return a :class:`~transformers.TextClassificationPipeline`
|
- "sentiment-analysis": will return a :class:`~transformers.TextClassificationPipeline`
|
||||||
- "ner": will return a :class:`~transformers.NerPipeline`
|
- "ner": will return a :class:`~transformers.TokenClassificationPipeline`
|
||||||
- "question-answering": will return a :class:`~transformers.QuestionAnsweringPipeline`
|
- "question-answering": will return a :class:`~transformers.QuestionAnsweringPipeline`
|
||||||
- "fill-mask": will return a :class:`~transformers.FillMaskPipeline`
|
- "fill-mask": will return a :class:`~transformers.FillMaskPipeline`
|
||||||
- "summarization": will return a :class:`~transformers.SummarizationPipeline`
|
- "summarization": will return a :class:`~transformers.SummarizationPipeline`
|
||||||
- "translation_xx_to_yy": will return a :class:`~transformers.TranslationPipeline`
|
- "translation_xx_to_yy": will return a :class:`~transformers.TranslationPipeline`
|
||||||
|
- "text-generation": will return a :class:`~transformers.TextGenerationPipeline`
|
||||||
model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`, defaults to :obj:`None`):
|
model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`, defaults to :obj:`None`):
|
||||||
The model that will be used by the pipeline to make predictions. This can be :obj:`None`,
|
The model that will be used by the pipeline to make predictions. This can be :obj:`None`,
|
||||||
a model identifier or an actual pre-trained model inheriting from
|
a model identifier or an actual pre-trained model inheriting from
|
||||||
@@ -1759,14 +1754,13 @@ def pipeline(
|
|||||||
|
|
||||||
# Use default model/config/tokenizer for the task if no model is provided
|
# Use default model/config/tokenizer for the task if no model is provided
|
||||||
if model is None:
|
if model is None:
|
||||||
models, config, tokenizer = [targeted_task["default"][k] for k in ["model", "config", "tokenizer"]]
|
model = targeted_task["default"]["model"][framework]
|
||||||
model = models[framework]
|
|
||||||
|
|
||||||
# Try to infer tokenizer from model or config name (if provided as str)
|
# Try to infer tokenizer from model or config name (if provided as str)
|
||||||
if tokenizer is None:
|
if tokenizer is None:
|
||||||
if isinstance(model, str) and model in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
|
if isinstance(model, str):
|
||||||
tokenizer = model
|
tokenizer = model
|
||||||
elif isinstance(config, str) and config in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
|
elif isinstance(config, str):
|
||||||
tokenizer = config
|
tokenizer = config
|
||||||
else:
|
else:
|
||||||
# Impossible to guest what is the right tokenizer here
|
# Impossible to guest what is the right tokenizer here
|
||||||
|
|||||||
Reference in New Issue
Block a user