add a placeholder test
This commit is contained in:
@@ -64,7 +64,7 @@ if is_torch_available():
|
|||||||
BertForMaskedLM, BertForNextSentencePrediction,
|
BertForMaskedLM, BertForNextSentencePrediction,
|
||||||
BertForSequenceClassification, BertForMultipleChoice,
|
BertForSequenceClassification, BertForMultipleChoice,
|
||||||
BertForTokenClassification, BertForQuestionAnswering,
|
BertForTokenClassification, BertForQuestionAnswering,
|
||||||
load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, Bert2Bert)
|
||||||
from .modeling_openai import (OpenAIGPTPreTrainedModel, OpenAIGPTModel,
|
from .modeling_openai import (OpenAIGPTPreTrainedModel, OpenAIGPTModel,
|
||||||
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel,
|
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel,
|
||||||
load_tf_weights_in_openai_gpt, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
load_tf_weights_in_openai_gpt, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ if is_torch_available():
|
|||||||
from transformers import (BertConfig, BertModel, BertForMaskedLM,
|
from transformers import (BertConfig, BertModel, BertForMaskedLM,
|
||||||
BertForNextSentencePrediction, BertForPreTraining,
|
BertForNextSentencePrediction, BertForPreTraining,
|
||||||
BertForQuestionAnswering, BertForSequenceClassification,
|
BertForQuestionAnswering, BertForSequenceClassification,
|
||||||
BertForTokenClassification, BertForMultipleChoice)
|
BertForTokenClassification, BertForMultipleChoice, Bert2Bert)
|
||||||
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
else:
|
else:
|
||||||
pytestmark = pytest.mark.skip("Require Torch")
|
pytestmark = pytest.mark.skip("Require Torch")
|
||||||
@@ -145,7 +145,6 @@ class BertModelTest(CommonTestCases.CommonModelTester):
|
|||||||
[self.batch_size, self.seq_length, self.hidden_size])
|
[self.batch_size, self.seq_length, self.hidden_size])
|
||||||
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
|
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
|
||||||
|
|
||||||
|
|
||||||
def create_and_check_bert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
def create_and_check_bert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||||
model = BertForMaskedLM(config=config)
|
model = BertForMaskedLM(config=config)
|
||||||
model.eval()
|
model.eval()
|
||||||
@@ -172,7 +171,6 @@ class BertModelTest(CommonTestCases.CommonModelTester):
|
|||||||
[self.batch_size, 2])
|
[self.batch_size, 2])
|
||||||
self.check_loss_output(result)
|
self.check_loss_output(result)
|
||||||
|
|
||||||
|
|
||||||
def create_and_check_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
def create_and_check_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||||
model = BertForPreTraining(config=config)
|
model = BertForPreTraining(config=config)
|
||||||
model.eval()
|
model.eval()
|
||||||
@@ -191,7 +189,6 @@ class BertModelTest(CommonTestCases.CommonModelTester):
|
|||||||
[self.batch_size, 2])
|
[self.batch_size, 2])
|
||||||
self.check_loss_output(result)
|
self.check_loss_output(result)
|
||||||
|
|
||||||
|
|
||||||
def create_and_check_bert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
def create_and_check_bert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||||
model = BertForQuestionAnswering(config=config)
|
model = BertForQuestionAnswering(config=config)
|
||||||
model.eval()
|
model.eval()
|
||||||
@@ -210,7 +207,6 @@ class BertModelTest(CommonTestCases.CommonModelTester):
|
|||||||
[self.batch_size, self.seq_length])
|
[self.batch_size, self.seq_length])
|
||||||
self.check_loss_output(result)
|
self.check_loss_output(result)
|
||||||
|
|
||||||
|
|
||||||
def create_and_check_bert_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
def create_and_check_bert_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||||
config.num_labels = self.num_labels
|
config.num_labels = self.num_labels
|
||||||
model = BertForSequenceClassification(config)
|
model = BertForSequenceClassification(config)
|
||||||
@@ -225,7 +221,6 @@ class BertModelTest(CommonTestCases.CommonModelTester):
|
|||||||
[self.batch_size, self.num_labels])
|
[self.batch_size, self.num_labels])
|
||||||
self.check_loss_output(result)
|
self.check_loss_output(result)
|
||||||
|
|
||||||
|
|
||||||
def create_and_check_bert_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
def create_and_check_bert_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||||
config.num_labels = self.num_labels
|
config.num_labels = self.num_labels
|
||||||
model = BertForTokenClassification(config=config)
|
model = BertForTokenClassification(config=config)
|
||||||
@@ -240,7 +235,6 @@ class BertModelTest(CommonTestCases.CommonModelTester):
|
|||||||
[self.batch_size, self.seq_length, self.num_labels])
|
[self.batch_size, self.seq_length, self.num_labels])
|
||||||
self.check_loss_output(result)
|
self.check_loss_output(result)
|
||||||
|
|
||||||
|
|
||||||
def create_and_check_bert_for_multiple_choice(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
def create_and_check_bert_for_multiple_choice(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||||
config.num_choices = self.num_choices
|
config.num_choices = self.num_choices
|
||||||
model = BertForMultipleChoice(config=config)
|
model = BertForMultipleChoice(config=config)
|
||||||
@@ -261,6 +255,16 @@ class BertModelTest(CommonTestCases.CommonModelTester):
|
|||||||
[self.batch_size, self.num_choices])
|
[self.batch_size, self.num_choices])
|
||||||
self.check_loss_output(result)
|
self.check_loss_output(result)
|
||||||
|
|
||||||
|
def create_and_check_bert2bert(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||||
|
config.num_choices = self.num_choices
|
||||||
|
model = Bert2Bert(config=config)
|
||||||
|
model.eval()
|
||||||
|
bert2bert_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||||
|
bert2bert_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||||
|
bert2bert_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||||
|
_ = model(bert2bert_inputs_ids,
|
||||||
|
attention_mask=bert2bert_input_mask,
|
||||||
|
token_type_ids=bert2bert_token_type_ids)
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
@@ -316,5 +320,6 @@ class BertModelTest(CommonTestCases.CommonModelTester):
|
|||||||
shutil.rmtree(cache_dir)
|
shutil.rmtree(cache_dir)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user