Add Summarization to Pipelines (#3128)
* passing * Undo stupid chg * docs * undo rename * delete-cruft * only import if you have torch * Dont rely on dict ordering * Fix dict ordering upstream * docstring link * docstring link * remove trailing comma for 3.5 compat * new name * delegate kwarging * Update kwargs
This commit is contained in:
@@ -61,3 +61,8 @@ QuestionAnsweringPipeline
|
|||||||
|
|
||||||
.. autoclass:: transformers.QuestionAnsweringPipeline
|
.. autoclass:: transformers.QuestionAnsweringPipeline
|
||||||
|
|
||||||
|
|
||||||
|
SummarizationPipeline
|
||||||
|
==========================================
|
||||||
|
|
||||||
|
.. autoclass:: transformers.SummarizationPipeline
|
||||||
|
|||||||
@@ -113,6 +113,7 @@ from .pipelines import (
|
|||||||
Pipeline,
|
Pipeline,
|
||||||
PipelineDataFormat,
|
PipelineDataFormat,
|
||||||
QuestionAnsweringPipeline,
|
QuestionAnsweringPipeline,
|
||||||
|
SummarizationPipeline,
|
||||||
TextClassificationPipeline,
|
TextClassificationPipeline,
|
||||||
TokenClassificationPipeline,
|
TokenClassificationPipeline,
|
||||||
pipeline,
|
pipeline,
|
||||||
|
|||||||
@@ -60,6 +60,7 @@ if is_torch_available():
|
|||||||
AutoModelForTokenClassification,
|
AutoModelForTokenClassification,
|
||||||
AutoModelWithLMHead,
|
AutoModelWithLMHead,
|
||||||
)
|
)
|
||||||
|
from .modeling_bart import BartForConditionalGeneration
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -1104,6 +1105,107 @@ class QuestionAnsweringPipeline(Pipeline):
|
|||||||
return {"answer": " ".join(words), "start": max(0, char_start_idx), "end": min(len(text), char_end_idx)}
|
return {"answer": " ".join(words), "start": max(0, char_start_idx), "end": min(len(text), char_end_idx)}
|
||||||
|
|
||||||
|
|
||||||
|
class SummarizationPipeline(Pipeline):
|
||||||
|
"""
|
||||||
|
Summarize news articles and other documents
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
summarizer = pipeline("summarization")
|
||||||
|
summarizer("Sam Shleifer writes the best docstring examples in the whole world.")
|
||||||
|
|
||||||
|
Supported Models:
|
||||||
|
The models that this pipeline can use are models that have been fine-tuned on a summarization task, which is
|
||||||
|
currently only ``BartForConditionalGeneration.from_pretrained('bart-large-cnn')``
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`, defaults to :obj:`None`):
|
||||||
|
The model that will be used by the pipeline to make predictions. This can be :obj:`None`, a string
|
||||||
|
checkpoint identifier or an actual pre-trained model inheriting from
|
||||||
|
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
|
||||||
|
TensorFlow.
|
||||||
|
|
||||||
|
If :obj:`None`, the default of the pipeline will be loaded.
|
||||||
|
tokenizer (:obj:`str` or :obj:`~transformers.PreTrainedTokenizer`, `optional`, defaults to :obj:`None`):
|
||||||
|
The tokenizer that will be used by the pipeline to encode data for the model. This can be :obj:`None`,
|
||||||
|
a string checkpoint identifier or an actual pre-trained tokenizer inheriting from
|
||||||
|
:class:`~transformers.PreTrainedTokenizer`.
|
||||||
|
|
||||||
|
If :obj:`None`, the default of the pipeline will be loaded.
|
||||||
|
modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`, defaults to :obj:`None`):
|
||||||
|
Model card attributed to the model for this pipeline.
|
||||||
|
framework (:obj:`str`, `optional`, defaults to :obj:`None`):
|
||||||
|
The framework to use, either "pt" for PyTorch or "tf" for TensorFlow. The specified framework must be
|
||||||
|
installed.
|
||||||
|
|
||||||
|
If no framework is specified, will default to the one currently installed. If no framework is specified
|
||||||
|
and both frameworks are installed, will default to PyTorch.
|
||||||
|
args_parser (:class:`~transformers.pipelines.ArgumentHandler`, `optional`, defaults to :obj:`None`):
|
||||||
|
Reference to the object in charge of parsing supplied pipeline parameters.
|
||||||
|
device (:obj:`int`, `optional`, defaults to :obj:`-1`):
|
||||||
|
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, >=0 will run the model
|
||||||
|
on the associated CUDA device id.
|
||||||
|
"""
|
||||||
|
|
||||||
|
task = "summarization"
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
*documents,
|
||||||
|
return_tensors=False,
|
||||||
|
return_text=True,
|
||||||
|
max_length=142,
|
||||||
|
min_length=21,
|
||||||
|
clean_up_tokenization_spaces=False,
|
||||||
|
**generate_kwargs
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
*documents: (list of strings) articles to be summarized
|
||||||
|
return_text: (bool, default=True) whether to add a decoded "summary_text" to each result
|
||||||
|
return_tensors: (bool, default=False) whether to return the raw "summary_token_ids" to each result
|
||||||
|
|
||||||
|
max_length: (`optional`) int
|
||||||
|
The max length of the sequence to be generated. Does not include tokens in input_ids.
|
||||||
|
min_len: (`optional`) int
|
||||||
|
no_repeat_ngram_size: (`optional`) int. ban ngrams of this length from being repeated in the generated text
|
||||||
|
clean_up_tokenization_spaces: (`optional`) bool whether to include extra spaces in the output
|
||||||
|
**generate_kwargs: extra kwargs passed to `self.model.generate`_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list of dicts with 'summary_text' and/or 'summary_token_ids' for each document_to_summarize
|
||||||
|
|
||||||
|
.. _`self.model.generate`:
|
||||||
|
https://huggingface.co/transformers/model_doc/bart.html#transformers.BartForConditionalGeneration.generate
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True"
|
||||||
|
if self.framework == "tf":
|
||||||
|
raise NotImplementedError("Tensorflow not supported")
|
||||||
|
with self.device_placement():
|
||||||
|
inputs = self._parse_and_tokenize(*documents)
|
||||||
|
inputs = self.ensure_tensor_on_device(**inputs)
|
||||||
|
summaries = self.model.generate(
|
||||||
|
inputs["input_ids"],
|
||||||
|
attention_mask=inputs["attention_mask"],
|
||||||
|
max_length=max_length,
|
||||||
|
min_length=min_length,
|
||||||
|
do_sample=False,
|
||||||
|
**generate_kwargs,
|
||||||
|
)
|
||||||
|
results = []
|
||||||
|
for summary in summaries:
|
||||||
|
record = {}
|
||||||
|
if return_tensors:
|
||||||
|
record["summary_token_ids"] = summary
|
||||||
|
if return_text:
|
||||||
|
record["summary_text"] = self.tokenizer.decode(
|
||||||
|
summary, skip_special_tokens=True, clean_up_tokenization_spaces=clean_up_tokenization_spaces
|
||||||
|
)
|
||||||
|
results.append(record)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
# Register all the supported task here
|
# Register all the supported task here
|
||||||
SUPPORTED_TASKS = {
|
SUPPORTED_TASKS = {
|
||||||
"feature-extraction": {
|
"feature-extraction": {
|
||||||
@@ -1162,6 +1264,16 @@ SUPPORTED_TASKS = {
|
|||||||
"tokenizer": ("distilroberta-base", {"use_fast": False}),
|
"tokenizer": ("distilroberta-base", {"use_fast": False}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"summarization": {
|
||||||
|
"impl": SummarizationPipeline,
|
||||||
|
"pt": BartForConditionalGeneration if is_torch_available() else None,
|
||||||
|
"tf": None,
|
||||||
|
"default": {
|
||||||
|
"model": {"pt": "bart-large-cnn", "tf": None},
|
||||||
|
"config": None,
|
||||||
|
"tokenizer": ("bart-large-cnn", {"use_fast": False}),
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -1253,7 +1365,7 @@ def pipeline(
|
|||||||
|
|
||||||
# Use default model/config/tokenizer for the task if no model is provided
|
# Use default model/config/tokenizer for the task if no model is provided
|
||||||
if model is None:
|
if model is None:
|
||||||
models, config, tokenizer = tuple(targeted_task["default"].values())
|
models, config, tokenizer = [targeted_task["default"][k] for k in ["model", "config", "tokenizer"]]
|
||||||
model = models[framework]
|
model = models[framework]
|
||||||
|
|
||||||
# Try to infer tokenizer from model or config name (if provided as str)
|
# Try to infer tokenizer from model or config name (if provided as str)
|
||||||
|
|||||||
@@ -247,6 +247,16 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
|||||||
expected_check_keys=["sequence"],
|
expected_check_keys=["sequence"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_summarization(self):
|
||||||
|
valid_inputs = ["A string like this", ["list of strings entry 1", "list of strings v2"]]
|
||||||
|
invalid_inputs = [4, "<mask>"]
|
||||||
|
mandatory_keys = ["summary_text"]
|
||||||
|
nlp = pipeline(task="summarization")
|
||||||
|
self._test_mono_column_pipeline(
|
||||||
|
nlp, valid_inputs, invalid_inputs, mandatory_keys,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MultiColumnInputTestCase(unittest.TestCase):
|
class MultiColumnInputTestCase(unittest.TestCase):
|
||||||
def _test_multicolumn_pipeline(self, nlp, valid_inputs: list, invalid_inputs: list, output_keys: Iterable[str]):
|
def _test_multicolumn_pipeline(self, nlp, valid_inputs: list, invalid_inputs: list, output_keys: Iterable[str]):
|
||||||
|
|||||||
Reference in New Issue
Block a user