Fix the behaviour of DefaultArgumentHandler (removing it). (#8180)
* Some work to fix the behaviour of DefaultArgumentHandler by removing it. * Fixing specific pipelines argument checking.
This commit is contained in:
@@ -23,9 +23,8 @@ import uuid
|
|||||||
import warnings
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from itertools import chain
|
|
||||||
from os.path import abspath, exists
|
from os.path import abspath, exists
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -185,57 +184,6 @@ class ArgumentHandler(ABC):
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class DefaultArgumentHandler(ArgumentHandler):
|
|
||||||
"""
|
|
||||||
Default argument parser handling parameters for each :class:`~transformers.pipelines.Pipeline`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def handle_kwargs(kwargs: Dict) -> List:
|
|
||||||
if len(kwargs) == 1:
|
|
||||||
output = list(kwargs.values())
|
|
||||||
else:
|
|
||||||
output = list(chain(kwargs.values()))
|
|
||||||
|
|
||||||
return DefaultArgumentHandler.handle_args(output)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def handle_args(args: Sequence[Any]) -> List[str]:
|
|
||||||
|
|
||||||
# Only one argument, let's do case by case
|
|
||||||
if len(args) == 1:
|
|
||||||
if isinstance(args[0], str):
|
|
||||||
return [args[0]]
|
|
||||||
elif not isinstance(args[0], list):
|
|
||||||
return list(args)
|
|
||||||
else:
|
|
||||||
return args[0]
|
|
||||||
|
|
||||||
# Multiple arguments (x1, x2, ...)
|
|
||||||
elif len(args) > 1:
|
|
||||||
if all([isinstance(arg, str) for arg in args]):
|
|
||||||
return list(args)
|
|
||||||
|
|
||||||
# If not instance of list, then it should instance of iterable
|
|
||||||
elif isinstance(args, Iterable):
|
|
||||||
return list(chain.from_iterable(chain(args)))
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Invalid input type {}. Pipeline supports Union[str, Iterable[str]]".format(type(args))
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return []
|
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
|
||||||
if len(kwargs) > 0 and len(args) > 0:
|
|
||||||
raise ValueError("Pipeline cannot handle mixed args and kwargs")
|
|
||||||
|
|
||||||
if len(kwargs) > 0:
|
|
||||||
return DefaultArgumentHandler.handle_kwargs(kwargs)
|
|
||||||
else:
|
|
||||||
return DefaultArgumentHandler.handle_args(args)
|
|
||||||
|
|
||||||
|
|
||||||
class PipelineDataFormat:
|
class PipelineDataFormat:
|
||||||
"""
|
"""
|
||||||
Base class for all the pipeline supported data format both for reading and writing. Supported data formats
|
Base class for all the pipeline supported data format both for reading and writing. Supported data formats
|
||||||
@@ -574,7 +522,6 @@ class Pipeline(_ScikitCompat):
|
|||||||
self.framework = framework
|
self.framework = framework
|
||||||
self.device = device if framework == "tf" else torch.device("cpu" if device < 0 else "cuda:{}".format(device))
|
self.device = device if framework == "tf" else torch.device("cpu" if device < 0 else "cuda:{}".format(device))
|
||||||
self.binary_output = binary_output
|
self.binary_output = binary_output
|
||||||
self._args_parser = args_parser or DefaultArgumentHandler()
|
|
||||||
|
|
||||||
# Special handling
|
# Special handling
|
||||||
if self.framework == "pt" and self.device.type == "cuda":
|
if self.framework == "pt" and self.device.type == "cuda":
|
||||||
@@ -669,12 +616,11 @@ class Pipeline(_ScikitCompat):
|
|||||||
f"The model '{self.model.__class__.__name__}' is not supported for {self.task}. Supported models are {supported_models}",
|
f"The model '{self.model.__class__.__name__}' is not supported for {self.task}. Supported models are {supported_models}",
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs):
|
def _parse_and_tokenize(self, inputs, padding=True, add_special_tokens=True, **kwargs):
|
||||||
"""
|
"""
|
||||||
Parse arguments and tokenize
|
Parse arguments and tokenize
|
||||||
"""
|
"""
|
||||||
# Parse arguments
|
# Parse arguments
|
||||||
inputs = self._args_parser(*args, **kwargs)
|
|
||||||
inputs = self.tokenizer(
|
inputs = self.tokenizer(
|
||||||
inputs,
|
inputs,
|
||||||
add_special_tokens=add_special_tokens,
|
add_special_tokens=add_special_tokens,
|
||||||
@@ -836,7 +782,7 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
|
|
||||||
# overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments
|
# overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments
|
||||||
|
|
||||||
def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs):
|
def _parse_and_tokenize(self, inputs, padding=True, add_special_tokens=True, **kwargs):
|
||||||
"""
|
"""
|
||||||
Parse arguments and tokenize
|
Parse arguments and tokenize
|
||||||
"""
|
"""
|
||||||
@@ -845,7 +791,6 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
tokenizer_kwargs = {"add_space_before_punct_symbol": True}
|
tokenizer_kwargs = {"add_space_before_punct_symbol": True}
|
||||||
else:
|
else:
|
||||||
tokenizer_kwargs = {}
|
tokenizer_kwargs = {}
|
||||||
inputs = self._args_parser(*args, **kwargs)
|
|
||||||
inputs = self.tokenizer(
|
inputs = self.tokenizer(
|
||||||
inputs,
|
inputs,
|
||||||
add_special_tokens=add_special_tokens,
|
add_special_tokens=add_special_tokens,
|
||||||
@@ -858,7 +803,7 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
*args,
|
text_inputs,
|
||||||
return_tensors=False,
|
return_tensors=False,
|
||||||
return_text=True,
|
return_text=True,
|
||||||
clean_up_tokenization_spaces=False,
|
clean_up_tokenization_spaces=False,
|
||||||
@@ -890,7 +835,6 @@ class TextGenerationPipeline(Pipeline):
|
|||||||
- **generated_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``)
|
- **generated_token_ids** (:obj:`torch.Tensor` or :obj:`tf.Tensor`, present when ``return_tensors=True``)
|
||||||
-- The token ids of the generated text.
|
-- The token ids of the generated text.
|
||||||
"""
|
"""
|
||||||
text_inputs = self._args_parser(*args)
|
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for prompt_text in text_inputs:
|
for prompt_text in text_inputs:
|
||||||
@@ -1094,7 +1038,8 @@ class ZeroShotClassificationPipeline(Pipeline):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, args_parser=ZeroShotClassificationArgumentHandler(), *args, **kwargs):
|
def __init__(self, args_parser=ZeroShotClassificationArgumentHandler(), *args, **kwargs):
|
||||||
super().__init__(*args, args_parser=args_parser, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
self._args_parser = args_parser
|
||||||
if self.entailment_id == -1:
|
if self.entailment_id == -1:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Failed to determine 'entailment' label id from the label2id mapping in the model config. Setting to "
|
"Failed to determine 'entailment' label id from the label2id mapping in the model config. Setting to "
|
||||||
@@ -1108,13 +1053,15 @@ class ZeroShotClassificationPipeline(Pipeline):
|
|||||||
return ind
|
return ind
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs):
|
def _parse_and_tokenize(
|
||||||
|
self, sequences, candidal_labels, hypothesis_template, padding=True, add_special_tokens=True, **kwargs
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Parse arguments and tokenize only_first so that hypothesis (label) is not truncated
|
Parse arguments and tokenize only_first so that hypothesis (label) is not truncated
|
||||||
"""
|
"""
|
||||||
inputs = self._args_parser(*args, **kwargs)
|
sequence_pairs = self._args_parser(sequences, candidal_labels, hypothesis_template)
|
||||||
inputs = self.tokenizer(
|
inputs = self.tokenizer(
|
||||||
inputs,
|
sequence_pairs,
|
||||||
add_special_tokens=add_special_tokens,
|
add_special_tokens=add_special_tokens,
|
||||||
return_tensors=self.framework,
|
return_tensors=self.framework,
|
||||||
padding=padding,
|
padding=padding,
|
||||||
@@ -1123,7 +1070,13 @@ class ZeroShotClassificationPipeline(Pipeline):
|
|||||||
|
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
def __call__(self, sequences, candidate_labels, hypothesis_template="This example is {}.", multi_class=False):
|
def __call__(
|
||||||
|
self,
|
||||||
|
sequences: Union[str, List[str]],
|
||||||
|
candidate_labels,
|
||||||
|
hypothesis_template="This example is {}.",
|
||||||
|
multi_class=False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Classify the sequence(s) given as inputs. See the :obj:`~transformers.ZeroShotClassificationPipeline`
|
Classify the sequence(s) given as inputs. See the :obj:`~transformers.ZeroShotClassificationPipeline`
|
||||||
documentation for more information.
|
documentation for more information.
|
||||||
@@ -1154,8 +1107,11 @@ class ZeroShotClassificationPipeline(Pipeline):
|
|||||||
- **labels** (:obj:`List[str]`) -- The labels sorted by order of likelihood.
|
- **labels** (:obj:`List[str]`) -- The labels sorted by order of likelihood.
|
||||||
- **scores** (:obj:`List[float]`) -- The probabilities for each of the labels.
|
- **scores** (:obj:`List[float]`) -- The probabilities for each of the labels.
|
||||||
"""
|
"""
|
||||||
|
if sequences and isinstance(sequences, str):
|
||||||
|
sequences = [sequences]
|
||||||
|
|
||||||
outputs = super().__call__(sequences, candidate_labels, hypothesis_template)
|
outputs = super().__call__(sequences, candidate_labels, hypothesis_template)
|
||||||
num_sequences = 1 if isinstance(sequences, str) else len(sequences)
|
num_sequences = len(sequences)
|
||||||
candidate_labels = self._args_parser._parse_labels(candidate_labels)
|
candidate_labels = self._args_parser._parse_labels(candidate_labels)
|
||||||
reshaped_outputs = outputs.reshape((num_sequences, len(candidate_labels), -1))
|
reshaped_outputs = outputs.reshape((num_sequences, len(candidate_labels), -1))
|
||||||
|
|
||||||
@@ -1425,12 +1381,12 @@ class TokenClassificationPipeline(Pipeline):
|
|||||||
self.ignore_labels = ignore_labels
|
self.ignore_labels = ignore_labels
|
||||||
self.grouped_entities = grouped_entities
|
self.grouped_entities = grouped_entities
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, inputs: Union[str, List[str]], **kwargs):
|
||||||
"""
|
"""
|
||||||
Classify each token of the text(s) given as inputs.
|
Classify each token of the text(s) given as inputs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
args (:obj:`str` or :obj:`List[str]`):
|
inputs (:obj:`str` or :obj:`List[str]`):
|
||||||
One or several texts (or one list of texts) for token classification.
|
One or several texts (or one list of texts) for token classification.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
@@ -1444,7 +1400,8 @@ class TokenClassificationPipeline(Pipeline):
|
|||||||
- **index** (:obj:`int`, only present when ``self.grouped_entities=False``) -- The index of the
|
- **index** (:obj:`int`, only present when ``self.grouped_entities=False``) -- The index of the
|
||||||
corresponding token in the sentence.
|
corresponding token in the sentence.
|
||||||
"""
|
"""
|
||||||
inputs = self._args_parser(*args, **kwargs)
|
if isinstance(inputs, str):
|
||||||
|
inputs = [inputs]
|
||||||
answers = []
|
answers = []
|
||||||
for sentence in inputs:
|
for sentence in inputs:
|
||||||
|
|
||||||
@@ -1659,12 +1616,12 @@ class QuestionAnsweringPipeline(Pipeline):
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
modelcard=modelcard,
|
modelcard=modelcard,
|
||||||
framework=framework,
|
framework=framework,
|
||||||
args_parser=QuestionAnsweringArgumentHandler(),
|
|
||||||
device=device,
|
device=device,
|
||||||
task=task,
|
task=task,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._args_parser = QuestionAnsweringArgumentHandler()
|
||||||
self.check_model_type(
|
self.check_model_type(
|
||||||
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING if self.framework == "tf" else MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING if self.framework == "tf" else MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
||||||
)
|
)
|
||||||
@@ -2489,12 +2446,11 @@ class ConversationalPipeline(Pipeline):
|
|||||||
else:
|
else:
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def _parse_and_tokenize(self, *args, **kwargs):
|
def _parse_and_tokenize(self, inputs, **kwargs):
|
||||||
"""
|
"""
|
||||||
Parse arguments and tokenize, adding an EOS token at the end of the user input
|
Parse arguments and tokenize, adding an EOS token at the end of the user input
|
||||||
"""
|
"""
|
||||||
# Parse arguments
|
# Parse arguments
|
||||||
inputs = self._args_parser(*args, **kwargs)
|
|
||||||
inputs = self.tokenizer(inputs, add_special_tokens=False, padding=False).get("input_ids", [])
|
inputs = self.tokenizer(inputs, add_special_tokens=False, padding=False).get("input_ids", [])
|
||||||
for input in inputs:
|
for input in inputs:
|
||||||
input.append(self.tokenizer.eos_token_id)
|
input.append(self.tokenizer.eos_token_id)
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
import unittest
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from transformers import is_tf_available, is_torch_available, pipeline
|
from transformers import is_tf_available, is_torch_available, pipeline
|
||||||
from transformers.pipelines import DefaultArgumentHandler, Pipeline
|
|
||||||
|
# from transformers.pipelines import DefaultArgumentHandler, Pipeline
|
||||||
|
from transformers.pipelines import Pipeline
|
||||||
from transformers.testing_utils import _run_slow_tests, is_pipeline_test, require_tf, require_torch, slow
|
from transformers.testing_utils import _run_slow_tests, is_pipeline_test, require_tf, require_torch, slow
|
||||||
|
|
||||||
|
|
||||||
@@ -200,74 +201,74 @@ class MonoInputPipelineCommonMixin:
|
|||||||
self.assertRaises(Exception, nlp, self.invalid_inputs)
|
self.assertRaises(Exception, nlp, self.invalid_inputs)
|
||||||
|
|
||||||
|
|
||||||
@is_pipeline_test
|
# @is_pipeline_test
|
||||||
class DefaultArgumentHandlerTestCase(unittest.TestCase):
|
# class DefaultArgumentHandlerTestCase(unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
# def setUp(self) -> None:
|
||||||
self.handler = DefaultArgumentHandler()
|
# self.handler = DefaultArgumentHandler()
|
||||||
|
#
|
||||||
def test_kwargs_x(self):
|
# def test_kwargs_x(self):
|
||||||
mono_data = {"X": "This is a sample input"}
|
# mono_data = {"X": "This is a sample input"}
|
||||||
mono_args = self.handler(**mono_data)
|
# mono_args = self.handler(**mono_data)
|
||||||
|
#
|
||||||
self.assertTrue(isinstance(mono_args, list))
|
# self.assertTrue(isinstance(mono_args, list))
|
||||||
self.assertEqual(len(mono_args), 1)
|
# self.assertEqual(len(mono_args), 1)
|
||||||
|
#
|
||||||
multi_data = {"x": ["This is a sample input", "This is a second sample input"]}
|
# multi_data = {"x": ["This is a sample input", "This is a second sample input"]}
|
||||||
multi_args = self.handler(**multi_data)
|
# multi_args = self.handler(**multi_data)
|
||||||
|
#
|
||||||
self.assertTrue(isinstance(multi_args, list))
|
# self.assertTrue(isinstance(multi_args, list))
|
||||||
self.assertEqual(len(multi_args), 2)
|
# self.assertEqual(len(multi_args), 2)
|
||||||
|
#
|
||||||
def test_kwargs_data(self):
|
# def test_kwargs_data(self):
|
||||||
mono_data = {"data": "This is a sample input"}
|
# mono_data = {"data": "This is a sample input"}
|
||||||
mono_args = self.handler(**mono_data)
|
# mono_args = self.handler(**mono_data)
|
||||||
|
#
|
||||||
self.assertTrue(isinstance(mono_args, list))
|
# self.assertTrue(isinstance(mono_args, list))
|
||||||
self.assertEqual(len(mono_args), 1)
|
# self.assertEqual(len(mono_args), 1)
|
||||||
|
#
|
||||||
multi_data = {"data": ["This is a sample input", "This is a second sample input"]}
|
# multi_data = {"data": ["This is a sample input", "This is a second sample input"]}
|
||||||
multi_args = self.handler(**multi_data)
|
# multi_args = self.handler(**multi_data)
|
||||||
|
#
|
||||||
self.assertTrue(isinstance(multi_args, list))
|
# self.assertTrue(isinstance(multi_args, list))
|
||||||
self.assertEqual(len(multi_args), 2)
|
# self.assertEqual(len(multi_args), 2)
|
||||||
|
#
|
||||||
def test_multi_kwargs(self):
|
# def test_multi_kwargs(self):
|
||||||
mono_data = {"data": "This is a sample input", "X": "This is a sample input 2"}
|
# mono_data = {"data": "This is a sample input", "X": "This is a sample input 2"}
|
||||||
mono_args = self.handler(**mono_data)
|
# mono_args = self.handler(**mono_data)
|
||||||
|
#
|
||||||
self.assertTrue(isinstance(mono_args, list))
|
# self.assertTrue(isinstance(mono_args, list))
|
||||||
self.assertEqual(len(mono_args), 2)
|
# self.assertEqual(len(mono_args), 2)
|
||||||
|
#
|
||||||
multi_data = {
|
# multi_data = {
|
||||||
"data": ["This is a sample input", "This is a second sample input"],
|
# "data": ["This is a sample input", "This is a second sample input"],
|
||||||
"test": ["This is a sample input 2", "This is a second sample input 2"],
|
# "test": ["This is a sample input 2", "This is a second sample input 2"],
|
||||||
}
|
# }
|
||||||
multi_args = self.handler(**multi_data)
|
# multi_args = self.handler(**multi_data)
|
||||||
|
#
|
||||||
self.assertTrue(isinstance(multi_args, list))
|
# self.assertTrue(isinstance(multi_args, list))
|
||||||
self.assertEqual(len(multi_args), 4)
|
# self.assertEqual(len(multi_args), 4)
|
||||||
|
#
|
||||||
def test_args(self):
|
# def test_args(self):
|
||||||
mono_data = "This is a sample input"
|
# mono_data = "This is a sample input"
|
||||||
mono_args = self.handler(mono_data)
|
# mono_args = self.handler(mono_data)
|
||||||
|
#
|
||||||
self.assertTrue(isinstance(mono_args, list))
|
# self.assertTrue(isinstance(mono_args, list))
|
||||||
self.assertEqual(len(mono_args), 1)
|
# self.assertEqual(len(mono_args), 1)
|
||||||
|
#
|
||||||
mono_data = ["This is a sample input"]
|
# mono_data = ["This is a sample input"]
|
||||||
mono_args = self.handler(mono_data)
|
# mono_args = self.handler(mono_data)
|
||||||
|
#
|
||||||
self.assertTrue(isinstance(mono_args, list))
|
# self.assertTrue(isinstance(mono_args, list))
|
||||||
self.assertEqual(len(mono_args), 1)
|
# self.assertEqual(len(mono_args), 1)
|
||||||
|
#
|
||||||
multi_data = ["This is a sample input", "This is a second sample input"]
|
# multi_data = ["This is a sample input", "This is a second sample input"]
|
||||||
multi_args = self.handler(multi_data)
|
# multi_args = self.handler(multi_data)
|
||||||
|
#
|
||||||
self.assertTrue(isinstance(multi_args, list))
|
# self.assertTrue(isinstance(multi_args, list))
|
||||||
self.assertEqual(len(multi_args), 2)
|
# self.assertEqual(len(multi_args), 2)
|
||||||
|
#
|
||||||
multi_data = ["This is a sample input", "This is a second sample input"]
|
# multi_data = ["This is a sample input", "This is a second sample input"]
|
||||||
multi_args = self.handler(*multi_data)
|
# multi_args = self.handler(*multi_data)
|
||||||
|
#
|
||||||
self.assertTrue(isinstance(multi_args, list))
|
# self.assertTrue(isinstance(multi_args, list))
|
||||||
self.assertEqual(len(multi_args), 2)
|
# self.assertEqual(len(multi_args), 2)
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from transformers import pipeline
|
from transformers import pipeline
|
||||||
from transformers.testing_utils import require_tf, require_torch, slow
|
from transformers.testing_utils import require_tf, require_torch, slow
|
||||||
|
|
||||||
@@ -37,7 +39,7 @@ EXPECTED_FILL_MASK_TARGET_RESULT = [
|
|||||||
|
|
||||||
class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
||||||
pipeline_task = "fill-mask"
|
pipeline_task = "fill-mask"
|
||||||
pipeline_loading_kwargs = {"topk": 2}
|
pipeline_loading_kwargs = {"top_k": 2}
|
||||||
small_models = ["sshleifer/tiny-distilroberta-base"] # Models tested without the @slow decorator
|
small_models = ["sshleifer/tiny-distilroberta-base"] # Models tested without the @slow decorator
|
||||||
large_models = ["distilroberta-base"] # Models tested with the @slow decorator
|
large_models = ["distilroberta-base"] # Models tested with the @slow decorator
|
||||||
mandatory_keys = {"sequence", "score", "token"}
|
mandatory_keys = {"sequence", "score", "token"}
|
||||||
@@ -51,6 +53,28 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
|||||||
]
|
]
|
||||||
expected_check_keys = ["sequence"]
|
expected_check_keys = ["sequence"]
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_torch_topk_deprecation(self):
|
||||||
|
# At pipeline initialization only it was not enabled at pipeline
|
||||||
|
# call site before
|
||||||
|
with pytest.warns(FutureWarning, match=r".*use `top_k`.*"):
|
||||||
|
pipeline(task="fill-mask", model=self.small_models[0], topk=1)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_torch_fill_mask(self):
|
||||||
|
valid_inputs = "My name is <mask>"
|
||||||
|
nlp = pipeline(task="fill-mask", model=self.small_models[0])
|
||||||
|
outputs = nlp(valid_inputs)
|
||||||
|
self.assertIsInstance(outputs, list)
|
||||||
|
|
||||||
|
# This passes
|
||||||
|
outputs = nlp(valid_inputs, targets=[" Patrick", " Clara"])
|
||||||
|
self.assertIsInstance(outputs, list)
|
||||||
|
|
||||||
|
# This used to fail with `cannot mix args and kwargs`
|
||||||
|
outputs = nlp(valid_inputs, something=False)
|
||||||
|
self.assertIsInstance(outputs, list)
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_torch_fill_mask_with_targets(self):
|
def test_torch_fill_mask_with_targets(self):
|
||||||
valid_inputs = ["My name is <mask>"]
|
valid_inputs = ["My name is <mask>"]
|
||||||
@@ -94,7 +118,7 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
|||||||
model=model_name,
|
model=model_name,
|
||||||
tokenizer=model_name,
|
tokenizer=model_name,
|
||||||
framework="pt",
|
framework="pt",
|
||||||
topk=2,
|
top_k=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
mono_result = nlp(valid_inputs[0], targets=valid_targets)
|
mono_result = nlp(valid_inputs[0], targets=valid_targets)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ class ZeroShotClassificationPipelineTests(CustomInputPipelineCommonMixin, unitte
|
|||||||
sum = 0.0
|
sum = 0.0
|
||||||
for score in result["scores"]:
|
for score in result["scores"]:
|
||||||
sum += score
|
sum += score
|
||||||
self.assertAlmostEqual(sum, 1.0)
|
self.assertAlmostEqual(sum, 1.0, places=5)
|
||||||
|
|
||||||
def _test_entailment_id(self, nlp: Pipeline):
|
def _test_entailment_id(self, nlp: Pipeline):
|
||||||
config = nlp.model.config
|
config = nlp.model.config
|
||||||
|
|||||||
Reference in New Issue
Block a user