From 972535ea74c7b30987bc31c6621a2bbb58f82ca6 Mon Sep 17 00:00:00 2001 From: Joe Davison Date: Tue, 4 Aug 2020 16:37:49 -0400 Subject: [PATCH] fix zero shot pipeline docs (#6245) --- docs/source/main_classes/pipelines.rst | 8 ++++++++ src/transformers/__init__.py | 1 + src/transformers/pipelines.py | 11 ++++++----- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/docs/source/main_classes/pipelines.rst b/docs/source/main_classes/pipelines.rst index 067b7eca93..6bcbd399e1 100644 --- a/docs/source/main_classes/pipelines.rst +++ b/docs/source/main_classes/pipelines.rst @@ -20,6 +20,7 @@ There are two categories of pipeline abstractions to be aware about: - :class:`~transformers.TextGenerationPipeline` - :class:`~transformers.TokenClassificationPipeline` - :class:`~transformers.TranslationPipeline` + - :class:`~transformers.ZeroShotClassificationPipeline` The pipeline abstraction ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -97,6 +98,13 @@ TokenClassificationPipeline :special-members: __call__ :members: +ZeroShotClassificationPipeline +========================================== + +.. autoclass:: transformers.ZeroShotClassificationPipeline + :special-members: __call__ + :members: + Parent class: :obj:`Pipeline` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 18f6d72cef..f14f032d19 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -120,6 +120,7 @@ from .pipelines import ( TextGenerationPipeline, TokenClassificationPipeline, TranslationPipeline, + ZeroShotClassificationPipeline, pipeline, ) diff --git a/src/transformers/pipelines.py b/src/transformers/pipelines.py index 3cd252fd8f..8538233b39 100755 --- a/src/transformers/pipelines.py +++ b/src/transformers/pipelines.py @@ -1033,30 +1033,31 @@ class ZeroShotClassificationPipeline(Pipeline): Classify the sequence(s) given as inputs. Args: - sequences (:obj:`str` or obj:`List[str]`): + sequences (:obj:`str` or :obj:`List[str]`): The sequence(s) to classify, will be truncated if the model input is too large. - candidate_labels (:obj:`str` or obj:`List[str]`): + candidate_labels (:obj:`str` or :obj:`List[str]`): The set of possible class labels to classify each sequence into. Can be a single label, a string of comma-separated labels, or a list of labels. - hypothesis_template (obj:`str`, `optional`, defaults to :obj:`"This example is {}."`): + hypothesis_template (:obj:`str`, `optional`, defaults to :obj:`"This example is {}."`): The template used to turn each label into an NLI-style hypothesis. This template must include a {} or similar syntax for the candidate label to be inserted into the template. For example, the default template is :obj:`"This example is {}."` With the candidate label :obj:`"sports"`, this would be fed into the model like :obj:`" sequence to classify This example is sports . "`. The default template works well in many cases, but it may be worthwhile to experiment with different templates depending on the task setting. - multi_class (obj:`bool`, `optional`, defaults to :obj:`False`): + multi_class (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not multiple candidate labels can be true. If :obj:`False`, the scores are normalized such that the sum of the label likelihoods for each sequence is 1. If :obj:`True`, the labels are considered independent and probabilities are normalized for each candidate by doing a softmax of the entailment score vs. the contradiction score. + Return: A :obj:`dict` or a list of :obj:`dict`: Each result comes as a dictionary with the following keys: - **sequence** (:obj:`str`) -- The sequence for which this is the output. - **labels** (:obj:`List[str]`) -- The labels sorted by order of likelihood. - - **scores** (:obj:` List[float]`) -- The probabilities for each of the labels. + - **scores** (:obj:`List[float]`) -- The probabilities for each of the labels. """ outputs = super().__call__(sequences, candidate_labels, hypothesis_template) num_sequences = 1 if isinstance(sequences, str) else len(sequences)