This commit is contained in:
Patrick von Platen
2021-03-22 10:32:21 +03:00
committed by GitHub
parent 82b8d8c7b0
commit 0f226f78ce

View File

@@ -401,7 +401,7 @@ def evaluate(batch):
with torch.no_grad():
logits = model(inputs.input_values.to("cuda"), attention_mask=inputs.attention_mask.to("cuda")).logits
pred_ids = torch.argmax(logits, dim=-1)
pred_ids = torch.argmax(logits, dim=-1)
batch["pred_strings"] = processor.batch_decode(pred_ids)
return batch