adding option to desactivate past/memory outputs

This commit is contained in:
thomwolf
2019-10-11 15:47:08 +02:00
parent 2a4fef837a
commit 0f9fc4fbde
8 changed files with 93 additions and 55 deletions

View File

@@ -269,16 +269,16 @@ class CTRLModel(CTRLPreTrainedModel):
def __init__(self, config):
super(CTRLModel, self).__init__(config)
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.output_past = config.output_past
self.d_model_size = config.n_embd
self.num_layers = config.n_layer
self.pos_encoding = positional_encoding(config.n_positions, self.d_model_size, torch.float)
self.output_attentions = config.output_attentions
self.w = nn.Embedding(config.vocab_size, config.n_embd)
self.dropout = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList([EncoderLayer(config.n_embd,
config.n_head,
@@ -378,7 +378,8 @@ class CTRLModel(CTRLPreTrainedModel):
attention_mask=attention_mask,
head_mask=head_mask[i])
hidden_states, present = outputs[:2]
presents = presents + (present,)
if self.output_past:
presents = presents + (present,)
if self.output_attentions:
all_attentions.append(outputs[2])
@@ -388,7 +389,9 @@ class CTRLModel(CTRLPreTrainedModel):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states, presents)
outputs = (hidden_states,)
if self.output_past:
outputs = outputs + (presents,)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
if self.output_attentions: