Use labels to remove deprecation warnings (#4807)
This commit is contained in:
@@ -218,7 +218,7 @@ class BertModelTester:
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, prediction_scores = model(
|
||||
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels
|
||||
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
@@ -248,7 +248,7 @@ class BertModelTester:
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
masked_lm_labels=token_labels,
|
||||
labels=token_labels,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
@@ -256,7 +256,7 @@ class BertModelTester:
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
masked_lm_labels=token_labels,
|
||||
labels=token_labels,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
result = {
|
||||
@@ -294,7 +294,7 @@ class BertModelTester:
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
masked_lm_labels=token_labels,
|
||||
labels=token_labels,
|
||||
next_sentence_label=sequence_labels,
|
||||
)
|
||||
result = {
|
||||
|
||||
Reference in New Issue
Block a user