FillMaskPipeline: support passing top_k on __call__ (#7971)

* FillMaskPipeline: support passing top_k on __call__

Also move from topk to top_k

* migrate to new param name in tests

* Review from @sgugger
This commit is contained in:
Julien Chaumond
2020-10-22 18:54:25 +02:00
committed by GitHub
parent 2e5052d4f1
commit ff65beafa3
2 changed files with 20 additions and 10 deletions

View File

@@ -1183,7 +1183,7 @@ class ZeroShotClassificationPipeline(Pipeline):
@add_end_docstrings( @add_end_docstrings(
PIPELINE_INIT_ARGS, PIPELINE_INIT_ARGS,
r""" r"""
topk (:obj:`int`, defaults to 5): The number of predictions to return. top_k (:obj:`int`, defaults to 5): The number of predictions to return.
""", """,
) )
class FillMaskPipeline(Pipeline): class FillMaskPipeline(Pipeline):
@@ -1212,8 +1212,9 @@ class FillMaskPipeline(Pipeline):
framework: Optional[str] = None, framework: Optional[str] = None,
args_parser: ArgumentHandler = None, args_parser: ArgumentHandler = None,
device: int = -1, device: int = -1,
topk=5, top_k=5,
task: str = "", task: str = "",
**kwargs
): ):
super().__init__( super().__init__(
model=model, model=model,
@@ -1228,7 +1229,14 @@ class FillMaskPipeline(Pipeline):
self.check_model_type(TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_MASKED_LM_MAPPING) self.check_model_type(TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_MASKED_LM_MAPPING)
self.topk = topk if "topk" in kwargs:
warnings.warn(
"The `topk` argument is deprecated and will be removed in a future version, use `top_k` instead.",
FutureWarning,
)
self.top_k = kwargs.pop("topk")
else:
self.top_k = top_k
def ensure_exactly_one_mask_token(self, masked_index: np.ndarray): def ensure_exactly_one_mask_token(self, masked_index: np.ndarray):
numel = np.prod(masked_index.shape) numel = np.prod(masked_index.shape)
@@ -1245,7 +1253,7 @@ class FillMaskPipeline(Pipeline):
f"No mask_token ({self.tokenizer.mask_token}) found on the input", f"No mask_token ({self.tokenizer.mask_token}) found on the input",
) )
def __call__(self, *args, targets=None, **kwargs): def __call__(self, *args, targets=None, top_k: Optional[int] = None, **kwargs):
""" """
Fill the masked token in the text(s) given as inputs. Fill the masked token in the text(s) given as inputs.
@@ -1256,6 +1264,8 @@ class FillMaskPipeline(Pipeline):
When passed, the model will return the scores for the passed token or tokens rather than the top k When passed, the model will return the scores for the passed token or tokens rather than the top k
predictions in the entire vocabulary. If the provided targets are not in the model vocab, they will predictions in the entire vocabulary. If the provided targets are not in the model vocab, they will
be tokenized and the first resulting token will be used (with a warning). be tokenized and the first resulting token will be used (with a warning).
top_k (:obj:`int`, `optional`):
When passed, overrides the number of predictions to return.
Return: Return:
A list or a list of list of :obj:`dict`: Each result comes as list of dictionaries with the A list or a list of list of :obj:`dict`: Each result comes as list of dictionaries with the
@@ -1303,7 +1313,7 @@ class FillMaskPipeline(Pipeline):
logits = outputs[i, masked_index.item(), :] logits = outputs[i, masked_index.item(), :]
probs = tf.nn.softmax(logits) probs = tf.nn.softmax(logits)
if targets is None: if targets is None:
topk = tf.math.top_k(probs, k=self.topk) topk = tf.math.top_k(probs, k=top_k if top_k is not None else self.top_k)
values, predictions = topk.values.numpy(), topk.indices.numpy() values, predictions = topk.values.numpy(), topk.indices.numpy()
else: else:
values = tf.gather_nd(probs, tf.reshape(target_inds, (-1, 1))) values = tf.gather_nd(probs, tf.reshape(target_inds, (-1, 1)))
@@ -1319,7 +1329,7 @@ class FillMaskPipeline(Pipeline):
logits = outputs[i, masked_index.item(), :] logits = outputs[i, masked_index.item(), :]
probs = logits.softmax(dim=0) probs = logits.softmax(dim=0)
if targets is None: if targets is None:
values, predictions = probs.topk(self.topk) values, predictions = probs.topk(top_k if top_k is not None else self.top_k)
else: else:
values = probs[..., target_inds] values = probs[..., target_inds]
sort_inds = list(reversed(values.argsort(dim=-1))) sort_inds = list(reversed(values.argsort(dim=-1)))

View File

@@ -226,7 +226,7 @@ class MonoColumnInputTestCase(unittest.TestCase):
model=model_name, model=model_name,
tokenizer=model_name, tokenizer=model_name,
framework="pt", framework="pt",
topk=2, top_k=2,
) )
self._test_mono_column_pipeline( self._test_mono_column_pipeline(
nlp, valid_inputs, mandatory_keys, invalid_inputs, expected_check_keys=["sequence"] nlp, valid_inputs, mandatory_keys, invalid_inputs, expected_check_keys=["sequence"]
@@ -249,7 +249,7 @@ class MonoColumnInputTestCase(unittest.TestCase):
model=model_name, model=model_name,
tokenizer=model_name, tokenizer=model_name,
framework="tf", framework="tf",
topk=2, top_k=2,
) )
self._test_mono_column_pipeline( self._test_mono_column_pipeline(
nlp, valid_inputs, mandatory_keys, invalid_inputs, expected_check_keys=["sequence"] nlp, valid_inputs, mandatory_keys, invalid_inputs, expected_check_keys=["sequence"]
@@ -298,7 +298,7 @@ class MonoColumnInputTestCase(unittest.TestCase):
model=model_name, model=model_name,
tokenizer=model_name, tokenizer=model_name,
framework="pt", framework="pt",
topk=2, top_k=2,
) )
self._test_mono_column_pipeline( self._test_mono_column_pipeline(
nlp, nlp,
@@ -326,7 +326,7 @@ class MonoColumnInputTestCase(unittest.TestCase):
] ]
valid_targets = [" Patrick", " Clara"] valid_targets = [" Patrick", " Clara"]
for model_name in LARGE_FILL_MASK_FINETUNED_MODELS: for model_name in LARGE_FILL_MASK_FINETUNED_MODELS:
nlp = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf", topk=2) nlp = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf", top_k=2)
self._test_mono_column_pipeline( self._test_mono_column_pipeline(
nlp, nlp,
valid_inputs, valid_inputs,