edit: cast attention_mask to long in DataCollatorCTCWithPadding (#19369)

* edit: casting attention_mask to long in DataCollatorCTCWithPadding

* edit: casting attention_mask to long in DataCollatorCTCWithPadding
This commit is contained in:
ddobokki
2022-10-07 23:05:48 +09:00
committed by GitHub
parent e9a49babee
commit fa4bcd5274

View File

@@ -317,6 +317,8 @@ class DataCollatorCTCWithPadding:
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
batch["labels"] = labels batch["labels"] = labels
if "attention_mask" in batch:
batch["attention_mask"] = batch["attention_mask"].to(torch.long)
return batch return batch