From ff65beafa35959d6a1ee764835064af75b67c8c5 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Thu, 22 Oct 2020 18:54:25 +0200 Subject: [PATCH] FillMaskPipeline: support passing top_k on __call__ (#7971) * FillMaskPipeline: support passing top_k on __call__ Also move from topk to top_k * migrate to new param name in tests * Review from @sgugger --- src/transformers/pipelines.py | 22 ++++++++++++++++------ tests/test_pipelines.py | 8 ++++---- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/src/transformers/pipelines.py b/src/transformers/pipelines.py index d80300492e..d76a5f9f1c 100755 --- a/src/transformers/pipelines.py +++ b/src/transformers/pipelines.py @@ -1183,7 +1183,7 @@ class ZeroShotClassificationPipeline(Pipeline): @add_end_docstrings( PIPELINE_INIT_ARGS, r""" - topk (:obj:`int`, defaults to 5): The number of predictions to return. + top_k (:obj:`int`, defaults to 5): The number of predictions to return. """, ) class FillMaskPipeline(Pipeline): @@ -1212,8 +1212,9 @@ class FillMaskPipeline(Pipeline): framework: Optional[str] = None, args_parser: ArgumentHandler = None, device: int = -1, - topk=5, + top_k=5, task: str = "", + **kwargs ): super().__init__( model=model, @@ -1228,7 +1229,14 @@ class FillMaskPipeline(Pipeline): self.check_model_type(TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_MASKED_LM_MAPPING) - self.topk = topk + if "topk" in kwargs: + warnings.warn( + "The `topk` argument is deprecated and will be removed in a future version, use `top_k` instead.", + FutureWarning, + ) + self.top_k = kwargs.pop("topk") + else: + self.top_k = top_k def ensure_exactly_one_mask_token(self, masked_index: np.ndarray): numel = np.prod(masked_index.shape) @@ -1245,7 +1253,7 @@ class FillMaskPipeline(Pipeline): f"No mask_token ({self.tokenizer.mask_token}) found on the input", ) - def __call__(self, *args, targets=None, **kwargs): + def __call__(self, *args, targets=None, top_k: Optional[int] = None, **kwargs): """ Fill the masked token in the text(s) given as inputs. @@ -1256,6 +1264,8 @@ class FillMaskPipeline(Pipeline): When passed, the model will return the scores for the passed token or tokens rather than the top k predictions in the entire vocabulary. If the provided targets are not in the model vocab, they will be tokenized and the first resulting token will be used (with a warning). + top_k (:obj:`int`, `optional`): + When passed, overrides the number of predictions to return. Return: A list or a list of list of :obj:`dict`: Each result comes as list of dictionaries with the @@ -1303,7 +1313,7 @@ class FillMaskPipeline(Pipeline): logits = outputs[i, masked_index.item(), :] probs = tf.nn.softmax(logits) if targets is None: - topk = tf.math.top_k(probs, k=self.topk) + topk = tf.math.top_k(probs, k=top_k if top_k is not None else self.top_k) values, predictions = topk.values.numpy(), topk.indices.numpy() else: values = tf.gather_nd(probs, tf.reshape(target_inds, (-1, 1))) @@ -1319,7 +1329,7 @@ class FillMaskPipeline(Pipeline): logits = outputs[i, masked_index.item(), :] probs = logits.softmax(dim=0) if targets is None: - values, predictions = probs.topk(self.topk) + values, predictions = probs.topk(top_k if top_k is not None else self.top_k) else: values = probs[..., target_inds] sort_inds = list(reversed(values.argsort(dim=-1))) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index a88637c694..b6165ec8f6 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -226,7 +226,7 @@ class MonoColumnInputTestCase(unittest.TestCase): model=model_name, tokenizer=model_name, framework="pt", - topk=2, + top_k=2, ) self._test_mono_column_pipeline( nlp, valid_inputs, mandatory_keys, invalid_inputs, expected_check_keys=["sequence"] @@ -249,7 +249,7 @@ class MonoColumnInputTestCase(unittest.TestCase): model=model_name, tokenizer=model_name, framework="tf", - topk=2, + top_k=2, ) self._test_mono_column_pipeline( nlp, valid_inputs, mandatory_keys, invalid_inputs, expected_check_keys=["sequence"] @@ -298,7 +298,7 @@ class MonoColumnInputTestCase(unittest.TestCase): model=model_name, tokenizer=model_name, framework="pt", - topk=2, + top_k=2, ) self._test_mono_column_pipeline( nlp, @@ -326,7 +326,7 @@ class MonoColumnInputTestCase(unittest.TestCase): ] valid_targets = [" Patrick", " Clara"] for model_name in LARGE_FILL_MASK_FINETUNED_MODELS: - nlp = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf", topk=2) + nlp = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf", top_k=2) self._test_mono_column_pipeline( nlp, valid_inputs,