[DPR] Correct init (#13796)
* update * add to docs and init * make fix-copies
This commit is contained in:
committed by
GitHub
parent
44eb8bdeea
commit
41436d3dfb
@@ -41,6 +41,13 @@ DPRConfig
|
||||
:members:
|
||||
|
||||
|
||||
DPRPreTrainedModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.DPRPreTrainedModel
|
||||
:members:
|
||||
|
||||
|
||||
DPRContextEncoderTokenizer
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
||||
@@ -773,6 +773,7 @@ if is_torch_available():
|
||||
"DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"DPRContextEncoder",
|
||||
"DPRPretrainedContextEncoder",
|
||||
"DPRPreTrainedModel",
|
||||
"DPRPretrainedQuestionEncoder",
|
||||
"DPRPretrainedReader",
|
||||
"DPRQuestionEncoder",
|
||||
@@ -2512,6 +2513,7 @@ if TYPE_CHECKING:
|
||||
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
DPRContextEncoder,
|
||||
DPRPretrainedContextEncoder,
|
||||
DPRPreTrainedModel,
|
||||
DPRPretrainedQuestionEncoder,
|
||||
DPRPretrainedReader,
|
||||
DPRQuestionEncoder,
|
||||
|
||||
@@ -46,6 +46,7 @@ if is_torch_available():
|
||||
"DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"DPRContextEncoder",
|
||||
"DPRPretrainedContextEncoder",
|
||||
"DPRPreTrainedModel",
|
||||
"DPRPretrainedQuestionEncoder",
|
||||
"DPRPretrainedReader",
|
||||
"DPRQuestionEncoder",
|
||||
@@ -89,6 +90,7 @@ if TYPE_CHECKING:
|
||||
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
DPRContextEncoder,
|
||||
DPRPretrainedContextEncoder,
|
||||
DPRPreTrainedModel,
|
||||
DPRPretrainedQuestionEncoder,
|
||||
DPRPretrainedReader,
|
||||
DPRQuestionEncoder,
|
||||
|
||||
@@ -147,7 +147,29 @@ class DPRReaderOutput(ModelOutput):
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
class DPREncoder(PreTrainedModel):
|
||||
class DPRPreTrainedModel(PreTrainedModel):
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, nn.Linear):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, BertEncoder):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
|
||||
class DPREncoder(DPRPreTrainedModel):
|
||||
|
||||
base_model_prefix = "bert_model"
|
||||
|
||||
@@ -200,13 +222,8 @@ class DPREncoder(PreTrainedModel):
|
||||
return self.encode_proj.out_features
|
||||
return self.bert_model.config.hidden_size
|
||||
|
||||
def init_weights(self):
|
||||
self.bert_model.init_weights()
|
||||
if self.projection_dim > 0:
|
||||
self.encode_proj.apply(self.bert_model._init_weights)
|
||||
|
||||
|
||||
class DPRSpanPredictor(PreTrainedModel):
|
||||
class DPRSpanPredictor(DPRPreTrainedModel):
|
||||
|
||||
base_model_prefix = "encoder"
|
||||
|
||||
@@ -262,16 +279,13 @@ class DPRSpanPredictor(PreTrainedModel):
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
def init_weights(self):
|
||||
self.encoder.init_weights()
|
||||
|
||||
|
||||
##################
|
||||
# PreTrainedModel
|
||||
##################
|
||||
|
||||
|
||||
class DPRPretrainedContextEncoder(PreTrainedModel):
|
||||
class DPRPretrainedContextEncoder(DPRPreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
@@ -282,11 +296,8 @@ class DPRPretrainedContextEncoder(PreTrainedModel):
|
||||
base_model_prefix = "ctx_encoder"
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||
|
||||
def init_weights(self):
|
||||
self.ctx_encoder.init_weights()
|
||||
|
||||
|
||||
class DPRPretrainedQuestionEncoder(PreTrainedModel):
|
||||
class DPRPretrainedQuestionEncoder(DPRPreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
@@ -297,15 +308,8 @@ class DPRPretrainedQuestionEncoder(PreTrainedModel):
|
||||
base_model_prefix = "question_encoder"
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||
|
||||
def init_weights(self):
|
||||
self.question_encoder.init_weights()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, BertEncoder):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
|
||||
class DPRPretrainedReader(PreTrainedModel):
|
||||
class DPRPretrainedReader(DPRPreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
@@ -316,15 +320,6 @@ class DPRPretrainedReader(PreTrainedModel):
|
||||
base_model_prefix = "span_predictor"
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||
|
||||
def init_weights(self):
|
||||
self.span_predictor.encoder.init_weights()
|
||||
self.span_predictor.qa_classifier.apply(self.span_predictor.encoder.bert_model._init_weights)
|
||||
self.span_predictor.qa_outputs.apply(self.span_predictor.encoder.bert_model._init_weights)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, BertEncoder):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
|
||||
###############
|
||||
# Actual Models
|
||||
|
||||
@@ -1462,6 +1462,15 @@ class DPRPretrainedContextEncoder:
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class DPRPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class DPRPretrainedQuestionEncoder:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import DPRConfig, is_torch_available
|
||||
@@ -213,6 +214,19 @@ class DPRModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_reader(*config_and_inputs)
|
||||
|
||||
def test_init_changed_config(self):
|
||||
config = self.model_tester.prepare_config_and_inputs()[0]
|
||||
|
||||
model = DPRQuestionEncoder(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||
model.save_pretrained(tmp_dirname)
|
||||
model = DPRQuestionEncoder.from_pretrained(tmp_dirname, projection_dim=512)
|
||||
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
|
||||
Reference in New Issue
Block a user