add is_decoder as an attribute to Config class
This commit is contained in:
@@ -322,7 +322,7 @@ class BertLayer(nn.Module):
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(BertLayer, self).__init__()
|
super(BertLayer, self).__init__()
|
||||||
self.self_attention = BertAttention(config)
|
self.self_attention = BertAttention(config)
|
||||||
if config.get("is_decoder", False):
|
if getattr(config, "is_decoder", False):
|
||||||
self.attention = BertAttention(config)
|
self.attention = BertAttention(config)
|
||||||
self.intermediate = BertIntermediate(config)
|
self.intermediate = BertIntermediate(config)
|
||||||
self.output = BertOutput(config)
|
self.output = BertOutput(config)
|
||||||
@@ -380,7 +380,7 @@ class BertEncoder(nn.Module):
|
|||||||
class BertDecoder(nn.Module):
|
class BertDecoder(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(BertDecoder, self).__init__()
|
super(BertDecoder, self).__init__()
|
||||||
config["is_decoder"] = True
|
config.is_decoder = True
|
||||||
self.output_attentions = config.output_attentions
|
self.output_attentions = config.output_attentions
|
||||||
self.output_hidden_states = config.output_hidden_states
|
self.output_hidden_states = config.output_hidden_states
|
||||||
self.layers = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
self.layers = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
||||||
|
|||||||
Reference in New Issue
Block a user