Fix dpr<>bart config for RAG (#8808)

* correct dpr test and bert pos fault

* fix dpr bert config problem

* fix layoutlm

* add config to dpr as well
This commit is contained in:
Patrick von Platen
2020-11-27 16:26:45 +01:00
committed by GitHub
parent a2cf37595e
commit a7d46a0609
7 changed files with 30 additions and 22 deletions

View File

@@ -178,7 +178,7 @@ class BertEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = config.position_embedding_type
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
if input_ids is not None:
@@ -225,7 +225,7 @@ class BertSelfAttention(nn.Module):
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = config.position_embedding_type
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)