[fix] slow fill_mask test failure (#5406)
This commit is contained in:
@@ -827,6 +827,7 @@ class FillMaskPipeline(Pipeline):
|
|||||||
values, predictions = topk.values.numpy(), topk.indices.numpy()
|
values, predictions = topk.values.numpy(), topk.indices.numpy()
|
||||||
else:
|
else:
|
||||||
masked_index = (input_ids == self.tokenizer.mask_token_id).nonzero().item()
|
masked_index = (input_ids == self.tokenizer.mask_token_id).nonzero().item()
|
||||||
|
|
||||||
logits = outputs[i, masked_index, :]
|
logits = outputs[i, masked_index, :]
|
||||||
probs = logits.softmax(dim=0)
|
probs = logits.softmax(dim=0)
|
||||||
values, predictions = probs.topk(self.topk)
|
values, predictions = probs.topk(self.topk)
|
||||||
|
|||||||
@@ -31,8 +31,8 @@ TF_TRANSLATION_FINETUNED_MODELS = [("patrickvonplaten/t5-tiny-random", "translat
|
|||||||
|
|
||||||
expected_fill_mask_result = [
|
expected_fill_mask_result = [
|
||||||
[
|
[
|
||||||
{"sequence": "<s> My name is:</s>", "score": 0.009954338893294334, "token": 35},
|
{"sequence": "<s>My name is John</s>", "score": 0.00782308354973793, "token": 610, "token_str": "ĠJohn"},
|
||||||
{"sequence": "<s> My name is John</s>", "score": 0.0080940006300807, "token": 610},
|
{"sequence": "<s>My name is Chris</s>", "score": 0.007475061342120171, "token": 1573, "token_str": "ĠChris"},
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
{"sequence": "<s>The largest city in France is Paris</s>", "score": 0.3185044229030609, "token": 2201},
|
{"sequence": "<s>The largest city in France is Paris</s>", "score": 0.3185044229030609, "token": 2201},
|
||||||
|
|||||||
Reference in New Issue
Block a user