adding option to desactivate past/memory outputs
This commit is contained in:
@@ -168,12 +168,14 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
super(TFCTRLMainLayer, self).__init__(**kwargs)
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.output_attentions = config.output_attentions
|
||||
self.output_past = config.output_past
|
||||
|
||||
self.d_model_size = config.n_embd
|
||||
self.num_layers = config.n_layer
|
||||
|
||||
self.pos_encoding = positional_encoding(config.n_positions, self.d_model_size)
|
||||
|
||||
self.output_attentions = config.output_attentions
|
||||
|
||||
self.w = TFSharedEmbeddings(config.vocab_size,
|
||||
config.n_embd,
|
||||
@@ -290,7 +292,9 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
|
||||
outputs = h([hidden_states, mask, layer_past, attention_mask, head_mask[i]], training=training)
|
||||
hidden_states, present = outputs[:2]
|
||||
presents = presents + (present,)
|
||||
|
||||
if self.output_past:
|
||||
presents = presents + (present,)
|
||||
|
||||
if self.output_attentions:
|
||||
all_attentions.append(outputs[2])
|
||||
@@ -300,7 +304,9 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
outputs = (hidden_states, presents)
|
||||
outputs = (hidden_states,)
|
||||
if self.output_past:
|
||||
outputs = outputs + (presents,)
|
||||
if self.output_hidden_states:
|
||||
outputs = outputs + (all_hidden_states,)
|
||||
if self.output_attentions:
|
||||
|
||||
Reference in New Issue
Block a user