add targets arg to fill-mask pipeline (#6239)
* add targets arg to fill-mask pipeline * add tests and more error handling * quality * update docstring
This commit is contained in:
@@ -1158,12 +1158,17 @@ class FillMaskPipeline(Pipeline):
|
|||||||
f"No mask_token ({self.tokenizer.mask_token}) found on the input",
|
f"No mask_token ({self.tokenizer.mask_token}) found on the input",
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, targets=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
Fill the masked token in the text(s) given as inputs.
|
Fill the masked token in the text(s) given as inputs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
args (:obj:`str` or :obj:`List[str]`): One or several texts (or one list of prompts) with masked tokens.
|
args (:obj:`str` or :obj:`List[str]`):
|
||||||
|
One or several texts (or one list of prompts) with masked tokens.
|
||||||
|
targets (:obj:`str` or :obj:`List[str]`, `optional`):
|
||||||
|
When passed, the model will return the scores for the passed token or tokens rather than the top k
|
||||||
|
predictions in the entire vocabulary. If the provided targets are not in the model vocab, they will
|
||||||
|
be tokenized and the first resulting token will be used (with a warning).
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
A list or a list of list of :obj:`dict`: Each result comes as list of dictionaries with the
|
A list or a list of list of :obj:`dict`: Each result comes as list of dictionaries with the
|
||||||
@@ -1180,6 +1185,24 @@ class FillMaskPipeline(Pipeline):
|
|||||||
results = []
|
results = []
|
||||||
batch_size = outputs.shape[0] if self.framework == "tf" else outputs.size(0)
|
batch_size = outputs.shape[0] if self.framework == "tf" else outputs.size(0)
|
||||||
|
|
||||||
|
if targets is not None:
|
||||||
|
if len(targets) == 0 or len(targets[0]) == 0:
|
||||||
|
raise ValueError("At least one target must be provided when passed.")
|
||||||
|
if isinstance(targets, str):
|
||||||
|
targets = [targets]
|
||||||
|
|
||||||
|
targets_proc = []
|
||||||
|
for target in targets:
|
||||||
|
target_enc = self.tokenizer.tokenize(target)
|
||||||
|
if len(target_enc) > 1 or target_enc[0] == self.tokenizer.unk_token:
|
||||||
|
logger.warning(
|
||||||
|
"The specified target token `{}` does not exist in the model vocabulary. Replacing with `{}`.".format(
|
||||||
|
target, target_enc[0]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
targets_proc.append(target_enc[0])
|
||||||
|
target_inds = np.array(self.tokenizer.convert_tokens_to_ids(targets_proc))
|
||||||
|
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
input_ids = inputs["input_ids"][i]
|
input_ids = inputs["input_ids"][i]
|
||||||
result = []
|
result = []
|
||||||
@@ -1192,8 +1215,14 @@ class FillMaskPipeline(Pipeline):
|
|||||||
|
|
||||||
logits = outputs[i, masked_index.item(), :]
|
logits = outputs[i, masked_index.item(), :]
|
||||||
probs = tf.nn.softmax(logits)
|
probs = tf.nn.softmax(logits)
|
||||||
topk = tf.math.top_k(probs, k=self.topk)
|
if targets is None:
|
||||||
values, predictions = topk.values.numpy(), topk.indices.numpy()
|
topk = tf.math.top_k(probs, k=self.topk)
|
||||||
|
values, predictions = topk.values.numpy(), topk.indices.numpy()
|
||||||
|
else:
|
||||||
|
values = tf.gather_nd(probs, tf.reshape(target_inds, (-1, 1)))
|
||||||
|
sort_inds = tf.reverse(tf.argsort(values), [0])
|
||||||
|
values = tf.gather_nd(values, tf.reshape(sort_inds, (-1, 1))).numpy()
|
||||||
|
predictions = target_inds[sort_inds.numpy()]
|
||||||
else:
|
else:
|
||||||
masked_index = (input_ids == self.tokenizer.mask_token_id).nonzero()
|
masked_index = (input_ids == self.tokenizer.mask_token_id).nonzero()
|
||||||
|
|
||||||
@@ -1202,7 +1231,13 @@ class FillMaskPipeline(Pipeline):
|
|||||||
|
|
||||||
logits = outputs[i, masked_index.item(), :]
|
logits = outputs[i, masked_index.item(), :]
|
||||||
probs = logits.softmax(dim=0)
|
probs = logits.softmax(dim=0)
|
||||||
values, predictions = probs.topk(self.topk)
|
if targets is None:
|
||||||
|
values, predictions = probs.topk(self.topk)
|
||||||
|
else:
|
||||||
|
values = probs[..., target_inds]
|
||||||
|
sort_inds = list(reversed(values.argsort(dim=-1)))
|
||||||
|
values = values[..., sort_inds]
|
||||||
|
predictions = target_inds[sort_inds]
|
||||||
|
|
||||||
for v, p in zip(values.tolist(), predictions.tolist()):
|
for v, p in zip(values.tolist(), predictions.tolist()):
|
||||||
tokens = input_ids.numpy()
|
tokens = input_ids.numpy()
|
||||||
|
|||||||
@@ -41,6 +41,23 @@ expected_fill_mask_result = [
|
|||||||
],
|
],
|
||||||
]
|
]
|
||||||
|
|
||||||
|
expected_fill_mask_target_result = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"sequence": "<s>My name is Patrick</s>",
|
||||||
|
"score": 0.004992353264242411,
|
||||||
|
"token": 3499,
|
||||||
|
"token_str": "ĠPatrick",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"sequence": "<s>My name is Clara</s>",
|
||||||
|
"score": 0.00019297805556561798,
|
||||||
|
"token": 13606,
|
||||||
|
"token_str": "ĠClara",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
SUMMARIZATION_KWARGS = dict(num_beams=2, min_length=2, max_length=5)
|
SUMMARIZATION_KWARGS = dict(num_beams=2, min_length=2, max_length=5)
|
||||||
|
|
||||||
|
|
||||||
@@ -139,7 +156,7 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
|||||||
for key in output_keys:
|
for key in output_keys:
|
||||||
self.assertIn(key, mono_result[0])
|
self.assertIn(key, mono_result[0])
|
||||||
|
|
||||||
multi_result = [nlp(input) for input in valid_inputs]
|
multi_result = [nlp(input, **kwargs) for input in valid_inputs]
|
||||||
self.assertIsInstance(multi_result, list)
|
self.assertIsInstance(multi_result, list)
|
||||||
self.assertIsInstance(multi_result[0], (dict, list))
|
self.assertIsInstance(multi_result[0], (dict, list))
|
||||||
|
|
||||||
@@ -219,6 +236,34 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
|||||||
nlp, valid_inputs, mandatory_keys, invalid_inputs, expected_check_keys=["sequence"]
|
nlp, valid_inputs, mandatory_keys, invalid_inputs, expected_check_keys=["sequence"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_torch_fill_mask_with_targets(self):
|
||||||
|
valid_inputs = ["My name is <mask>"]
|
||||||
|
valid_targets = [[" Teven", " Patrick", " Clara"], [" Sam"]]
|
||||||
|
invalid_targets = [[], [""], ""]
|
||||||
|
for model_name in FILL_MASK_FINETUNED_MODELS:
|
||||||
|
nlp = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt")
|
||||||
|
for targets in valid_targets:
|
||||||
|
outputs = nlp(valid_inputs, targets=targets)
|
||||||
|
self.assertIsInstance(outputs, list)
|
||||||
|
self.assertEqual(len(outputs), len(targets))
|
||||||
|
for targets in invalid_targets:
|
||||||
|
self.assertRaises(ValueError, nlp, valid_inputs, targets=targets)
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
def test_tf_fill_mask_with_targets(self):
|
||||||
|
valid_inputs = ["My name is <mask>"]
|
||||||
|
valid_targets = [[" Teven", " Patrick", " Clara"], [" Sam"]]
|
||||||
|
invalid_targets = [[], [""], ""]
|
||||||
|
for model_name in FILL_MASK_FINETUNED_MODELS:
|
||||||
|
nlp = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf")
|
||||||
|
for targets in valid_targets:
|
||||||
|
outputs = nlp(valid_inputs, targets=targets)
|
||||||
|
self.assertIsInstance(outputs, list)
|
||||||
|
self.assertEqual(len(outputs), len(targets))
|
||||||
|
for targets in invalid_targets:
|
||||||
|
self.assertRaises(ValueError, nlp, valid_inputs, targets=targets)
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@slow
|
@slow
|
||||||
def test_torch_fill_mask_results(self):
|
def test_torch_fill_mask_results(self):
|
||||||
@@ -227,6 +272,7 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
|||||||
"My name is <mask>",
|
"My name is <mask>",
|
||||||
"The largest city in France is <mask>",
|
"The largest city in France is <mask>",
|
||||||
]
|
]
|
||||||
|
valid_targets = [" Patrick", " Clara"]
|
||||||
for model_name in LARGE_FILL_MASK_FINETUNED_MODELS:
|
for model_name in LARGE_FILL_MASK_FINETUNED_MODELS:
|
||||||
nlp = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt", topk=2,)
|
nlp = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt", topk=2,)
|
||||||
self._test_mono_column_pipeline(
|
self._test_mono_column_pipeline(
|
||||||
@@ -236,6 +282,14 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
|||||||
expected_multi_result=expected_fill_mask_result,
|
expected_multi_result=expected_fill_mask_result,
|
||||||
expected_check_keys=["sequence"],
|
expected_check_keys=["sequence"],
|
||||||
)
|
)
|
||||||
|
self._test_mono_column_pipeline(
|
||||||
|
nlp,
|
||||||
|
valid_inputs[:1],
|
||||||
|
mandatory_keys,
|
||||||
|
expected_multi_result=expected_fill_mask_target_result,
|
||||||
|
expected_check_keys=["sequence"],
|
||||||
|
targets=valid_targets,
|
||||||
|
)
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
@slow
|
@slow
|
||||||
@@ -245,6 +299,7 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
|||||||
"My name is <mask>",
|
"My name is <mask>",
|
||||||
"The largest city in France is <mask>",
|
"The largest city in France is <mask>",
|
||||||
]
|
]
|
||||||
|
valid_targets = [" Patrick", " Clara"]
|
||||||
for model_name in LARGE_FILL_MASK_FINETUNED_MODELS:
|
for model_name in LARGE_FILL_MASK_FINETUNED_MODELS:
|
||||||
nlp = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf", topk=2)
|
nlp = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf", topk=2)
|
||||||
self._test_mono_column_pipeline(
|
self._test_mono_column_pipeline(
|
||||||
@@ -254,6 +309,14 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
|||||||
expected_multi_result=expected_fill_mask_result,
|
expected_multi_result=expected_fill_mask_result,
|
||||||
expected_check_keys=["sequence"],
|
expected_check_keys=["sequence"],
|
||||||
)
|
)
|
||||||
|
self._test_mono_column_pipeline(
|
||||||
|
nlp,
|
||||||
|
valid_inputs[:1],
|
||||||
|
mandatory_keys,
|
||||||
|
expected_multi_result=expected_fill_mask_target_result,
|
||||||
|
expected_check_keys=["sequence"],
|
||||||
|
targets=valid_targets,
|
||||||
|
)
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_torch_summarization(self):
|
def test_torch_summarization(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user