Fix dpr<>bart config for RAG (#8808)
* correct dpr test and bert pos fault * fix dpr bert config problem * fix layoutlm * add config to dpr as well
This commit is contained in:
committed by
GitHub
parent
a2cf37595e
commit
a7d46a0609
@@ -83,7 +83,7 @@ class RobertaEmbeddings(nn.Module):
|
||||
|
||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||
self.position_embedding_type = config.position_embedding_type
|
||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||
|
||||
# End copy
|
||||
self.padding_idx = config.pad_token_id
|
||||
@@ -162,7 +162,7 @@ class RobertaSelfAttention(nn.Module):
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
self.position_embedding_type = config.position_embedding_type
|
||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||
|
||||
Reference in New Issue
Block a user