add is_decoder as an attribute to Config class

This commit is contained in:
Rémi Louf
2019-10-10 12:03:58 +02:00
parent df85a0ff0b
commit 17177e7379

View File

@@ -322,7 +322,7 @@ class BertLayer(nn.Module):
def __init__(self, config):
super(BertLayer, self).__init__()
self.self_attention = BertAttention(config)
if config.get("is_decoder", False):
if getattr(config, "is_decoder", False):
self.attention = BertAttention(config)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
@@ -380,7 +380,7 @@ class BertEncoder(nn.Module):
class BertDecoder(nn.Module):
def __init__(self, config):
super(BertDecoder, self).__init__()
config["is_decoder"] = True
config.is_decoder = True
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.layers = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])