FillMaskPipeline: support passing top_k on __call__ (#7971)
* FillMaskPipeline: support passing top_k on __call__ Also move from topk to top_k * migrate to new param name in tests * Review from @sgugger
This commit is contained in:
@@ -1183,7 +1183,7 @@ class ZeroShotClassificationPipeline(Pipeline):
|
|||||||
@add_end_docstrings(
|
@add_end_docstrings(
|
||||||
PIPELINE_INIT_ARGS,
|
PIPELINE_INIT_ARGS,
|
||||||
r"""
|
r"""
|
||||||
topk (:obj:`int`, defaults to 5): The number of predictions to return.
|
top_k (:obj:`int`, defaults to 5): The number of predictions to return.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
class FillMaskPipeline(Pipeline):
|
class FillMaskPipeline(Pipeline):
|
||||||
@@ -1212,8 +1212,9 @@ class FillMaskPipeline(Pipeline):
|
|||||||
framework: Optional[str] = None,
|
framework: Optional[str] = None,
|
||||||
args_parser: ArgumentHandler = None,
|
args_parser: ArgumentHandler = None,
|
||||||
device: int = -1,
|
device: int = -1,
|
||||||
topk=5,
|
top_k=5,
|
||||||
task: str = "",
|
task: str = "",
|
||||||
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model=model,
|
model=model,
|
||||||
@@ -1228,7 +1229,14 @@ class FillMaskPipeline(Pipeline):
|
|||||||
|
|
||||||
self.check_model_type(TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_MASKED_LM_MAPPING)
|
self.check_model_type(TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_MASKED_LM_MAPPING)
|
||||||
|
|
||||||
self.topk = topk
|
if "topk" in kwargs:
|
||||||
|
warnings.warn(
|
||||||
|
"The `topk` argument is deprecated and will be removed in a future version, use `top_k` instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
self.top_k = kwargs.pop("topk")
|
||||||
|
else:
|
||||||
|
self.top_k = top_k
|
||||||
|
|
||||||
def ensure_exactly_one_mask_token(self, masked_index: np.ndarray):
|
def ensure_exactly_one_mask_token(self, masked_index: np.ndarray):
|
||||||
numel = np.prod(masked_index.shape)
|
numel = np.prod(masked_index.shape)
|
||||||
@@ -1245,7 +1253,7 @@ class FillMaskPipeline(Pipeline):
|
|||||||
f"No mask_token ({self.tokenizer.mask_token}) found on the input",
|
f"No mask_token ({self.tokenizer.mask_token}) found on the input",
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, *args, targets=None, **kwargs):
|
def __call__(self, *args, targets=None, top_k: Optional[int] = None, **kwargs):
|
||||||
"""
|
"""
|
||||||
Fill the masked token in the text(s) given as inputs.
|
Fill the masked token in the text(s) given as inputs.
|
||||||
|
|
||||||
@@ -1256,6 +1264,8 @@ class FillMaskPipeline(Pipeline):
|
|||||||
When passed, the model will return the scores for the passed token or tokens rather than the top k
|
When passed, the model will return the scores for the passed token or tokens rather than the top k
|
||||||
predictions in the entire vocabulary. If the provided targets are not in the model vocab, they will
|
predictions in the entire vocabulary. If the provided targets are not in the model vocab, they will
|
||||||
be tokenized and the first resulting token will be used (with a warning).
|
be tokenized and the first resulting token will be used (with a warning).
|
||||||
|
top_k (:obj:`int`, `optional`):
|
||||||
|
When passed, overrides the number of predictions to return.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
A list or a list of list of :obj:`dict`: Each result comes as list of dictionaries with the
|
A list or a list of list of :obj:`dict`: Each result comes as list of dictionaries with the
|
||||||
@@ -1303,7 +1313,7 @@ class FillMaskPipeline(Pipeline):
|
|||||||
logits = outputs[i, masked_index.item(), :]
|
logits = outputs[i, masked_index.item(), :]
|
||||||
probs = tf.nn.softmax(logits)
|
probs = tf.nn.softmax(logits)
|
||||||
if targets is None:
|
if targets is None:
|
||||||
topk = tf.math.top_k(probs, k=self.topk)
|
topk = tf.math.top_k(probs, k=top_k if top_k is not None else self.top_k)
|
||||||
values, predictions = topk.values.numpy(), topk.indices.numpy()
|
values, predictions = topk.values.numpy(), topk.indices.numpy()
|
||||||
else:
|
else:
|
||||||
values = tf.gather_nd(probs, tf.reshape(target_inds, (-1, 1)))
|
values = tf.gather_nd(probs, tf.reshape(target_inds, (-1, 1)))
|
||||||
@@ -1319,7 +1329,7 @@ class FillMaskPipeline(Pipeline):
|
|||||||
logits = outputs[i, masked_index.item(), :]
|
logits = outputs[i, masked_index.item(), :]
|
||||||
probs = logits.softmax(dim=0)
|
probs = logits.softmax(dim=0)
|
||||||
if targets is None:
|
if targets is None:
|
||||||
values, predictions = probs.topk(self.topk)
|
values, predictions = probs.topk(top_k if top_k is not None else self.top_k)
|
||||||
else:
|
else:
|
||||||
values = probs[..., target_inds]
|
values = probs[..., target_inds]
|
||||||
sort_inds = list(reversed(values.argsort(dim=-1)))
|
sort_inds = list(reversed(values.argsort(dim=-1)))
|
||||||
|
|||||||
@@ -226,7 +226,7 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
|||||||
model=model_name,
|
model=model_name,
|
||||||
tokenizer=model_name,
|
tokenizer=model_name,
|
||||||
framework="pt",
|
framework="pt",
|
||||||
topk=2,
|
top_k=2,
|
||||||
)
|
)
|
||||||
self._test_mono_column_pipeline(
|
self._test_mono_column_pipeline(
|
||||||
nlp, valid_inputs, mandatory_keys, invalid_inputs, expected_check_keys=["sequence"]
|
nlp, valid_inputs, mandatory_keys, invalid_inputs, expected_check_keys=["sequence"]
|
||||||
@@ -249,7 +249,7 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
|||||||
model=model_name,
|
model=model_name,
|
||||||
tokenizer=model_name,
|
tokenizer=model_name,
|
||||||
framework="tf",
|
framework="tf",
|
||||||
topk=2,
|
top_k=2,
|
||||||
)
|
)
|
||||||
self._test_mono_column_pipeline(
|
self._test_mono_column_pipeline(
|
||||||
nlp, valid_inputs, mandatory_keys, invalid_inputs, expected_check_keys=["sequence"]
|
nlp, valid_inputs, mandatory_keys, invalid_inputs, expected_check_keys=["sequence"]
|
||||||
@@ -298,7 +298,7 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
|||||||
model=model_name,
|
model=model_name,
|
||||||
tokenizer=model_name,
|
tokenizer=model_name,
|
||||||
framework="pt",
|
framework="pt",
|
||||||
topk=2,
|
top_k=2,
|
||||||
)
|
)
|
||||||
self._test_mono_column_pipeline(
|
self._test_mono_column_pipeline(
|
||||||
nlp,
|
nlp,
|
||||||
@@ -326,7 +326,7 @@ class MonoColumnInputTestCase(unittest.TestCase):
|
|||||||
]
|
]
|
||||||
valid_targets = [" Patrick", " Clara"]
|
valid_targets = [" Patrick", " Clara"]
|
||||||
for model_name in LARGE_FILL_MASK_FINETUNED_MODELS:
|
for model_name in LARGE_FILL_MASK_FINETUNED_MODELS:
|
||||||
nlp = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf", topk=2)
|
nlp = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf", top_k=2)
|
||||||
self._test_mono_column_pipeline(
|
self._test_mono_column_pipeline(
|
||||||
nlp,
|
nlp,
|
||||||
valid_inputs,
|
valid_inputs,
|
||||||
|
|||||||
Reference in New Issue
Block a user