adding option to desactivate past/memory outputs
This commit is contained in:
@@ -269,16 +269,16 @@ class CTRLModel(CTRLPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super(CTRLModel, self).__init__(config)
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.output_attentions = config.output_attentions
|
||||
self.output_past = config.output_past
|
||||
|
||||
self.d_model_size = config.n_embd
|
||||
self.num_layers = config.n_layer
|
||||
|
||||
|
||||
self.pos_encoding = positional_encoding(config.n_positions, self.d_model_size, torch.float)
|
||||
|
||||
self.output_attentions = config.output_attentions
|
||||
|
||||
self.w = nn.Embedding(config.vocab_size, config.n_embd)
|
||||
|
||||
|
||||
self.dropout = nn.Dropout(config.embd_pdrop)
|
||||
self.h = nn.ModuleList([EncoderLayer(config.n_embd,
|
||||
config.n_head,
|
||||
@@ -378,7 +378,8 @@ class CTRLModel(CTRLPreTrainedModel):
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask[i])
|
||||
hidden_states, present = outputs[:2]
|
||||
presents = presents + (present,)
|
||||
if self.output_past:
|
||||
presents = presents + (present,)
|
||||
|
||||
if self.output_attentions:
|
||||
all_attentions.append(outputs[2])
|
||||
@@ -388,7 +389,9 @@ class CTRLModel(CTRLPreTrainedModel):
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
outputs = (hidden_states, presents)
|
||||
outputs = (hidden_states,)
|
||||
if self.output_past:
|
||||
outputs = outputs + (presents,)
|
||||
if self.output_hidden_states:
|
||||
outputs = outputs + (all_hidden_states,)
|
||||
if self.output_attentions:
|
||||
|
||||
Reference in New Issue
Block a user