fix labels (#6213)
This commit is contained in:
@@ -87,6 +87,7 @@ class DataCollatorForLanguageModeling:
|
|||||||
return {"input_ids": inputs, "labels": labels}
|
return {"input_ids": inputs, "labels": labels}
|
||||||
else:
|
else:
|
||||||
labels = batch.clone().detach()
|
labels = batch.clone().detach()
|
||||||
|
if self.tokenizer.pad_token_id is not None:
|
||||||
labels[labels == self.tokenizer.pad_token_id] = -100
|
labels[labels == self.tokenizer.pad_token_id] = -100
|
||||||
return {"input_ids": batch, "labels": labels}
|
return {"input_ids": batch, "labels": labels}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user