From 17177e73796f516e3f49d311eab77b02ab679871 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Thu, 10 Oct 2019 12:03:58 +0200 Subject: [PATCH] add is_decoder as an attribute to Config class --- transformers/modeling_bert.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformers/modeling_bert.py b/transformers/modeling_bert.py index a5b21510aa..9e03c2f8d4 100644 --- a/transformers/modeling_bert.py +++ b/transformers/modeling_bert.py @@ -322,7 +322,7 @@ class BertLayer(nn.Module): def __init__(self, config): super(BertLayer, self).__init__() self.self_attention = BertAttention(config) - if config.get("is_decoder", False): + if getattr(config, "is_decoder", False): self.attention = BertAttention(config) self.intermediate = BertIntermediate(config) self.output = BertOutput(config) @@ -380,7 +380,7 @@ class BertEncoder(nn.Module): class BertDecoder(nn.Module): def __init__(self, config): super(BertDecoder, self).__init__() - config["is_decoder"] = True + config.is_decoder = True self.output_attentions = config.output_attentions self.output_hidden_states = config.output_hidden_states self.layers = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])