Use labels to remove deprecation warnings (#4807)
This commit is contained in:
@@ -164,7 +164,7 @@ class LongformerModelTester(object):
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, prediction_scores = model(
|
||||
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels
|
||||
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
@@ -361,7 +361,7 @@ class LongformerModelIntegrationTest(unittest.TestCase):
|
||||
[[0] + [20920, 232, 328, 1437] * 1000 + [2]], dtype=torch.long, device=torch_device
|
||||
) # long input
|
||||
|
||||
loss, prediction_scores = model(input_ids, masked_lm_labels=input_ids)
|
||||
loss, prediction_scores = model(input_ids, labels=input_ids)
|
||||
|
||||
expected_loss = torch.tensor(0.0620, device=torch_device)
|
||||
expected_prediction_scores_sum = torch.tensor(-6.1599e08, device=torch_device)
|
||||
|
||||
Reference in New Issue
Block a user