From 5c5af879b6d45c879c987154f66d4ea978925fb2 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Tue, 3 Mar 2020 15:14:12 -0500 Subject: [PATCH] [Bart] dont call .forward (#3094) --- src/transformers/modeling_bart.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index f832d88575..286d0f0ea4 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -208,7 +208,7 @@ class EncoderLayer(nn.Module): encoded output of shape `(seq_len, batch, embed_dim)` """ residual = x - x, attn_weights = self.self_attn.forward( + x, attn_weights = self.self_attn( query=x, key=x, value=x, key_padding_mask=encoder_padding_mask, need_weights=self.output_attentions, ) x = F.dropout(x, p=self.dropout, training=self.training) @@ -292,7 +292,7 @@ class BartEncoder(nn.Module): if self.training and (dropout_probability < self.layerdrop): # skip the layer attn = None else: - x, attn = encoder_layer.forward(x, attention_mask) + x, attn = encoder_layer(x, attention_mask) if self.output_attentions: all_attentions.append(attn) @@ -356,7 +356,7 @@ class DecoderLayer(nn.Module): if layer_state is None: layer_state = {} # next line mutates layer state - x, self_attn_weights = self.self_attn.forward( + x, self_attn_weights = self.self_attn( query=x, key=y, value=y, layer_state=layer_state, need_weights=need_attn_weights, attn_mask=attention_mask, ) x = F.dropout(x, p=self.dropout, training=self.training) @@ -365,7 +365,7 @@ class DecoderLayer(nn.Module): residual = x assert self.encoder_attn.cache_key != self.self_attn.cache_key - x, encoder_attn_weights = self.encoder_attn.forward( + x, encoder_attn_weights = self.encoder_attn( query=x, key=encoder_hidden_states, # could be None value=encoder_hidden_states, @@ -449,7 +449,7 @@ class BartDecoder(nn.Module): - attentions """ # embed positions - positions = self.embed_positions.forward(input_ids, generation_mode=self.generation_mode) + positions = self.embed_positions(input_ids, generation_mode=self.generation_mode) if self.generation_mode: input_ids = input_ids[:, -1:] @@ -475,7 +475,7 @@ class BartDecoder(nn.Module): continue layer_state = decoder_cached_states[i] if decoder_cached_states is not None else None - x, layer_self_attn, layer_past = decoder_layer.forward( + x, layer_self_attn, layer_past = decoder_layer( x, encoder_hidden_states, encoder_padding_mask, @@ -836,10 +836,10 @@ class BartModel(PretrainedBartModel): ) assert decoder_input_ids is not None if encoder_outputs is None: - encoder_outputs = self.encoder.forward(input_ids=input_ids, attention_mask=attention_mask) + encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) assert isinstance(encoder_outputs, tuple) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - decoder_outputs = self.decoder.forward( + decoder_outputs = self.decoder( decoder_input_ids, encoder_outputs[0], attention_mask, @@ -925,7 +925,7 @@ class BartForMaskedLM(PretrainedBartModel): outputs = model(input_ids=input_ids, lm_labels=input_ids) loss, prediction_scores = outputs[:2] """ - outputs = self.model.forward( + outputs = self.model( input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, @@ -933,7 +933,7 @@ class BartForMaskedLM(PretrainedBartModel): decoder_attention_mask=decoder_attention_mask, decoder_cached_states=decoder_cached_states, ) - lm_logits = self.lm_head.forward(outputs[0]) + lm_logits = self.lm_head(outputs[0]) outputs = (lm_logits,) + outputs[1:] # Add hidden states and attention if they are here if lm_labels is not None: loss_fct = nn.CrossEntropyLoss() @@ -1308,7 +1308,7 @@ class BartForSequenceClassification(PretrainedBartModel): loss, logits = outputs[:2] """ - outputs = self.model.forward( + outputs = self.model( input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids,