[Bert2Bert] allow bert2bert + relative embeddings (#14324)
* [Bert2Bert] allow bert2bert + relative embeddings * up * Update README_ko.md * up * up
This commit is contained in:
committed by
GitHub
parent
e4d8f517b9
commit
e81d8d7fa9
@@ -203,7 +203,7 @@ class {{cookiecutter.camelcase_modelname}}Embeddings(nn.Module):
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->{{cookiecutter.camelcase_modelname}}
|
||||
class {{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
@@ -220,7 +220,7 @@ class {{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module):
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||
self.position_embedding_type = position_embedding_type or getattr(config, "position_embedding_type", "absolute")
|
||||
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||
@@ -344,9 +344,9 @@ class {{cookiecutter.camelcase_modelname}}SelfOutput(nn.Module):
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->{{cookiecutter.camelcase_modelname}}
|
||||
class {{cookiecutter.camelcase_modelname}}Attention(nn.Module):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, position_embedding_type=None):
|
||||
super().__init__()
|
||||
self.self = {{cookiecutter.camelcase_modelname}}SelfAttention(config)
|
||||
self.self = {{cookiecutter.camelcase_modelname}}SelfAttention(config, position_embedding_type=position_embedding_type)
|
||||
self.output = {{cookiecutter.camelcase_modelname}}SelfOutput(config)
|
||||
self.pruned_heads = set()
|
||||
|
||||
@@ -434,7 +434,7 @@ class {{cookiecutter.camelcase_modelname}}Layer(nn.Module):
|
||||
self.add_cross_attention = config.add_cross_attention
|
||||
if self.add_cross_attention:
|
||||
assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"
|
||||
self.crossattention = {{cookiecutter.camelcase_modelname}}Attention(config)
|
||||
self.crossattention = {{cookiecutter.camelcase_modelname}}Attention(config, position_embedding_type="absolute")
|
||||
self.intermediate = {{cookiecutter.camelcase_modelname}}Intermediate(config)
|
||||
self.output = {{cookiecutter.camelcase_modelname}}Output(config)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user