Moving fill-mask pipeline to new testing scheme (#12943)
* Fill mask pipelines test updates. * Model eval !! * Adding slow test with actual values. * Making all tests pass (skipping quite a bit.) * Doc styling. * Better doc cleanup. * Making an explicit test with no pad token tokenizer. * Typo.
This commit is contained in:
@@ -748,6 +748,8 @@ class Pipeline(_ScikitCompat):
|
|||||||
Parse arguments and tokenize
|
Parse arguments and tokenize
|
||||||
"""
|
"""
|
||||||
# Parse arguments
|
# Parse arguments
|
||||||
|
if self.tokenizer.pad_token is None:
|
||||||
|
padding = False
|
||||||
inputs = self.tokenizer(
|
inputs = self.tokenizer(
|
||||||
inputs,
|
inputs,
|
||||||
add_special_tokens=add_special_tokens,
|
add_special_tokens=add_special_tokens,
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import TYPE_CHECKING, Optional, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -9,6 +9,8 @@ from ..utils import logging
|
|||||||
from .base import PIPELINE_INIT_ARGS, ArgumentHandler, Pipeline, PipelineException
|
from .base import PIPELINE_INIT_ARGS, ArgumentHandler, Pipeline, PipelineException
|
||||||
|
|
||||||
|
|
||||||
|
GenericTensor = Union[List["GenericTensor"], "torch.Tensor", "tf.Tensor"]
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..modeling_tf_utils import TFPreTrainedModel
|
from ..modeling_tf_utils import TFPreTrainedModel
|
||||||
from ..modeling_utils import PreTrainedModel
|
from ..modeling_utils import PreTrainedModel
|
||||||
@@ -30,7 +32,13 @@ logger = logging.get_logger(__name__)
|
|||||||
@add_end_docstrings(
|
@add_end_docstrings(
|
||||||
PIPELINE_INIT_ARGS,
|
PIPELINE_INIT_ARGS,
|
||||||
r"""
|
r"""
|
||||||
top_k (:obj:`int`, defaults to 5): The number of predictions to return.
|
top_k (:obj:`int`, defaults to 5):
|
||||||
|
The number of predictions to return.
|
||||||
|
targets (:obj:`str` or :obj:`List[str]`, `optional`):
|
||||||
|
When passed, the model will limit the scores to the passed targets instead of looking up in the whole
|
||||||
|
vocab. If the provided targets are not in the model vocab, they will be tokenized and the first resulting
|
||||||
|
token will be used (with a warning, and that might be slower).
|
||||||
|
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
class FillMaskPipeline(Pipeline):
|
class FillMaskPipeline(Pipeline):
|
||||||
@@ -59,6 +67,7 @@ class FillMaskPipeline(Pipeline):
|
|||||||
args_parser: ArgumentHandler = None,
|
args_parser: ArgumentHandler = None,
|
||||||
device: int = -1,
|
device: int = -1,
|
||||||
top_k=5,
|
top_k=5,
|
||||||
|
targets=None,
|
||||||
task: str = "",
|
task: str = "",
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -74,8 +83,23 @@ class FillMaskPipeline(Pipeline):
|
|||||||
|
|
||||||
self.check_model_type(TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_MASKED_LM_MAPPING)
|
self.check_model_type(TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_MASKED_LM_MAPPING)
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
|
self.targets = targets
|
||||||
|
if self.tokenizer.mask_token_id is None:
|
||||||
|
raise PipelineException(
|
||||||
|
"fill-mask", self.model.base_model_prefix, "The tokenizer does not define a `mask_token`."
|
||||||
|
)
|
||||||
|
|
||||||
def ensure_exactly_one_mask_token(self, masked_index: np.ndarray):
|
def get_masked_index(self, input_ids: GenericTensor) -> np.ndarray:
|
||||||
|
if self.framework == "tf":
|
||||||
|
masked_index = tf.where(input_ids == self.tokenizer.mask_token_id).numpy()
|
||||||
|
elif self.framework == "pt":
|
||||||
|
masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported framework")
|
||||||
|
return masked_index
|
||||||
|
|
||||||
|
def _ensure_exactly_one_mask_token(self, input_ids: GenericTensor) -> np.ndarray:
|
||||||
|
masked_index = self.get_masked_index(input_ids)
|
||||||
numel = np.prod(masked_index.shape)
|
numel = np.prod(masked_index.shape)
|
||||||
if numel > 1:
|
if numel > 1:
|
||||||
raise PipelineException(
|
raise PipelineException(
|
||||||
@@ -90,7 +114,25 @@ class FillMaskPipeline(Pipeline):
|
|||||||
f"No mask_token ({self.tokenizer.mask_token}) found on the input",
|
f"No mask_token ({self.tokenizer.mask_token}) found on the input",
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, *args, targets=None, top_k: Optional[int] = None, **kwargs):
|
def ensure_exactly_one_mask_token(self, model_inputs: GenericTensor):
|
||||||
|
if isinstance(model_inputs, list):
|
||||||
|
for model_input in model_inputs:
|
||||||
|
self._ensure_exactly_one_mask_token(model_input["input_ids"][0])
|
||||||
|
else:
|
||||||
|
for input_ids in model_inputs["input_ids"]:
|
||||||
|
self._ensure_exactly_one_mask_token(input_ids)
|
||||||
|
|
||||||
|
def get_model_inputs(self, inputs, *args, **kwargs) -> Dict:
|
||||||
|
if isinstance(inputs, list) and self.tokenizer.pad_token is None:
|
||||||
|
model_inputs = []
|
||||||
|
for input_ in inputs:
|
||||||
|
model_input = self._parse_and_tokenize(input_, padding=False, *args, **kwargs)
|
||||||
|
model_inputs.append(model_input)
|
||||||
|
else:
|
||||||
|
model_inputs = self._parse_and_tokenize(inputs, *args, **kwargs)
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
def __call__(self, inputs, *args, targets=None, top_k: Optional[int] = None, **kwargs):
|
||||||
"""
|
"""
|
||||||
Fill the masked token in the text(s) given as inputs.
|
Fill the masked token in the text(s) given as inputs.
|
||||||
|
|
||||||
@@ -112,16 +154,27 @@ class FillMaskPipeline(Pipeline):
|
|||||||
- **token** (:obj:`int`) -- The predicted token id (to replace the masked one).
|
- **token** (:obj:`int`) -- The predicted token id (to replace the masked one).
|
||||||
- **token** (:obj:`str`) -- The predicted token (to replace the masked one).
|
- **token** (:obj:`str`) -- The predicted token (to replace the masked one).
|
||||||
"""
|
"""
|
||||||
inputs = self._parse_and_tokenize(*args, **kwargs)
|
model_inputs = self.get_model_inputs(inputs, *args, **kwargs)
|
||||||
outputs = self._forward(inputs, return_tensors=True)
|
self.ensure_exactly_one_mask_token(model_inputs)
|
||||||
|
if isinstance(model_inputs, list):
|
||||||
|
outputs = []
|
||||||
|
for model_input in model_inputs:
|
||||||
|
output = self._forward(model_input, return_tensors=True)
|
||||||
|
outputs.append(output)
|
||||||
|
|
||||||
|
batch_size = len(model_inputs)
|
||||||
|
else:
|
||||||
|
outputs = self._forward(model_inputs, return_tensors=True)
|
||||||
|
batch_size = outputs.shape[0] if self.framework == "tf" else outputs.size(0)
|
||||||
|
|
||||||
# top_k must be defined
|
# top_k must be defined
|
||||||
if top_k is None:
|
if top_k is None:
|
||||||
top_k = self.top_k
|
top_k = self.top_k
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
batch_size = outputs.shape[0] if self.framework == "tf" else outputs.size(0)
|
|
||||||
|
|
||||||
|
if targets is None and self.targets is not None:
|
||||||
|
targets = self.targets
|
||||||
if targets is not None:
|
if targets is not None:
|
||||||
if isinstance(targets, str):
|
if isinstance(targets, str):
|
||||||
targets = [targets]
|
targets = [targets]
|
||||||
@@ -167,16 +220,21 @@ class FillMaskPipeline(Pipeline):
|
|||||||
top_k = target_ids.shape[0]
|
top_k = target_ids.shape[0]
|
||||||
|
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
input_ids = inputs["input_ids"][i]
|
if isinstance(model_inputs, list):
|
||||||
|
input_ids = model_inputs[i]["input_ids"][0]
|
||||||
|
else:
|
||||||
|
input_ids = model_inputs["input_ids"][i]
|
||||||
result = []
|
result = []
|
||||||
|
|
||||||
if self.framework == "tf":
|
if self.framework == "tf":
|
||||||
masked_index = tf.where(input_ids == self.tokenizer.mask_token_id).numpy()
|
masked_index = tf.where(input_ids == self.tokenizer.mask_token_id).numpy()
|
||||||
|
|
||||||
# Fill mask pipeline supports only one ${mask_token} per sample
|
# Fill mask pipeline supports only one ${mask_token} per sample
|
||||||
self.ensure_exactly_one_mask_token(masked_index)
|
|
||||||
|
|
||||||
logits = outputs[i, masked_index.item(), :]
|
if isinstance(outputs, list):
|
||||||
|
logits = outputs[i][0, masked_index.item(), :]
|
||||||
|
else:
|
||||||
|
logits = outputs[i, masked_index.item(), :]
|
||||||
probs = tf.nn.softmax(logits)
|
probs = tf.nn.softmax(logits)
|
||||||
if targets is not None:
|
if targets is not None:
|
||||||
probs = tf.gather_nd(probs, tf.reshape(target_ids, (-1, 1)))
|
probs = tf.gather_nd(probs, tf.reshape(target_ids, (-1, 1)))
|
||||||
@@ -185,11 +243,12 @@ class FillMaskPipeline(Pipeline):
|
|||||||
values, predictions = topk.values.numpy(), topk.indices.numpy()
|
values, predictions = topk.values.numpy(), topk.indices.numpy()
|
||||||
else:
|
else:
|
||||||
masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False)
|
masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False)
|
||||||
|
|
||||||
# Fill mask pipeline supports only one ${mask_token} per sample
|
# Fill mask pipeline supports only one ${mask_token} per sample
|
||||||
self.ensure_exactly_one_mask_token(masked_index.numpy())
|
|
||||||
|
|
||||||
logits = outputs[i, masked_index.item(), :]
|
if isinstance(outputs, list):
|
||||||
|
logits = outputs[i][0, masked_index.item(), :]
|
||||||
|
else:
|
||||||
|
logits = outputs[i, masked_index.item(), :]
|
||||||
probs = logits.softmax(dim=0)
|
probs = logits.softmax(dim=0)
|
||||||
if targets is not None:
|
if targets is not None:
|
||||||
probs = probs[..., target_ids]
|
probs = probs[..., target_ids]
|
||||||
|
|||||||
@@ -189,6 +189,7 @@ class ReformerModelTester:
|
|||||||
def get_pipeline_config(self):
|
def get_pipeline_config(self):
|
||||||
config = self.get_config()
|
config = self.get_config()
|
||||||
config.vocab_size = 100
|
config.vocab_size = 100
|
||||||
|
config.is_decoder = False
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def create_and_check_reformer_model(self, config, input_ids, input_mask, choice_labels):
|
def create_and_check_reformer_model(self, config, input_ids, input_mask, choice_labels):
|
||||||
|
|||||||
@@ -74,10 +74,10 @@ def get_tiny_config_from_class(configuration_class):
|
|||||||
@lru_cache(maxsize=100)
|
@lru_cache(maxsize=100)
|
||||||
def get_tiny_tokenizer_from_checkpoint(checkpoint):
|
def get_tiny_tokenizer_from_checkpoint(checkpoint):
|
||||||
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
||||||
logger.warning("Training new from iterator ...")
|
logger.info("Training new from iterator ...")
|
||||||
vocabulary = string.ascii_letters + string.digits + " "
|
vocabulary = string.ascii_letters + string.digits + " "
|
||||||
tokenizer = tokenizer.train_new_from_iterator(vocabulary, vocab_size=len(vocabulary), show_progress=False)
|
tokenizer = tokenizer.train_new_from_iterator(vocabulary, vocab_size=len(vocabulary), show_progress=False)
|
||||||
logger.warning("Trained.")
|
logger.info("Trained.")
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
@@ -109,9 +109,7 @@ class PipelineTestCaseMeta(type):
|
|||||||
# Some test tokenizer contain broken vocabs or custom PreTokenizer, so we
|
# Some test tokenizer contain broken vocabs or custom PreTokenizer, so we
|
||||||
# provide some default tokenizer and hope for the best.
|
# provide some default tokenizer and hope for the best.
|
||||||
except: # noqa: E722
|
except: # noqa: E722
|
||||||
logger.warning(f"Tokenizer cannot be created from checkpoint {checkpoint}")
|
self.skipTest(f"Ignoring {ModelClass}, cannot create a simple tokenizer")
|
||||||
tokenizer = get_tiny_tokenizer_from_checkpoint("gpt2")
|
|
||||||
tokenizer.model_max_length = model.config.max_position_embeddings
|
|
||||||
self.run_pipeline_test(model, tokenizer)
|
self.run_pipeline_test(model, tokenizer)
|
||||||
|
|
||||||
return test
|
return test
|
||||||
|
|||||||
@@ -14,304 +14,307 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import pipeline
|
from transformers import MODEL_FOR_MASKED_LM_MAPPING, TF_MODEL_FOR_MASKED_LM_MAPPING, FillMaskPipeline, pipeline
|
||||||
from transformers.testing_utils import nested_simplify, require_tf, require_torch, slow
|
from transformers.pipelines import PipelineException
|
||||||
|
from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch, slow
|
||||||
|
|
||||||
from .test_pipelines_common import MonoInputPipelineCommonMixin
|
from .test_pipelines_common import ANY, PipelineTestCaseMeta
|
||||||
|
|
||||||
|
|
||||||
EXPECTED_FILL_MASK_RESULT = [
|
@is_pipeline_test
|
||||||
[
|
class FillMaskPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
||||||
{"sequence": "My name is John", "score": 0.00782308354973793, "token": 610, "token_str": " John"},
|
model_mapping = MODEL_FOR_MASKED_LM_MAPPING
|
||||||
{"sequence": "My name is Chris", "score": 0.007475061342120171, "token": 1573, "token_str": " Chris"},
|
tf_model_mapping = TF_MODEL_FOR_MASKED_LM_MAPPING
|
||||||
],
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"sequence": "The largest city in France is Paris",
|
|
||||||
"score": 0.2510891854763031,
|
|
||||||
"token": 2201,
|
|
||||||
"token_str": " Paris",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"sequence": "The largest city in France is Lyon",
|
|
||||||
"score": 0.21418564021587372,
|
|
||||||
"token": 12790,
|
|
||||||
"token_str": " Lyon",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
]
|
|
||||||
|
|
||||||
EXPECTED_FILL_MASK_TARGET_RESULT = [EXPECTED_FILL_MASK_RESULT[0]]
|
@require_tf
|
||||||
|
def test_small_model_tf(self):
|
||||||
|
unmasker = pipeline(task="fill-mask", model="sshleifer/tiny-distilroberta-base", top_k=2, framework="tf")
|
||||||
class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
outputs = unmasker("My name is <mask>")
|
||||||
pipeline_task = "fill-mask"
|
|
||||||
pipeline_loading_kwargs = {"top_k": 2}
|
|
||||||
small_models = ["sshleifer/tiny-distilroberta-base"] # Models tested without the @slow decorator
|
|
||||||
large_models = ["distilroberta-base"] # Models tested with the @slow decorator
|
|
||||||
mandatory_keys = {"sequence", "score", "token"}
|
|
||||||
valid_inputs = [
|
|
||||||
"My name is <mask>",
|
|
||||||
"The largest city in France is <mask>",
|
|
||||||
]
|
|
||||||
invalid_inputs = [
|
|
||||||
"This is <mask> <mask>" # More than 1 mask_token in the input is not supported
|
|
||||||
"This is" # No mask_token is not supported
|
|
||||||
]
|
|
||||||
expected_check_keys = ["sequence"]
|
|
||||||
|
|
||||||
@require_torch
|
|
||||||
def test_torch_fill_mask(self):
|
|
||||||
valid_inputs = "My name is <mask>"
|
|
||||||
unmasker = pipeline(task="fill-mask", model=self.small_models[0])
|
|
||||||
outputs = unmasker(valid_inputs)
|
|
||||||
self.assertIsInstance(outputs, list)
|
|
||||||
|
|
||||||
# This passes
|
|
||||||
outputs = unmasker(valid_inputs, targets=[" Patrick", " Clara"])
|
|
||||||
self.assertIsInstance(outputs, list)
|
|
||||||
|
|
||||||
# This used to fail with `cannot mix args and kwargs`
|
|
||||||
outputs = unmasker(valid_inputs, something=False)
|
|
||||||
self.assertIsInstance(outputs, list)
|
|
||||||
|
|
||||||
@require_torch
|
|
||||||
def test_torch_fill_mask_with_targets(self):
|
|
||||||
valid_inputs = ["My name is <mask>"]
|
|
||||||
# ' Sam' will yield a warning but work
|
|
||||||
valid_targets = [[" Teven", "ĠPatrick", "ĠClara"], ["ĠSam"], [" Sam"]]
|
|
||||||
invalid_targets = [[], [""], ""]
|
|
||||||
for model_name in self.small_models:
|
|
||||||
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt")
|
|
||||||
for targets in valid_targets:
|
|
||||||
outputs = unmasker(valid_inputs, targets=targets)
|
|
||||||
self.assertIsInstance(outputs, list)
|
|
||||||
self.assertEqual(len(outputs), len(targets))
|
|
||||||
for targets in invalid_targets:
|
|
||||||
self.assertRaises(ValueError, unmasker, valid_inputs, targets=targets)
|
|
||||||
|
|
||||||
@require_torch
|
|
||||||
@slow
|
|
||||||
def test_torch_fill_mask_targets_equivalence(self):
|
|
||||||
model_name = self.large_models[0]
|
|
||||||
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt")
|
|
||||||
unmasked = unmasker(self.valid_inputs[0])
|
|
||||||
tokens = [top_mask["token_str"] for top_mask in unmasked]
|
|
||||||
scores = [top_mask["score"] for top_mask in unmasked]
|
|
||||||
|
|
||||||
unmasked_targets = unmasker(self.valid_inputs[0], targets=tokens)
|
|
||||||
target_scores = [top_mask["score"] for top_mask in unmasked_targets]
|
|
||||||
|
|
||||||
self.assertEqual(scores, target_scores)
|
|
||||||
|
|
||||||
@require_torch
|
|
||||||
def test_torch_fill_mask_with_targets_and_topk(self):
|
|
||||||
model_name = self.small_models[0]
|
|
||||||
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt")
|
|
||||||
targets = [" Teven", "ĠPatrick", "ĠClara"]
|
|
||||||
top_k = 2
|
|
||||||
outputs = unmasker("My name is <mask>", targets=targets, top_k=top_k)
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
nested_simplify(outputs),
|
nested_simplify(outputs, decimals=6),
|
||||||
[
|
[
|
||||||
{"sequence": "My name is Patrick", "score": 0.0, "token": 3499, "token_str": " Patrick"},
|
{"sequence": "My name is grouped", "score": 2.1e-05, "token": 38015, "token_str": " grouped"},
|
||||||
{"sequence": "My name is Te", "score": 0.0, "token": 2941, "token_str": " Te"},
|
{"sequence": "My name is accuser", "score": 2.1e-05, "token": 25506, "token_str": " accuser"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = unmasker("The largest city in France is <mask>")
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs, decimals=6),
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"sequence": "The largest city in France is grouped",
|
||||||
|
"score": 2.1e-05,
|
||||||
|
"token": 38015,
|
||||||
|
"token_str": " grouped",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"sequence": "The largest city in France is accuser",
|
||||||
|
"score": 2.1e-05,
|
||||||
|
"token": 25506,
|
||||||
|
"token_str": " accuser",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = unmasker("My name is <mask>", targets=[" Patrick", " Clara", " Teven"], top_k=3)
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs, decimals=6),
|
||||||
|
[
|
||||||
|
{"sequence": "My name is Clara", "score": 2e-05, "token": 13606, "token_str": " Clara"},
|
||||||
|
{"sequence": "My name is Patrick", "score": 2e-05, "token": 3499, "token_str": " Patrick"},
|
||||||
|
{"sequence": "My name is Te", "score": 1.9e-05, "token": 2941, "token_str": " Te"},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_torch_fill_mask_with_duplicate_targets_and_topk(self):
|
def test_small_model_pt(self):
|
||||||
model_name = self.small_models[0]
|
unmasker = pipeline(task="fill-mask", model="sshleifer/tiny-distilroberta-base", top_k=2, framework="pt")
|
||||||
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt")
|
|
||||||
|
outputs = unmasker("My name is <mask>")
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs, decimals=6),
|
||||||
|
[
|
||||||
|
{"sequence": "My name is Maul", "score": 2.2e-05, "token": 35676, "token_str": " Maul"},
|
||||||
|
{"sequence": "My name isELS", "score": 2.2e-05, "token": 16416, "token_str": "ELS"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = unmasker("The largest city in France is <mask>")
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs, decimals=6),
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"sequence": "The largest city in France is Maul",
|
||||||
|
"score": 2.2e-05,
|
||||||
|
"token": 35676,
|
||||||
|
"token_str": " Maul",
|
||||||
|
},
|
||||||
|
{"sequence": "The largest city in France isELS", "score": 2.2e-05, "token": 16416, "token_str": "ELS"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = unmasker("My name is <mask>", targets=[" Patrick", " Clara", " Teven"], top_k=3)
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs, decimals=6),
|
||||||
|
[
|
||||||
|
{"sequence": "My name is Patrick", "score": 2.1e-05, "token": 3499, "token_str": " Patrick"},
|
||||||
|
{"sequence": "My name is Te", "score": 2e-05, "token": 2941, "token_str": " Te"},
|
||||||
|
{"sequence": "My name is Clara", "score": 2e-05, "token": 13606, "token_str": " Clara"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
def test_large_model_pt(self):
|
||||||
|
unmasker = pipeline(task="fill-mask", model="distilroberta-base", top_k=2, framework="pt")
|
||||||
|
self.run_large_test(unmasker)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_tf
|
||||||
|
def test_large_model_tf(self):
|
||||||
|
unmasker = pipeline(task="fill-mask", model="distilroberta-base", top_k=2, framework="tf")
|
||||||
|
self.run_large_test(unmasker)
|
||||||
|
|
||||||
|
def run_large_test(self, unmasker):
|
||||||
|
outputs = unmasker("My name is <mask>")
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs),
|
||||||
|
[
|
||||||
|
{"sequence": "My name is John", "score": 0.008, "token": 610, "token_str": " John"},
|
||||||
|
{"sequence": "My name is Chris", "score": 0.007, "token": 1573, "token_str": " Chris"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
outputs = unmasker("The largest city in France is <mask>")
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs),
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"sequence": "The largest city in France is Paris",
|
||||||
|
"score": 0.251,
|
||||||
|
"token": 2201,
|
||||||
|
"token_str": " Paris",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"sequence": "The largest city in France is Lyon",
|
||||||
|
"score": 0.214,
|
||||||
|
"token": 12790,
|
||||||
|
"token_str": " Lyon",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = unmasker("My name is <mask>", targets=[" Patrick", " Clara", " Teven"], top_k=3)
|
||||||
|
self.assertEqual(
|
||||||
|
nested_simplify(outputs),
|
||||||
|
[
|
||||||
|
{"sequence": "My name is Patrick", "score": 0.005, "token": 3499, "token_str": " Patrick"},
|
||||||
|
{"sequence": "My name is Clara", "score": 0.000, "token": 13606, "token_str": " Clara"},
|
||||||
|
{"sequence": "My name is Te", "score": 0.000, "token": 2941, "token_str": " Te"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_model_no_pad_pt(self):
|
||||||
|
unmasker = pipeline(task="fill-mask", model="sshleifer/tiny-distilroberta-base", framework="pt")
|
||||||
|
unmasker.tokenizer.pad_token_id = None
|
||||||
|
unmasker.tokenizer.pad_token = None
|
||||||
|
self.run_pipeline_test(unmasker.model, unmasker.tokenizer)
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
def test_model_no_pad_tf(self):
|
||||||
|
unmasker = pipeline(task="fill-mask", model="sshleifer/tiny-distilroberta-base", framework="tf")
|
||||||
|
unmasker.tokenizer.pad_token_id = None
|
||||||
|
unmasker.tokenizer.pad_token = None
|
||||||
|
self.run_pipeline_test(unmasker.model, unmasker.tokenizer)
|
||||||
|
|
||||||
|
def run_pipeline_test(self, model, tokenizer):
|
||||||
|
if tokenizer.mask_token_id is None:
|
||||||
|
self.skipTest("The provided tokenizer has no mask token, (probably reformer)")
|
||||||
|
|
||||||
|
fill_masker = FillMaskPipeline(model=model, tokenizer=tokenizer)
|
||||||
|
|
||||||
|
outputs = fill_masker(f"This is a {tokenizer.mask_token}")
|
||||||
|
self.assertEqual(
|
||||||
|
outputs,
|
||||||
|
[
|
||||||
|
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
|
||||||
|
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
|
||||||
|
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
|
||||||
|
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
|
||||||
|
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = fill_masker([f"This is a {tokenizer.mask_token}", f"Another {tokenizer.mask_token}"])
|
||||||
|
self.assertEqual(
|
||||||
|
outputs,
|
||||||
|
[
|
||||||
|
[
|
||||||
|
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
|
||||||
|
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
|
||||||
|
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
|
||||||
|
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
|
||||||
|
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
|
||||||
|
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
|
||||||
|
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
|
||||||
|
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
|
||||||
|
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
fill_masker([None])
|
||||||
|
# Multiple masks
|
||||||
|
with self.assertRaises(PipelineException):
|
||||||
|
fill_masker(f"This is {tokenizer.mask_token} {tokenizer.mask_token}")
|
||||||
|
# No mask_token is not supported
|
||||||
|
with self.assertRaises(PipelineException):
|
||||||
|
fill_masker("This is")
|
||||||
|
|
||||||
|
self.run_test_top_k(model, tokenizer)
|
||||||
|
self.run_test_targets(model, tokenizer)
|
||||||
|
self.run_test_top_k_targets(model, tokenizer)
|
||||||
|
self.fill_mask_with_duplicate_targets_and_top_k(model, tokenizer)
|
||||||
|
|
||||||
|
def run_test_targets(self, model, tokenizer):
|
||||||
|
vocab = tokenizer.get_vocab()
|
||||||
|
targets = list(sorted(vocab.keys()))[:2]
|
||||||
|
# Pipeline argument
|
||||||
|
fill_masker = FillMaskPipeline(model=model, tokenizer=tokenizer, targets=targets)
|
||||||
|
outputs = fill_masker(f"This is a {tokenizer.mask_token}")
|
||||||
|
self.assertEqual(
|
||||||
|
outputs,
|
||||||
|
[
|
||||||
|
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
|
||||||
|
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
target_ids = {vocab[el] for el in targets}
|
||||||
|
self.assertEqual(set(el["token"] for el in outputs), target_ids)
|
||||||
|
self.assertEqual(set(el["token_str"] for el in outputs), set(targets))
|
||||||
|
|
||||||
|
# Call argument
|
||||||
|
fill_masker = FillMaskPipeline(model=model, tokenizer=tokenizer)
|
||||||
|
outputs = fill_masker(f"This is a {tokenizer.mask_token}", targets=targets)
|
||||||
|
self.assertEqual(
|
||||||
|
outputs,
|
||||||
|
[
|
||||||
|
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
|
||||||
|
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
target_ids = {vocab[el] for el in targets}
|
||||||
|
self.assertEqual(set(el["token"] for el in outputs), target_ids)
|
||||||
|
self.assertEqual(set(el["token_str"] for el in outputs), set(targets))
|
||||||
|
|
||||||
|
# Score equivalence
|
||||||
|
outputs = fill_masker(f"This is a {tokenizer.mask_token}", targets=targets)
|
||||||
|
tokens = [top_mask["token_str"] for top_mask in outputs]
|
||||||
|
scores = [top_mask["score"] for top_mask in outputs]
|
||||||
|
|
||||||
|
unmasked_targets = fill_masker(f"This is a {tokenizer.mask_token}", targets=tokens)
|
||||||
|
target_scores = [top_mask["score"] for top_mask in unmasked_targets]
|
||||||
|
self.assertEqual(nested_simplify(scores), nested_simplify(target_scores))
|
||||||
|
|
||||||
|
# Raises with invalid
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
outputs = fill_masker(f"This is a {tokenizer.mask_token}", targets=[""])
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
outputs = fill_masker(f"This is a {tokenizer.mask_token}", targets=[])
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
outputs = fill_masker(f"This is a {tokenizer.mask_token}", targets="")
|
||||||
|
|
||||||
|
def run_test_top_k(self, model, tokenizer):
|
||||||
|
fill_masker = FillMaskPipeline(model=model, tokenizer=tokenizer, top_k=2)
|
||||||
|
outputs = fill_masker(f"This is a {tokenizer.mask_token}")
|
||||||
|
self.assertEqual(
|
||||||
|
outputs,
|
||||||
|
[
|
||||||
|
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
|
||||||
|
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
fill_masker = FillMaskPipeline(model=model, tokenizer=tokenizer)
|
||||||
|
outputs2 = fill_masker(f"This is a {tokenizer.mask_token}", top_k=2)
|
||||||
|
self.assertEqual(
|
||||||
|
outputs2,
|
||||||
|
[
|
||||||
|
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
|
||||||
|
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
self.assertEqual(nested_simplify(outputs), nested_simplify(outputs2))
|
||||||
|
|
||||||
|
def run_test_top_k_targets(self, model, tokenizer):
|
||||||
|
vocab = tokenizer.get_vocab()
|
||||||
|
fill_masker = FillMaskPipeline(model=model, tokenizer=tokenizer)
|
||||||
|
|
||||||
|
# top_k=2, ntargets=3
|
||||||
|
targets = list(sorted(vocab.keys()))[:3]
|
||||||
|
outputs = fill_masker(f"This is a {tokenizer.mask_token}", top_k=2, targets=targets)
|
||||||
|
|
||||||
|
# If we use the most probably targets, and filter differently, we should still
|
||||||
|
# have the same results
|
||||||
|
targets2 = [el["token_str"] for el in sorted(outputs, key=lambda x: x["score"], reverse=True)]
|
||||||
|
outputs2 = fill_masker(f"This is a {tokenizer.mask_token}", top_k=3, targets=targets2)
|
||||||
|
|
||||||
|
# They should yield exactly the same result
|
||||||
|
self.assertEqual(nested_simplify(outputs), nested_simplify(outputs2))
|
||||||
|
|
||||||
|
def fill_mask_with_duplicate_targets_and_top_k(self, model, tokenizer):
|
||||||
|
fill_masker = FillMaskPipeline(model=model, tokenizer=tokenizer)
|
||||||
|
vocab = tokenizer.get_vocab()
|
||||||
# String duplicates + id duplicates
|
# String duplicates + id duplicates
|
||||||
targets = [" Teven", "ĠPatrick", "ĠClara", "ĠClara", " Clara"]
|
targets = list(sorted(vocab.keys()))[:3]
|
||||||
top_k = 10
|
targets = [targets[0], targets[1], targets[0], targets[2], targets[1]]
|
||||||
outputs = unmasker("My name is <mask>", targets=targets, top_k=top_k)
|
outputs = fill_masker(f"My name is {tokenizer.mask_token}", targets=targets, top_k=10)
|
||||||
|
|
||||||
# The target list contains duplicates, so we can't output more
|
# The target list contains duplicates, so we can't output more
|
||||||
# than them
|
# than them
|
||||||
self.assertEqual(len(outputs), 3)
|
self.assertEqual(len(outputs), 3)
|
||||||
|
|
||||||
@require_tf
|
|
||||||
def test_tf_fill_mask_with_targets(self):
|
|
||||||
valid_inputs = ["My name is <mask>"]
|
|
||||||
# ' Teven' will yield a warning but work as " Te"
|
|
||||||
invalid_targets = [[], [""], ""]
|
|
||||||
unmasker = pipeline(
|
|
||||||
task="fill-mask", model=self.small_models[0], tokenizer=self.small_models[0], framework="tf"
|
|
||||||
)
|
|
||||||
outputs = unmasker(valid_inputs, targets=[" Teven", "ĠPatrick", "ĠClara"])
|
|
||||||
self.assertEqual(
|
|
||||||
nested_simplify(outputs),
|
|
||||||
[
|
|
||||||
{"sequence": "My name is Clara", "score": 0.0, "token": 13606, "token_str": " Clara"},
|
|
||||||
{"sequence": "My name is Patrick", "score": 0.0, "token": 3499, "token_str": " Patrick"},
|
|
||||||
{"sequence": "My name is Te", "score": 0.0, "token": 2941, "token_str": " Te"},
|
|
||||||
],
|
|
||||||
)
|
|
||||||
# topk
|
|
||||||
outputs = unmasker(valid_inputs, targets=[" Teven", "ĠPatrick", "ĠClara"], top_k=2)
|
|
||||||
self.assertEqual(
|
|
||||||
nested_simplify(outputs),
|
|
||||||
[
|
|
||||||
{"sequence": "My name is Clara", "score": 0.0, "token": 13606, "token_str": " Clara"},
|
|
||||||
{"sequence": "My name is Patrick", "score": 0.0, "token": 3499, "token_str": " Patrick"},
|
|
||||||
],
|
|
||||||
)
|
|
||||||
for targets in invalid_targets:
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
unmasker(valid_inputs, targets=targets)
|
|
||||||
|
|
||||||
@require_torch
|
|
||||||
@slow
|
|
||||||
def test_torch_fill_mask_results(self):
|
|
||||||
mandatory_keys = {"sequence", "score", "token"}
|
|
||||||
valid_inputs = [
|
|
||||||
"My name is <mask>",
|
|
||||||
"The largest city in France is <mask>",
|
|
||||||
]
|
|
||||||
valid_targets = ["ĠPatrick", "ĠClara"]
|
|
||||||
for model_name in self.large_models:
|
|
||||||
unmasker = pipeline(
|
|
||||||
task="fill-mask",
|
|
||||||
model=model_name,
|
|
||||||
tokenizer=model_name,
|
|
||||||
framework="pt",
|
|
||||||
top_k=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
mono_result = unmasker(valid_inputs[0], targets=valid_targets)
|
|
||||||
self.assertIsInstance(mono_result, list)
|
|
||||||
self.assertIsInstance(mono_result[0], dict)
|
|
||||||
|
|
||||||
for mandatory_key in mandatory_keys:
|
|
||||||
self.assertIn(mandatory_key, mono_result[0])
|
|
||||||
|
|
||||||
multi_result = [unmasker(valid_input) for valid_input in valid_inputs]
|
|
||||||
self.assertIsInstance(multi_result, list)
|
|
||||||
self.assertIsInstance(multi_result[0], (dict, list))
|
|
||||||
|
|
||||||
for result, expected in zip(multi_result, EXPECTED_FILL_MASK_RESULT):
|
|
||||||
for r, e in zip(result, expected):
|
|
||||||
self.assertEqual(r["sequence"], e["sequence"])
|
|
||||||
self.assertEqual(r["token_str"], e["token_str"])
|
|
||||||
self.assertEqual(r["token"], e["token"])
|
|
||||||
self.assertAlmostEqual(r["score"], e["score"], places=3)
|
|
||||||
|
|
||||||
if isinstance(multi_result[0], list):
|
|
||||||
multi_result = multi_result[0]
|
|
||||||
|
|
||||||
for result in multi_result:
|
|
||||||
for key in mandatory_keys:
|
|
||||||
self.assertIn(key, result)
|
|
||||||
|
|
||||||
self.assertRaises(Exception, unmasker, [None])
|
|
||||||
|
|
||||||
valid_inputs = valid_inputs[:1]
|
|
||||||
mono_result = unmasker(valid_inputs[0], targets=valid_targets)
|
|
||||||
self.assertIsInstance(mono_result, list)
|
|
||||||
self.assertIsInstance(mono_result[0], dict)
|
|
||||||
|
|
||||||
for mandatory_key in mandatory_keys:
|
|
||||||
self.assertIn(mandatory_key, mono_result[0])
|
|
||||||
|
|
||||||
multi_result = [unmasker(valid_input) for valid_input in valid_inputs]
|
|
||||||
self.assertIsInstance(multi_result, list)
|
|
||||||
self.assertIsInstance(multi_result[0], (dict, list))
|
|
||||||
|
|
||||||
for result, expected in zip(multi_result, EXPECTED_FILL_MASK_TARGET_RESULT):
|
|
||||||
for r, e in zip(result, expected):
|
|
||||||
self.assertEqual(r["sequence"], e["sequence"])
|
|
||||||
self.assertEqual(r["token_str"], e["token_str"])
|
|
||||||
self.assertEqual(r["token"], e["token"])
|
|
||||||
self.assertAlmostEqual(r["score"], e["score"], places=3)
|
|
||||||
|
|
||||||
if isinstance(multi_result[0], list):
|
|
||||||
multi_result = multi_result[0]
|
|
||||||
|
|
||||||
for result in multi_result:
|
|
||||||
for key in mandatory_keys:
|
|
||||||
self.assertIn(key, result)
|
|
||||||
|
|
||||||
self.assertRaises(Exception, unmasker, [None])
|
|
||||||
|
|
||||||
@require_tf
|
|
||||||
@slow
|
|
||||||
def test_tf_fill_mask_results(self):
|
|
||||||
mandatory_keys = {"sequence", "score", "token"}
|
|
||||||
valid_inputs = [
|
|
||||||
"My name is <mask>",
|
|
||||||
"The largest city in France is <mask>",
|
|
||||||
]
|
|
||||||
valid_targets = ["ĠPatrick", "ĠClara"]
|
|
||||||
for model_name in self.large_models:
|
|
||||||
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf", top_k=2)
|
|
||||||
|
|
||||||
mono_result = unmasker(valid_inputs[0], targets=valid_targets)
|
|
||||||
self.assertIsInstance(mono_result, list)
|
|
||||||
self.assertIsInstance(mono_result[0], dict)
|
|
||||||
|
|
||||||
for mandatory_key in mandatory_keys:
|
|
||||||
self.assertIn(mandatory_key, mono_result[0])
|
|
||||||
|
|
||||||
multi_result = [unmasker(valid_input) for valid_input in valid_inputs]
|
|
||||||
self.assertIsInstance(multi_result, list)
|
|
||||||
self.assertIsInstance(multi_result[0], (dict, list))
|
|
||||||
|
|
||||||
for result, expected in zip(multi_result, EXPECTED_FILL_MASK_RESULT):
|
|
||||||
for r, e in zip(result, expected):
|
|
||||||
self.assertEqual(r["sequence"], e["sequence"])
|
|
||||||
self.assertEqual(r["token_str"], e["token_str"])
|
|
||||||
self.assertEqual(r["token"], e["token"])
|
|
||||||
self.assertAlmostEqual(r["score"], e["score"], places=3)
|
|
||||||
|
|
||||||
if isinstance(multi_result[0], list):
|
|
||||||
multi_result = multi_result[0]
|
|
||||||
|
|
||||||
for result in multi_result:
|
|
||||||
for key in mandatory_keys:
|
|
||||||
self.assertIn(key, result)
|
|
||||||
|
|
||||||
self.assertRaises(Exception, unmasker, [None])
|
|
||||||
|
|
||||||
valid_inputs = valid_inputs[:1]
|
|
||||||
mono_result = unmasker(valid_inputs[0], targets=valid_targets)
|
|
||||||
self.assertIsInstance(mono_result, list)
|
|
||||||
self.assertIsInstance(mono_result[0], dict)
|
|
||||||
|
|
||||||
for mandatory_key in mandatory_keys:
|
|
||||||
self.assertIn(mandatory_key, mono_result[0])
|
|
||||||
|
|
||||||
multi_result = [unmasker(valid_input) for valid_input in valid_inputs]
|
|
||||||
self.assertIsInstance(multi_result, list)
|
|
||||||
self.assertIsInstance(multi_result[0], (dict, list))
|
|
||||||
|
|
||||||
for result, expected in zip(multi_result, EXPECTED_FILL_MASK_TARGET_RESULT):
|
|
||||||
for r, e in zip(result, expected):
|
|
||||||
self.assertEqual(r["sequence"], e["sequence"])
|
|
||||||
self.assertEqual(r["token_str"], e["token_str"])
|
|
||||||
self.assertEqual(r["token"], e["token"])
|
|
||||||
self.assertAlmostEqual(r["score"], e["score"], places=3)
|
|
||||||
|
|
||||||
if isinstance(multi_result[0], list):
|
|
||||||
multi_result = multi_result[0]
|
|
||||||
|
|
||||||
for result in multi_result:
|
|
||||||
for key in mandatory_keys:
|
|
||||||
self.assertIn(key, result)
|
|
||||||
|
|
||||||
self.assertRaises(Exception, unmasker, [None])
|
|
||||||
|
|
||||||
@require_tf
|
|
||||||
@slow
|
|
||||||
def test_tf_fill_mask_targets_equivalence(self):
|
|
||||||
model_name = self.large_models[0]
|
|
||||||
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf")
|
|
||||||
unmasked = unmasker(self.valid_inputs[0])
|
|
||||||
tokens = [top_mask["token_str"] for top_mask in unmasked]
|
|
||||||
scores = [top_mask["score"] for top_mask in unmasked]
|
|
||||||
|
|
||||||
unmasked_targets = unmasker(self.valid_inputs[0], targets=tokens)
|
|
||||||
target_scores = [top_mask["score"] for top_mask in unmasked_targets]
|
|
||||||
|
|
||||||
self.assertEqual(scores, target_scores)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user