[Bert2Bert] allow bert2bert + relative embeddings (#14324)

* [Bert2Bert] allow bert2bert + relative embeddings

* up

* Update README_ko.md

* up

* up
This commit is contained in:
Patrick von Platen
2021-11-09 20:26:58 +01:00
committed by GitHub
parent e4d8f517b9
commit e81d8d7fa9
11 changed files with 70 additions and 40 deletions

View File

@@ -203,7 +203,7 @@ class {{cookiecutter.camelcase_modelname}}Embeddings(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->{{cookiecutter.camelcase_modelname}}
class {{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module):
def __init__(self, config):
def __init__(self, config, position_embedding_type=None):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
@@ -220,7 +220,7 @@ class {{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module):
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.position_embedding_type = position_embedding_type or getattr(config, "position_embedding_type", "absolute")
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
@@ -344,9 +344,9 @@ class {{cookiecutter.camelcase_modelname}}SelfOutput(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->{{cookiecutter.camelcase_modelname}}
class {{cookiecutter.camelcase_modelname}}Attention(nn.Module):
def __init__(self, config):
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = {{cookiecutter.camelcase_modelname}}SelfAttention(config)
self.self = {{cookiecutter.camelcase_modelname}}SelfAttention(config, position_embedding_type=position_embedding_type)
self.output = {{cookiecutter.camelcase_modelname}}SelfOutput(config)
self.pruned_heads = set()
@@ -434,7 +434,7 @@ class {{cookiecutter.camelcase_modelname}}Layer(nn.Module):
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"
self.crossattention = {{cookiecutter.camelcase_modelname}}Attention(config)
self.crossattention = {{cookiecutter.camelcase_modelname}}Attention(config, position_embedding_type="absolute")
self.intermediate = {{cookiecutter.camelcase_modelname}}Intermediate(config)
self.output = {{cookiecutter.camelcase_modelname}}Output(config)