add is_decoder as an attribute to Config class
This commit is contained in:
@@ -322,7 +322,7 @@ class BertLayer(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertLayer, self).__init__()
|
||||
self.self_attention = BertAttention(config)
|
||||
if config.get("is_decoder", False):
|
||||
if getattr(config, "is_decoder", False):
|
||||
self.attention = BertAttention(config)
|
||||
self.intermediate = BertIntermediate(config)
|
||||
self.output = BertOutput(config)
|
||||
@@ -380,7 +380,7 @@ class BertEncoder(nn.Module):
|
||||
class BertDecoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertDecoder, self).__init__()
|
||||
config["is_decoder"] = True
|
||||
config.is_decoder = True
|
||||
self.output_attentions = config.output_attentions
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.layers = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
Reference in New Issue
Block a user