edit: cast attention_mask to long in DataCollatorCTCWithPadding (#19369)
* edit: casting attention_mask to long in DataCollatorCTCWithPadding * edit: casting attention_mask to long in DataCollatorCTCWithPadding
This commit is contained in:
@@ -317,6 +317,8 @@ class DataCollatorCTCWithPadding:
|
||||
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
|
||||
|
||||
batch["labels"] = labels
|
||||
if "attention_mask" in batch:
|
||||
batch["attention_mask"] = batch["attention_mask"].to(torch.long)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
Reference in New Issue
Block a user