Add Summarization to Pipelines (#3128)
* passing * Undo stupid chg * docs * undo rename * delete-cruft * only import if you have torch * Dont rely on dict ordering * Fix dict ordering upstream * docstring link * docstring link * remove trailing comma for 3.5 compat * new name * delegate kwarging * Update kwargs
This commit is contained in:
@@ -113,6 +113,7 @@ from .pipelines import (
|
||||
Pipeline,
|
||||
PipelineDataFormat,
|
||||
QuestionAnsweringPipeline,
|
||||
SummarizationPipeline,
|
||||
TextClassificationPipeline,
|
||||
TokenClassificationPipeline,
|
||||
pipeline,
|
||||
|
||||
@@ -60,6 +60,7 @@ if is_torch_available():
|
||||
AutoModelForTokenClassification,
|
||||
AutoModelWithLMHead,
|
||||
)
|
||||
from .modeling_bart import BartForConditionalGeneration
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -1104,6 +1105,107 @@ class QuestionAnsweringPipeline(Pipeline):
|
||||
return {"answer": " ".join(words), "start": max(0, char_start_idx), "end": min(len(text), char_end_idx)}
|
||||
|
||||
|
||||
class SummarizationPipeline(Pipeline):
|
||||
"""
|
||||
Summarize news articles and other documents
|
||||
|
||||
Usage::
|
||||
|
||||
summarizer = pipeline("summarization")
|
||||
summarizer("Sam Shleifer writes the best docstring examples in the whole world.")
|
||||
|
||||
Supported Models:
|
||||
The models that this pipeline can use are models that have been fine-tuned on a summarization task, which is
|
||||
currently only ``BartForConditionalGeneration.from_pretrained('bart-large-cnn')``
|
||||
|
||||
Arguments:
|
||||
model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`, defaults to :obj:`None`):
|
||||
The model that will be used by the pipeline to make predictions. This can be :obj:`None`, a string
|
||||
checkpoint identifier or an actual pre-trained model inheriting from
|
||||
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
|
||||
TensorFlow.
|
||||
|
||||
If :obj:`None`, the default of the pipeline will be loaded.
|
||||
tokenizer (:obj:`str` or :obj:`~transformers.PreTrainedTokenizer`, `optional`, defaults to :obj:`None`):
|
||||
The tokenizer that will be used by the pipeline to encode data for the model. This can be :obj:`None`,
|
||||
a string checkpoint identifier or an actual pre-trained tokenizer inheriting from
|
||||
:class:`~transformers.PreTrainedTokenizer`.
|
||||
|
||||
If :obj:`None`, the default of the pipeline will be loaded.
|
||||
modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`, defaults to :obj:`None`):
|
||||
Model card attributed to the model for this pipeline.
|
||||
framework (:obj:`str`, `optional`, defaults to :obj:`None`):
|
||||
The framework to use, either "pt" for PyTorch or "tf" for TensorFlow. The specified framework must be
|
||||
installed.
|
||||
|
||||
If no framework is specified, will default to the one currently installed. If no framework is specified
|
||||
and both frameworks are installed, will default to PyTorch.
|
||||
args_parser (:class:`~transformers.pipelines.ArgumentHandler`, `optional`, defaults to :obj:`None`):
|
||||
Reference to the object in charge of parsing supplied pipeline parameters.
|
||||
device (:obj:`int`, `optional`, defaults to :obj:`-1`):
|
||||
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, >=0 will run the model
|
||||
on the associated CUDA device id.
|
||||
"""
|
||||
|
||||
task = "summarization"
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
*documents,
|
||||
return_tensors=False,
|
||||
return_text=True,
|
||||
max_length=142,
|
||||
min_length=21,
|
||||
clean_up_tokenization_spaces=False,
|
||||
**generate_kwargs
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
*documents: (list of strings) articles to be summarized
|
||||
return_text: (bool, default=True) whether to add a decoded "summary_text" to each result
|
||||
return_tensors: (bool, default=False) whether to return the raw "summary_token_ids" to each result
|
||||
|
||||
max_length: (`optional`) int
|
||||
The max length of the sequence to be generated. Does not include tokens in input_ids.
|
||||
min_len: (`optional`) int
|
||||
no_repeat_ngram_size: (`optional`) int. ban ngrams of this length from being repeated in the generated text
|
||||
clean_up_tokenization_spaces: (`optional`) bool whether to include extra spaces in the output
|
||||
**generate_kwargs: extra kwargs passed to `self.model.generate`_
|
||||
|
||||
Returns:
|
||||
list of dicts with 'summary_text' and/or 'summary_token_ids' for each document_to_summarize
|
||||
|
||||
.. _`self.model.generate`:
|
||||
https://huggingface.co/transformers/model_doc/bart.html#transformers.BartForConditionalGeneration.generate
|
||||
|
||||
"""
|
||||
assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"
|
||||
if self.framework == "tf":
|
||||
raise NotImplementedError("Tensorflow not supported")
|
||||
with self.device_placement():
|
||||
inputs = self._parse_and_tokenize(*documents)
|
||||
inputs = self.ensure_tensor_on_device(**inputs)
|
||||
summaries = self.model.generate(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
max_length=max_length,
|
||||
min_length=min_length,
|
||||
do_sample=False,
|
||||
**generate_kwargs,
|
||||
)
|
||||
results = []
|
||||
for summary in summaries:
|
||||
record = {}
|
||||
if return_tensors:
|
||||
record["summary_token_ids"] = summary
|
||||
if return_text:
|
||||
record["summary_text"] = self.tokenizer.decode(
|
||||
summary, skip_special_tokens=True, clean_up_tokenization_spaces=clean_up_tokenization_spaces
|
||||
)
|
||||
results.append(record)
|
||||
return results
|
||||
|
||||
|
||||
# Register all the supported task here
|
||||
SUPPORTED_TASKS = {
|
||||
"feature-extraction": {
|
||||
@@ -1162,6 +1264,16 @@ SUPPORTED_TASKS = {
|
||||
"tokenizer": ("distilroberta-base", {"use_fast": False}),
|
||||
},
|
||||
},
|
||||
"summarization": {
|
||||
"impl": SummarizationPipeline,
|
||||
"pt": BartForConditionalGeneration if is_torch_available() else None,
|
||||
"tf": None,
|
||||
"default": {
|
||||
"model": {"pt": "bart-large-cnn", "tf": None},
|
||||
"config": None,
|
||||
"tokenizer": ("bart-large-cnn", {"use_fast": False}),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -1253,7 +1365,7 @@ def pipeline(
|
||||
|
||||
# Use default model/config/tokenizer for the task if no model is provided
|
||||
if model is None:
|
||||
models, config, tokenizer = tuple(targeted_task["default"].values())
|
||||
models, config, tokenizer = [targeted_task["default"][k] for k in ["model", "config", "tokenizer"]]
|
||||
model = models[framework]
|
||||
|
||||
# Try to infer tokenizer from model or config name (if provided as str)
|
||||
|
||||
Reference in New Issue
Block a user