[DPR] Correct init (#13796)
* update * add to docs and init * make fix-copies
This commit is contained in:
committed by
GitHub
parent
44eb8bdeea
commit
41436d3dfb
@@ -41,6 +41,13 @@ DPRConfig
|
|||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
DPRPreTrainedModel
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.DPRPreTrainedModel
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
DPRContextEncoderTokenizer
|
DPRContextEncoderTokenizer
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|||||||
@@ -773,6 +773,7 @@ if is_torch_available():
|
|||||||
"DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
"DPRContextEncoder",
|
"DPRContextEncoder",
|
||||||
"DPRPretrainedContextEncoder",
|
"DPRPretrainedContextEncoder",
|
||||||
|
"DPRPreTrainedModel",
|
||||||
"DPRPretrainedQuestionEncoder",
|
"DPRPretrainedQuestionEncoder",
|
||||||
"DPRPretrainedReader",
|
"DPRPretrainedReader",
|
||||||
"DPRQuestionEncoder",
|
"DPRQuestionEncoder",
|
||||||
@@ -2512,6 +2513,7 @@ if TYPE_CHECKING:
|
|||||||
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
DPRContextEncoder,
|
DPRContextEncoder,
|
||||||
DPRPretrainedContextEncoder,
|
DPRPretrainedContextEncoder,
|
||||||
|
DPRPreTrainedModel,
|
||||||
DPRPretrainedQuestionEncoder,
|
DPRPretrainedQuestionEncoder,
|
||||||
DPRPretrainedReader,
|
DPRPretrainedReader,
|
||||||
DPRQuestionEncoder,
|
DPRQuestionEncoder,
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ if is_torch_available():
|
|||||||
"DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
"DPRContextEncoder",
|
"DPRContextEncoder",
|
||||||
"DPRPretrainedContextEncoder",
|
"DPRPretrainedContextEncoder",
|
||||||
|
"DPRPreTrainedModel",
|
||||||
"DPRPretrainedQuestionEncoder",
|
"DPRPretrainedQuestionEncoder",
|
||||||
"DPRPretrainedReader",
|
"DPRPretrainedReader",
|
||||||
"DPRQuestionEncoder",
|
"DPRQuestionEncoder",
|
||||||
@@ -89,6 +90,7 @@ if TYPE_CHECKING:
|
|||||||
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
DPRContextEncoder,
|
DPRContextEncoder,
|
||||||
DPRPretrainedContextEncoder,
|
DPRPretrainedContextEncoder,
|
||||||
|
DPRPreTrainedModel,
|
||||||
DPRPretrainedQuestionEncoder,
|
DPRPretrainedQuestionEncoder,
|
||||||
DPRPretrainedReader,
|
DPRPretrainedReader,
|
||||||
DPRQuestionEncoder,
|
DPRQuestionEncoder,
|
||||||
|
|||||||
@@ -147,7 +147,29 @@ class DPRReaderOutput(ModelOutput):
|
|||||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
|
||||||
|
|
||||||
class DPREncoder(PreTrainedModel):
|
class DPRPreTrainedModel(PreTrainedModel):
|
||||||
|
def _init_weights(self, module):
|
||||||
|
"""Initialize the weights"""
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||||
|
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||||
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||||
|
if module.bias is not None:
|
||||||
|
module.bias.data.zero_()
|
||||||
|
elif isinstance(module, nn.Embedding):
|
||||||
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||||
|
if module.padding_idx is not None:
|
||||||
|
module.weight.data[module.padding_idx].zero_()
|
||||||
|
elif isinstance(module, nn.LayerNorm):
|
||||||
|
module.bias.data.zero_()
|
||||||
|
module.weight.data.fill_(1.0)
|
||||||
|
|
||||||
|
def _set_gradient_checkpointing(self, module, value=False):
|
||||||
|
if isinstance(module, BertEncoder):
|
||||||
|
module.gradient_checkpointing = value
|
||||||
|
|
||||||
|
|
||||||
|
class DPREncoder(DPRPreTrainedModel):
|
||||||
|
|
||||||
base_model_prefix = "bert_model"
|
base_model_prefix = "bert_model"
|
||||||
|
|
||||||
@@ -200,13 +222,8 @@ class DPREncoder(PreTrainedModel):
|
|||||||
return self.encode_proj.out_features
|
return self.encode_proj.out_features
|
||||||
return self.bert_model.config.hidden_size
|
return self.bert_model.config.hidden_size
|
||||||
|
|
||||||
def init_weights(self):
|
|
||||||
self.bert_model.init_weights()
|
|
||||||
if self.projection_dim > 0:
|
|
||||||
self.encode_proj.apply(self.bert_model._init_weights)
|
|
||||||
|
|
||||||
|
class DPRSpanPredictor(DPRPreTrainedModel):
|
||||||
class DPRSpanPredictor(PreTrainedModel):
|
|
||||||
|
|
||||||
base_model_prefix = "encoder"
|
base_model_prefix = "encoder"
|
||||||
|
|
||||||
@@ -262,16 +279,13 @@ class DPRSpanPredictor(PreTrainedModel):
|
|||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_weights(self):
|
|
||||||
self.encoder.init_weights()
|
|
||||||
|
|
||||||
|
|
||||||
##################
|
##################
|
||||||
# PreTrainedModel
|
# PreTrainedModel
|
||||||
##################
|
##################
|
||||||
|
|
||||||
|
|
||||||
class DPRPretrainedContextEncoder(PreTrainedModel):
|
class DPRPretrainedContextEncoder(DPRPreTrainedModel):
|
||||||
"""
|
"""
|
||||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||||
models.
|
models.
|
||||||
@@ -282,11 +296,8 @@ class DPRPretrainedContextEncoder(PreTrainedModel):
|
|||||||
base_model_prefix = "ctx_encoder"
|
base_model_prefix = "ctx_encoder"
|
||||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||||
|
|
||||||
def init_weights(self):
|
|
||||||
self.ctx_encoder.init_weights()
|
|
||||||
|
|
||||||
|
class DPRPretrainedQuestionEncoder(DPRPreTrainedModel):
|
||||||
class DPRPretrainedQuestionEncoder(PreTrainedModel):
|
|
||||||
"""
|
"""
|
||||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||||
models.
|
models.
|
||||||
@@ -297,15 +308,8 @@ class DPRPretrainedQuestionEncoder(PreTrainedModel):
|
|||||||
base_model_prefix = "question_encoder"
|
base_model_prefix = "question_encoder"
|
||||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||||
|
|
||||||
def init_weights(self):
|
|
||||||
self.question_encoder.init_weights()
|
|
||||||
|
|
||||||
def _set_gradient_checkpointing(self, module, value=False):
|
class DPRPretrainedReader(DPRPreTrainedModel):
|
||||||
if isinstance(module, BertEncoder):
|
|
||||||
module.gradient_checkpointing = value
|
|
||||||
|
|
||||||
|
|
||||||
class DPRPretrainedReader(PreTrainedModel):
|
|
||||||
"""
|
"""
|
||||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||||
models.
|
models.
|
||||||
@@ -316,15 +320,6 @@ class DPRPretrainedReader(PreTrainedModel):
|
|||||||
base_model_prefix = "span_predictor"
|
base_model_prefix = "span_predictor"
|
||||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||||
|
|
||||||
def init_weights(self):
|
|
||||||
self.span_predictor.encoder.init_weights()
|
|
||||||
self.span_predictor.qa_classifier.apply(self.span_predictor.encoder.bert_model._init_weights)
|
|
||||||
self.span_predictor.qa_outputs.apply(self.span_predictor.encoder.bert_model._init_weights)
|
|
||||||
|
|
||||||
def _set_gradient_checkpointing(self, module, value=False):
|
|
||||||
if isinstance(module, BertEncoder):
|
|
||||||
module.gradient_checkpointing = value
|
|
||||||
|
|
||||||
|
|
||||||
###############
|
###############
|
||||||
# Actual Models
|
# Actual Models
|
||||||
|
|||||||
@@ -1462,6 +1462,15 @@ class DPRPretrainedContextEncoder:
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class DPRPreTrainedModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class DPRPretrainedQuestionEncoder:
|
class DPRPretrainedQuestionEncoder:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import DPRConfig, is_torch_available
|
from transformers import DPRConfig, is_torch_available
|
||||||
@@ -213,6 +214,19 @@ class DPRModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_reader(*config_and_inputs)
|
self.model_tester.create_and_check_reader(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_init_changed_config(self):
|
||||||
|
config = self.model_tester.prepare_config_and_inputs()[0]
|
||||||
|
|
||||||
|
model = DPRQuestionEncoder(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||||
|
model.save_pretrained(tmp_dirname)
|
||||||
|
model = DPRQuestionEncoder.from_pretrained(tmp_dirname, projection_dim=512)
|
||||||
|
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
for model_name in DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
for model_name in DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
|
|||||||
Reference in New Issue
Block a user