Use labels to remove deprecation warnings (#4807)

This commit is contained in:
Sylvain Gugger
2020-06-05 16:41:46 -04:00
committed by GitHub
parent 5c0cfc2cf0
commit f1fe18465d
10 changed files with 17 additions and 17 deletions

View File

@@ -151,7 +151,7 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
model = DistilBertForMaskedLM(config=config)
model.to(torch_device)
model.eval()
loss, prediction_scores = model(input_ids, attention_mask=input_mask, masked_lm_labels=token_labels)
loss, prediction_scores = model(input_ids, attention_mask=input_mask, labels=token_labels)
result = {
"loss": loss,
"prediction_scores": prediction_scores,