WIP reodering arguments for torchscript and TF
This commit is contained in:
@@ -126,8 +126,8 @@ class BertModelTest(CommonTestCases.CommonModelTester):
|
||||
def create_and_check_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
model = BertModel(config=config)
|
||||
model.eval()
|
||||
sequence_output, pooled_output = model(input_ids, token_type_ids, input_mask)
|
||||
sequence_output, pooled_output = model(input_ids, token_type_ids)
|
||||
sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
sequence_output, pooled_output = model(input_ids, token_type_ids=token_type_ids)
|
||||
sequence_output, pooled_output = model(input_ids)
|
||||
|
||||
result = {
|
||||
@@ -143,7 +143,7 @@ class BertModelTest(CommonTestCases.CommonModelTester):
|
||||
def create_and_check_bert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
model = BertForMaskedLM(config=config)
|
||||
model.eval()
|
||||
loss, prediction_scores = model(input_ids, token_type_ids, input_mask, token_labels)
|
||||
loss, prediction_scores = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"prediction_scores": prediction_scores,
|
||||
@@ -156,7 +156,7 @@ class BertModelTest(CommonTestCases.CommonModelTester):
|
||||
def create_and_check_bert_for_next_sequence_prediction(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
model = BertForNextSentencePrediction(config=config)
|
||||
model.eval()
|
||||
loss, seq_relationship_score = model(input_ids, token_type_ids, input_mask, sequence_labels)
|
||||
loss, seq_relationship_score = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, next_sentence_label=sequence_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"seq_relationship_score": seq_relationship_score,
|
||||
@@ -170,7 +170,8 @@ class BertModelTest(CommonTestCases.CommonModelTester):
|
||||
def create_and_check_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
model = BertForPreTraining(config=config)
|
||||
model.eval()
|
||||
loss, prediction_scores, seq_relationship_score = model(input_ids, token_type_ids, input_mask, token_labels, sequence_labels)
|
||||
loss, prediction_scores, seq_relationship_score = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids,
|
||||
masked_lm_labels=token_labels, next_sentence_label=sequence_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"prediction_scores": prediction_scores,
|
||||
@@ -188,7 +189,8 @@ class BertModelTest(CommonTestCases.CommonModelTester):
|
||||
def create_and_check_bert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
model = BertForQuestionAnswering(config=config)
|
||||
model.eval()
|
||||
loss, start_logits, end_logits = model(input_ids, token_type_ids, input_mask, sequence_labels, sequence_labels)
|
||||
loss, start_logits, end_logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids,
|
||||
start_positions=sequence_labels, end_positions=sequence_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"start_logits": start_logits,
|
||||
@@ -207,7 +209,7 @@ class BertModelTest(CommonTestCases.CommonModelTester):
|
||||
config.num_labels = self.num_labels
|
||||
model = BertForSequenceClassification(config)
|
||||
model.eval()
|
||||
loss, logits = model(input_ids, token_type_ids, input_mask, sequence_labels)
|
||||
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
@@ -222,7 +224,7 @@ class BertModelTest(CommonTestCases.CommonModelTester):
|
||||
config.num_labels = self.num_labels
|
||||
model = BertForTokenClassification(config=config)
|
||||
model.eval()
|
||||
loss, logits = model(input_ids, token_type_ids, input_mask, token_labels)
|
||||
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
@@ -241,9 +243,9 @@ class BertModelTest(CommonTestCases.CommonModelTester):
|
||||
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
loss, logits = model(multiple_choice_inputs_ids,
|
||||
multiple_choice_token_type_ids,
|
||||
multiple_choice_input_mask,
|
||||
choice_labels)
|
||||
attention_mask=multiple_choice_input_mask,
|
||||
token_type_ids=multiple_choice_token_type_ids,
|
||||
labels=choice_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
|
||||
@@ -148,7 +148,7 @@ class DistilBertModelTest(CommonTestCases.CommonModelTester):
|
||||
def create_and_check_distilbert_for_question_answering(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
model = DistilBertForQuestionAnswering(config=config)
|
||||
model.eval()
|
||||
loss, start_logits, end_logits = model(input_ids, input_mask, sequence_labels, sequence_labels)
|
||||
loss, start_logits, end_logits = model(input_ids, attention_mask=input_mask, start_positions=sequence_labels, end_positions=sequence_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"start_logits": start_logits,
|
||||
@@ -166,7 +166,7 @@ class DistilBertModelTest(CommonTestCases.CommonModelTester):
|
||||
config.num_labels = self.num_labels
|
||||
model = DistilBertForSequenceClassification(config)
|
||||
model.eval()
|
||||
loss, logits = model(input_ids, input_mask, sequence_labels)
|
||||
loss, logits = model(input_ids, attention_mask=input_mask, labels=sequence_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"logits": logits,
|
||||
|
||||
Reference in New Issue
Block a user