From bc820476a5c72060f810f825298befd5ec85da4d Mon Sep 17 00:00:00 2001 From: Joe Davison Date: Wed, 12 Aug 2020 12:48:29 -0400 Subject: [PATCH] add targets arg to fill-mask pipeline (#6239) * add targets arg to fill-mask pipeline * add tests and more error handling * quality * update docstring --- src/transformers/pipelines.py | 45 +++++++++++++++++++++--- tests/test_pipelines.py | 65 ++++++++++++++++++++++++++++++++++- 2 files changed, 104 insertions(+), 6 deletions(-) diff --git a/src/transformers/pipelines.py b/src/transformers/pipelines.py index cbf79bdc9c..e2186c1e0d 100755 --- a/src/transformers/pipelines.py +++ b/src/transformers/pipelines.py @@ -1158,12 +1158,17 @@ class FillMaskPipeline(Pipeline): f"No mask_token ({self.tokenizer.mask_token}) found on the input", ) - def __call__(self, *args, **kwargs): + def __call__(self, *args, targets=None, **kwargs): """ Fill the masked token in the text(s) given as inputs. Args: - args (:obj:`str` or :obj:`List[str]`): One or several texts (or one list of prompts) with masked tokens. + args (:obj:`str` or :obj:`List[str]`): + One or several texts (or one list of prompts) with masked tokens. + targets (:obj:`str` or :obj:`List[str]`, `optional`): + When passed, the model will return the scores for the passed token or tokens rather than the top k + predictions in the entire vocabulary. If the provided targets are not in the model vocab, they will + be tokenized and the first resulting token will be used (with a warning). Return: A list or a list of list of :obj:`dict`: Each result comes as list of dictionaries with the @@ -1180,6 +1185,24 @@ class FillMaskPipeline(Pipeline): results = [] batch_size = outputs.shape[0] if self.framework == "tf" else outputs.size(0) + if targets is not None: + if len(targets) == 0 or len(targets[0]) == 0: + raise ValueError("At least one target must be provided when passed.") + if isinstance(targets, str): + targets = [targets] + + targets_proc = [] + for target in targets: + target_enc = self.tokenizer.tokenize(target) + if len(target_enc) > 1 or target_enc[0] == self.tokenizer.unk_token: + logger.warning( + "The specified target token `{}` does not exist in the model vocabulary. Replacing with `{}`.".format( + target, target_enc[0] + ) + ) + targets_proc.append(target_enc[0]) + target_inds = np.array(self.tokenizer.convert_tokens_to_ids(targets_proc)) + for i in range(batch_size): input_ids = inputs["input_ids"][i] result = [] @@ -1192,8 +1215,14 @@ class FillMaskPipeline(Pipeline): logits = outputs[i, masked_index.item(), :] probs = tf.nn.softmax(logits) - topk = tf.math.top_k(probs, k=self.topk) - values, predictions = topk.values.numpy(), topk.indices.numpy() + if targets is None: + topk = tf.math.top_k(probs, k=self.topk) + values, predictions = topk.values.numpy(), topk.indices.numpy() + else: + values = tf.gather_nd(probs, tf.reshape(target_inds, (-1, 1))) + sort_inds = tf.reverse(tf.argsort(values), [0]) + values = tf.gather_nd(values, tf.reshape(sort_inds, (-1, 1))).numpy() + predictions = target_inds[sort_inds.numpy()] else: masked_index = (input_ids == self.tokenizer.mask_token_id).nonzero() @@ -1202,7 +1231,13 @@ class FillMaskPipeline(Pipeline): logits = outputs[i, masked_index.item(), :] probs = logits.softmax(dim=0) - values, predictions = probs.topk(self.topk) + if targets is None: + values, predictions = probs.topk(self.topk) + else: + values = probs[..., target_inds] + sort_inds = list(reversed(values.argsort(dim=-1))) + values = values[..., sort_inds] + predictions = target_inds[sort_inds] for v, p in zip(values.tolist(), predictions.tolist()): tokens = input_ids.numpy() diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index cd11fcfb1c..205db94b4c 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -41,6 +41,23 @@ expected_fill_mask_result = [ ], ] +expected_fill_mask_target_result = [ + [ + { + "sequence": "My name is Patrick", + "score": 0.004992353264242411, + "token": 3499, + "token_str": "ĠPatrick", + }, + { + "sequence": "My name is Clara", + "score": 0.00019297805556561798, + "token": 13606, + "token_str": "ĠClara", + }, + ] +] + SUMMARIZATION_KWARGS = dict(num_beams=2, min_length=2, max_length=5) @@ -139,7 +156,7 @@ class MonoColumnInputTestCase(unittest.TestCase): for key in output_keys: self.assertIn(key, mono_result[0]) - multi_result = [nlp(input) for input in valid_inputs] + multi_result = [nlp(input, **kwargs) for input in valid_inputs] self.assertIsInstance(multi_result, list) self.assertIsInstance(multi_result[0], (dict, list)) @@ -219,6 +236,34 @@ class MonoColumnInputTestCase(unittest.TestCase): nlp, valid_inputs, mandatory_keys, invalid_inputs, expected_check_keys=["sequence"] ) + @require_torch + def test_torch_fill_mask_with_targets(self): + valid_inputs = ["My name is "] + valid_targets = [[" Teven", " Patrick", " Clara"], [" Sam"]] + invalid_targets = [[], [""], ""] + for model_name in FILL_MASK_FINETUNED_MODELS: + nlp = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt") + for targets in valid_targets: + outputs = nlp(valid_inputs, targets=targets) + self.assertIsInstance(outputs, list) + self.assertEqual(len(outputs), len(targets)) + for targets in invalid_targets: + self.assertRaises(ValueError, nlp, valid_inputs, targets=targets) + + @require_tf + def test_tf_fill_mask_with_targets(self): + valid_inputs = ["My name is "] + valid_targets = [[" Teven", " Patrick", " Clara"], [" Sam"]] + invalid_targets = [[], [""], ""] + for model_name in FILL_MASK_FINETUNED_MODELS: + nlp = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf") + for targets in valid_targets: + outputs = nlp(valid_inputs, targets=targets) + self.assertIsInstance(outputs, list) + self.assertEqual(len(outputs), len(targets)) + for targets in invalid_targets: + self.assertRaises(ValueError, nlp, valid_inputs, targets=targets) + @require_torch @slow def test_torch_fill_mask_results(self): @@ -227,6 +272,7 @@ class MonoColumnInputTestCase(unittest.TestCase): "My name is ", "The largest city in France is ", ] + valid_targets = [" Patrick", " Clara"] for model_name in LARGE_FILL_MASK_FINETUNED_MODELS: nlp = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt", topk=2,) self._test_mono_column_pipeline( @@ -236,6 +282,14 @@ class MonoColumnInputTestCase(unittest.TestCase): expected_multi_result=expected_fill_mask_result, expected_check_keys=["sequence"], ) + self._test_mono_column_pipeline( + nlp, + valid_inputs[:1], + mandatory_keys, + expected_multi_result=expected_fill_mask_target_result, + expected_check_keys=["sequence"], + targets=valid_targets, + ) @require_tf @slow @@ -245,6 +299,7 @@ class MonoColumnInputTestCase(unittest.TestCase): "My name is ", "The largest city in France is ", ] + valid_targets = [" Patrick", " Clara"] for model_name in LARGE_FILL_MASK_FINETUNED_MODELS: nlp = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf", topk=2) self._test_mono_column_pipeline( @@ -254,6 +309,14 @@ class MonoColumnInputTestCase(unittest.TestCase): expected_multi_result=expected_fill_mask_result, expected_check_keys=["sequence"], ) + self._test_mono_column_pipeline( + nlp, + valid_inputs[:1], + mandatory_keys, + expected_multi_result=expected_fill_mask_target_result, + expected_check_keys=["sequence"], + targets=valid_targets, + ) @require_torch def test_torch_summarization(self):