From fa218e648abc4f2c2d8a897ed0b4f2f050ecaca4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Thu, 10 Oct 2019 15:16:07 +0200 Subject: [PATCH] fix syntax errors --- transformers/modeling_bert.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/transformers/modeling_bert.py b/transformers/modeling_bert.py index 5d53b981e5..bce7972315 100644 --- a/transformers/modeling_bert.py +++ b/transformers/modeling_bert.py @@ -201,7 +201,7 @@ class BertSelfAttention(nn.Module): def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None): mixed_key_layer = self.key(hidden_states) mixed_value_layer = self.value(hidden_states) - if encoder_hidden_states: # if encoder-decoder attention + if encoder_hidden_states is not None: # if encoder-decoder attention mixed_query_layer = self.query(encoder_hidden_states) else: mixed_query_layer = self.query(hidden_states) @@ -331,11 +331,12 @@ class BertLayer(nn.Module): attention_outputs = self.attention(hidden_states, attention_mask, head_mask) attention_output = attention_outputs[0] - if encoder_hidden_state: + if encoder_hidden_state is not None: try: attention_outputs = self.crossattention(attention_output, attention_mask, head_mask, encoder_hidden_state) except AttributeError as ae: - raise ae("you need to set `is_encoder` to True in the configuration to instantiate an encoder layer") + print("You need to set `is_encoder` to True in the configuration to instantiate an encoder layer:", ae) + raise attention_output = attention_outputs[0] intermediate_output = self.intermediate(attention_output) @@ -382,7 +383,7 @@ class BertDecoder(nn.Module): config.is_decoder = True self.output_attentions = config.output_attentions self.output_hidden_states = config.output_hidden_states - self.layers = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) def forward(self, hidden_states, encoder_outputs, attention_mask=None, head_mask=None): all_hidden_states = () @@ -738,7 +739,7 @@ class BertDecoderModel(BertPreTrainedModel): self.decoder.layer[layer].attention.prune_heads(heads) self.decoder.layer[layer].crossattention.prune_heads(heads) - def forward(self, input_ids, encoder_outputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None): + def forward(self, input_ids, encoder_outputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False): if attention_mask is None: attention_mask = torch.ones_like(input_ids) if token_type_ids is None: @@ -782,7 +783,7 @@ class BertDecoderModel(BertPreTrainedModel): sequence_output = decoder_outputs[0] pooled_output = self.pooler(sequence_output) - outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here + outputs = (sequence_output, pooled_output,) + decoder_outputs[1:] # add hidden_states and attentions if they are here return outputs # sequence_output, pooled_output, (hidden_states), (attentions) @@ -1387,8 +1388,7 @@ class Bert2Rnd(BertPreTrainedModel): head_mask=head_mask) encoder_output = encoder_outputs[0] - decoder_input = torch.empty_like(input_ids).normal_(mean=0.0, std=self.config.initializer_range) - decoder_outputs = self.decoder(decoder_input, + decoder_outputs = self.decoder(input_ids, encoder_output, token_type_ids=token_type_ids, position_ids=position_ids,