[fix] slow fill_mask test failure (#5406)
This commit is contained in:
@@ -827,6 +827,7 @@ class FillMaskPipeline(Pipeline):
|
||||
values, predictions = topk.values.numpy(), topk.indices.numpy()
|
||||
else:
|
||||
masked_index = (input_ids == self.tokenizer.mask_token_id).nonzero().item()
|
||||
|
||||
logits = outputs[i, masked_index, :]
|
||||
probs = logits.softmax(dim=0)
|
||||
values, predictions = probs.topk(self.topk)
|
||||
|
||||
@@ -31,12 +31,12 @@ TF_TRANSLATION_FINETUNED_MODELS = [("patrickvonplaten/t5-tiny-random", "translat
|
||||
|
||||
expected_fill_mask_result = [
|
||||
[
|
||||
{"sequence": "<s> My name is:</s>", "score": 0.009954338893294334, "token": 35},
|
||||
{"sequence": "<s> My name is John</s>", "score": 0.0080940006300807, "token": 610},
|
||||
{"sequence": "<s>My name is John</s>", "score": 0.00782308354973793, "token": 610, "token_str": "ĠJohn"},
|
||||
{"sequence": "<s>My name is Chris</s>", "score": 0.007475061342120171, "token": 1573, "token_str": "ĠChris"},
|
||||
],
|
||||
[
|
||||
{"sequence": "<s> The largest city in France is Paris</s>", "score": 0.3185044229030609, "token": 2201},
|
||||
{"sequence": "<s> The largest city in France is Lyon</s>", "score": 0.21112334728240967, "token": 12790},
|
||||
{"sequence": "<s>The largest city in France is Paris</s>", "score": 0.3185044229030609, "token": 2201},
|
||||
{"sequence": "<s>The largest city in France is Lyon</s>", "score": 0.21112334728240967, "token": 12790},
|
||||
],
|
||||
]
|
||||
SUMMARIZATION_KWARGS = dict(num_beams=2, min_length=2, max_length=5)
|
||||
|
||||
Reference in New Issue
Block a user