From 38a555a83c8aceae77895d325174af5bd576cec7 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Tue, 17 Mar 2020 18:04:21 -0400 Subject: [PATCH] Add Summarization to Pipelines (#3128) * passing * Undo stupid chg * docs * undo rename * delete-cruft * only import if you have torch * Dont rely on dict ordering * Fix dict ordering upstream * docstring link * docstring link * remove trailing comma for 3.5 compat * new name * delegate kwarging * Update kwargs --- docs/source/main_classes/pipelines.rst | 5 ++ src/transformers/__init__.py | 1 + src/transformers/pipelines.py | 114 ++++++++++++++++++++++++- tests/test_pipelines.py | 10 +++ 4 files changed, 129 insertions(+), 1 deletion(-) diff --git a/docs/source/main_classes/pipelines.rst b/docs/source/main_classes/pipelines.rst index 9e8bfa8af8..85bf673d85 100644 --- a/docs/source/main_classes/pipelines.rst +++ b/docs/source/main_classes/pipelines.rst @@ -61,3 +61,8 @@ QuestionAnsweringPipeline .. autoclass:: transformers.QuestionAnsweringPipeline + +SummarizationPipeline +========================================== + +.. autoclass:: transformers.SummarizationPipeline diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 1c9e1ac4c7..a09096aa69 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -113,6 +113,7 @@ from .pipelines import ( Pipeline, PipelineDataFormat, QuestionAnsweringPipeline, + SummarizationPipeline, TextClassificationPipeline, TokenClassificationPipeline, pipeline, diff --git a/src/transformers/pipelines.py b/src/transformers/pipelines.py index 54745b05a3..446cd2d57d 100755 --- a/src/transformers/pipelines.py +++ b/src/transformers/pipelines.py @@ -60,6 +60,7 @@ if is_torch_available(): AutoModelForTokenClassification, AutoModelWithLMHead, ) + from .modeling_bart import BartForConditionalGeneration logger = logging.getLogger(__name__) @@ -1104,6 +1105,107 @@ class QuestionAnsweringPipeline(Pipeline): return {"answer": " ".join(words), "start": max(0, char_start_idx), "end": min(len(text), char_end_idx)} +class SummarizationPipeline(Pipeline): + """ + Summarize news articles and other documents + + Usage:: + + summarizer = pipeline("summarization") + summarizer("Sam Shleifer writes the best docstring examples in the whole world.") + + Supported Models: + The models that this pipeline can use are models that have been fine-tuned on a summarization task, which is + currently only ``BartForConditionalGeneration.from_pretrained('bart-large-cnn')`` + + Arguments: + model (:obj:`str` or :obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`, `optional`, defaults to :obj:`None`): + The model that will be used by the pipeline to make predictions. This can be :obj:`None`, a string + checkpoint identifier or an actual pre-trained model inheriting from + :class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for + TensorFlow. + + If :obj:`None`, the default of the pipeline will be loaded. + tokenizer (:obj:`str` or :obj:`~transformers.PreTrainedTokenizer`, `optional`, defaults to :obj:`None`): + The tokenizer that will be used by the pipeline to encode data for the model. This can be :obj:`None`, + a string checkpoint identifier or an actual pre-trained tokenizer inheriting from + :class:`~transformers.PreTrainedTokenizer`. + + If :obj:`None`, the default of the pipeline will be loaded. + modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`, defaults to :obj:`None`): + Model card attributed to the model for this pipeline. + framework (:obj:`str`, `optional`, defaults to :obj:`None`): + The framework to use, either "pt" for PyTorch or "tf" for TensorFlow. The specified framework must be + installed. + + If no framework is specified, will default to the one currently installed. If no framework is specified + and both frameworks are installed, will default to PyTorch. + args_parser (:class:`~transformers.pipelines.ArgumentHandler`, `optional`, defaults to :obj:`None`): + Reference to the object in charge of parsing supplied pipeline parameters. + device (:obj:`int`, `optional`, defaults to :obj:`-1`): + Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, >=0 will run the model + on the associated CUDA device id. + """ + + task = "summarization" + + def __call__( + self, + *documents, + return_tensors=False, + return_text=True, + max_length=142, + min_length=21, + clean_up_tokenization_spaces=False, + **generate_kwargs + ): + r""" + Args: + *documents: (list of strings) articles to be summarized + return_text: (bool, default=True) whether to add a decoded "summary_text" to each result + return_tensors: (bool, default=False) whether to return the raw "summary_token_ids" to each result + + max_length: (`optional`) int + The max length of the sequence to be generated. Does not include tokens in input_ids. + min_len: (`optional`) int + no_repeat_ngram_size: (`optional`) int. ban ngrams of this length from being repeated in the generated text + clean_up_tokenization_spaces: (`optional`) bool whether to include extra spaces in the output + **generate_kwargs: extra kwargs passed to `self.model.generate`_ + + Returns: + list of dicts with 'summary_text' and/or 'summary_token_ids' for each document_to_summarize + + .. _`self.model.generate`: + https://huggingface.co/transformers/model_doc/bart.html#transformers.BartForConditionalGeneration.generate + + """ + assert return_tensors or return_text, "You must specify return_tensors=True or return_text=True" + if self.framework == "tf": + raise NotImplementedError("Tensorflow not supported") + with self.device_placement(): + inputs = self._parse_and_tokenize(*documents) + inputs = self.ensure_tensor_on_device(**inputs) + summaries = self.model.generate( + inputs["input_ids"], + attention_mask=inputs["attention_mask"], + max_length=max_length, + min_length=min_length, + do_sample=False, + **generate_kwargs, + ) + results = [] + for summary in summaries: + record = {} + if return_tensors: + record["summary_token_ids"] = summary + if return_text: + record["summary_text"] = self.tokenizer.decode( + summary, skip_special_tokens=True, clean_up_tokenization_spaces=clean_up_tokenization_spaces + ) + results.append(record) + return results + + # Register all the supported task here SUPPORTED_TASKS = { "feature-extraction": { @@ -1162,6 +1264,16 @@ SUPPORTED_TASKS = { "tokenizer": ("distilroberta-base", {"use_fast": False}), }, }, + "summarization": { + "impl": SummarizationPipeline, + "pt": BartForConditionalGeneration if is_torch_available() else None, + "tf": None, + "default": { + "model": {"pt": "bart-large-cnn", "tf": None}, + "config": None, + "tokenizer": ("bart-large-cnn", {"use_fast": False}), + }, + }, } @@ -1253,7 +1365,7 @@ def pipeline( # Use default model/config/tokenizer for the task if no model is provided if model is None: - models, config, tokenizer = tuple(targeted_task["default"].values()) + models, config, tokenizer = [targeted_task["default"][k] for k in ["model", "config", "tokenizer"]] model = models[framework] # Try to infer tokenizer from model or config name (if provided as str) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 09776d63cf..fb5e78ae9a 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -247,6 +247,16 @@ class MonoColumnInputTestCase(unittest.TestCase): expected_check_keys=["sequence"], ) + @require_torch + def test_summarization(self): + valid_inputs = ["A string like this", ["list of strings entry 1", "list of strings v2"]] + invalid_inputs = [4, ""] + mandatory_keys = ["summary_text"] + nlp = pipeline(task="summarization") + self._test_mono_column_pipeline( + nlp, valid_inputs, invalid_inputs, mandatory_keys, + ) + class MultiColumnInputTestCase(unittest.TestCase): def _test_multicolumn_pipeline(self, nlp, valid_inputs: list, invalid_inputs: list, output_keys: Iterable[str]):