add targets arg to fill-mask pipeline (#6239)
* add targets arg to fill-mask pipeline * add tests and more error handling * quality * update docstring
This commit is contained in:
@@ -1158,12 +1158,17 @@ class FillMaskPipeline(Pipeline):
|
||||
f"No mask_token ({self.tokenizer.mask_token}) found on the input",
|
||||
)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
def __call__(self, *args, targets=None, **kwargs):
|
||||
"""
|
||||
Fill the masked token in the text(s) given as inputs.
|
||||
|
||||
Args:
|
||||
args (:obj:`str` or :obj:`List[str]`): One or several texts (or one list of prompts) with masked tokens.
|
||||
args (:obj:`str` or :obj:`List[str]`):
|
||||
One or several texts (or one list of prompts) with masked tokens.
|
||||
targets (:obj:`str` or :obj:`List[str]`, `optional`):
|
||||
When passed, the model will return the scores for the passed token or tokens rather than the top k
|
||||
predictions in the entire vocabulary. If the provided targets are not in the model vocab, they will
|
||||
be tokenized and the first resulting token will be used (with a warning).
|
||||
|
||||
Return:
|
||||
A list or a list of list of :obj:`dict`: Each result comes as list of dictionaries with the
|
||||
@@ -1180,6 +1185,24 @@ class FillMaskPipeline(Pipeline):
|
||||
results = []
|
||||
batch_size = outputs.shape[0] if self.framework == "tf" else outputs.size(0)
|
||||
|
||||
if targets is not None:
|
||||
if len(targets) == 0 or len(targets[0]) == 0:
|
||||
raise ValueError("At least one target must be provided when passed.")
|
||||
if isinstance(targets, str):
|
||||
targets = [targets]
|
||||
|
||||
targets_proc = []
|
||||
for target in targets:
|
||||
target_enc = self.tokenizer.tokenize(target)
|
||||
if len(target_enc) > 1 or target_enc[0] == self.tokenizer.unk_token:
|
||||
logger.warning(
|
||||
"The specified target token `{}` does not exist in the model vocabulary. Replacing with `{}`.".format(
|
||||
target, target_enc[0]
|
||||
)
|
||||
)
|
||||
targets_proc.append(target_enc[0])
|
||||
target_inds = np.array(self.tokenizer.convert_tokens_to_ids(targets_proc))
|
||||
|
||||
for i in range(batch_size):
|
||||
input_ids = inputs["input_ids"][i]
|
||||
result = []
|
||||
@@ -1192,8 +1215,14 @@ class FillMaskPipeline(Pipeline):
|
||||
|
||||
logits = outputs[i, masked_index.item(), :]
|
||||
probs = tf.nn.softmax(logits)
|
||||
topk = tf.math.top_k(probs, k=self.topk)
|
||||
values, predictions = topk.values.numpy(), topk.indices.numpy()
|
||||
if targets is None:
|
||||
topk = tf.math.top_k(probs, k=self.topk)
|
||||
values, predictions = topk.values.numpy(), topk.indices.numpy()
|
||||
else:
|
||||
values = tf.gather_nd(probs, tf.reshape(target_inds, (-1, 1)))
|
||||
sort_inds = tf.reverse(tf.argsort(values), [0])
|
||||
values = tf.gather_nd(values, tf.reshape(sort_inds, (-1, 1))).numpy()
|
||||
predictions = target_inds[sort_inds.numpy()]
|
||||
else:
|
||||
masked_index = (input_ids == self.tokenizer.mask_token_id).nonzero()
|
||||
|
||||
@@ -1202,7 +1231,13 @@ class FillMaskPipeline(Pipeline):
|
||||
|
||||
logits = outputs[i, masked_index.item(), :]
|
||||
probs = logits.softmax(dim=0)
|
||||
values, predictions = probs.topk(self.topk)
|
||||
if targets is None:
|
||||
values, predictions = probs.topk(self.topk)
|
||||
else:
|
||||
values = probs[..., target_inds]
|
||||
sort_inds = list(reversed(values.argsort(dim=-1)))
|
||||
values = values[..., sort_inds]
|
||||
predictions = target_inds[sort_inds]
|
||||
|
||||
for v, p in zip(values.tolist(), predictions.tolist()):
|
||||
tokens = input_ids.numpy()
|
||||
|
||||
Reference in New Issue
Block a user