fix labels (#6213)

This commit is contained in:
Suraj Patil
2020-08-03 19:49:35 +05:30
committed by GitHub
parent cedc547e7e
commit 0b41867357

View File

@@ -87,6 +87,7 @@ class DataCollatorForLanguageModeling:
return {"input_ids": inputs, "labels": labels} return {"input_ids": inputs, "labels": labels}
else: else:
labels = batch.clone().detach() labels = batch.clone().detach()
if self.tokenizer.pad_token_id is not None:
labels[labels == self.tokenizer.pad_token_id] = -100 labels[labels == self.tokenizer.pad_token_id] = -100
return {"input_ids": batch, "labels": labels} return {"input_ids": batch, "labels": labels}