edit: cast attention_mask to long in DataCollatorCTCWithPadding (#19369)
* edit: casting attention_mask to long in DataCollatorCTCWithPadding * edit: casting attention_mask to long in DataCollatorCTCWithPadding
This commit is contained in:
@@ -317,6 +317,8 @@ class DataCollatorCTCWithPadding:
|
|||||||
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
|
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
|
||||||
|
|
||||||
batch["labels"] = labels
|
batch["labels"] = labels
|
||||||
|
if "attention_mask" in batch:
|
||||||
|
batch["attention_mask"] = batch["attention_mask"].to(torch.long)
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user