Fix dpr<>bart config for RAG (#8808)
* correct dpr test and bert pos fault * fix dpr bert config problem * fix layoutlm * add config to dpr as well
This commit is contained in:
committed by
GitHub
parent
a2cf37595e
commit
a7d46a0609
@@ -26,7 +26,7 @@ from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import BertConfig, DPRConfig, DPRContextEncoder, DPRQuestionEncoder, DPRReader
|
||||
from transformers import DPRConfig, DPRContextEncoder, DPRQuestionEncoder, DPRReader
|
||||
from transformers.models.dpr.modeling_dpr import (
|
||||
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
@@ -104,7 +104,8 @@ class DPRModelTester:
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||
|
||||
config = BertConfig(
|
||||
config = DPRConfig(
|
||||
projection_dim=self.projection_dim,
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
@@ -115,14 +116,12 @@ class DPRModelTester:
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
)
|
||||
config = DPRConfig(projection_dim=self.projection_dim, **config.to_dict())
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
||||
def create_and_check_dpr_context_encoder(
|
||||
def create_and_check_context_encoder(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = DPRContextEncoder(config=config)
|
||||
@@ -133,7 +132,7 @@ class DPRModelTester:
|
||||
result = model(input_ids)
|
||||
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.projection_dim or self.hidden_size))
|
||||
|
||||
def create_and_check_dpr_question_encoder(
|
||||
def create_and_check_question_encoder(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = DPRQuestionEncoder(config=config)
|
||||
@@ -144,7 +143,7 @@ class DPRModelTester:
|
||||
result = model(input_ids)
|
||||
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.projection_dim or self.hidden_size))
|
||||
|
||||
def create_and_check_dpr_reader(
|
||||
def create_and_check_reader(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = DPRReader(config=config)
|
||||
@@ -199,17 +198,17 @@ class DPRModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_dpr_context_encoder_model(self):
|
||||
def test_context_encoder_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_dpr_context_encoder(*config_and_inputs)
|
||||
self.model_tester.create_and_check_context_encoder(*config_and_inputs)
|
||||
|
||||
def test_dpr_question_encoder_model(self):
|
||||
def test_question_encoder_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_dpr_question_encoder(*config_and_inputs)
|
||||
self.model_tester.create_and_check_question_encoder(*config_and_inputs)
|
||||
|
||||
def test_dpr_reader_model(self):
|
||||
def test_reader_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_dpr_reader(*config_and_inputs)
|
||||
self.model_tester.create_and_check_reader(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
|
||||
Reference in New Issue
Block a user