Use labels to remove deprecation warnings (#4807)

This commit is contained in:
Sylvain Gugger
2020-06-05 16:41:46 -04:00
committed by GitHub
parent 5c0cfc2cf0
commit f1fe18465d
10 changed files with 17 additions and 17 deletions

View File

@@ -218,7 +218,7 @@ class BertModelTester:
model.to(torch_device)
model.eval()
loss, prediction_scores = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels
)
result = {
"loss": loss,
@@ -248,7 +248,7 @@ class BertModelTester:
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
masked_lm_labels=token_labels,
labels=token_labels,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
)
@@ -256,7 +256,7 @@ class BertModelTester:
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
masked_lm_labels=token_labels,
labels=token_labels,
encoder_hidden_states=encoder_hidden_states,
)
result = {
@@ -294,7 +294,7 @@ class BertModelTester:
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
masked_lm_labels=token_labels,
labels=token_labels,
next_sentence_label=sequence_labels,
)
result = {