From fa4bcd5274125278ce6f97438760b18640b83e62 Mon Sep 17 00:00:00 2001 From: ddobokki <44228269+ddobokki@users.noreply.github.com> Date: Fri, 7 Oct 2022 23:05:48 +0900 Subject: [PATCH] edit: cast attention_mask to long in DataCollatorCTCWithPadding (#19369) * edit: casting attention_mask to long in DataCollatorCTCWithPadding * edit: casting attention_mask to long in DataCollatorCTCWithPadding --- .../pytorch/speech-recognition/run_speech_recognition_ctc.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py b/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py index 54ea4e17f4..904a297c5a 100755 --- a/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py +++ b/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py @@ -317,6 +317,8 @@ class DataCollatorCTCWithPadding: labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) batch["labels"] = labels + if "attention_mask" in batch: + batch["attention_mask"] = batch["attention_mask"].to(torch.long) return batch