Use labels to remove deprecation warnings (#4807)
This commit is contained in:
@@ -169,7 +169,7 @@ class OpenAIGPTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
loss, lm_logits, mc_logits = model(input_ids, token_type_ids=token_type_ids, lm_labels=input_ids)
|
||||
loss, lm_logits, mc_logits = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
|
||||
|
||||
result = {"loss": loss, "lm_logits": lm_logits}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user