Use labels to remove deprecation warnings (#4807)
This commit is contained in:
@@ -151,7 +151,7 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
model = DistilBertForMaskedLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, prediction_scores = model(input_ids, attention_mask=input_mask, masked_lm_labels=token_labels)
|
||||
loss, prediction_scores = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"prediction_scores": prediction_scores,
|
||||
|
||||
Reference in New Issue
Block a user