@@ -745,7 +745,7 @@ class T5Stack(T5PreTrainedModel):
|
|||||||
# layer_outputs = hidden-states, key-value-states (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
|
# layer_outputs = hidden-states, key-value-states (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
|
||||||
position_bias = layer_outputs[3 if self.output_attentions else 2]
|
position_bias = layer_outputs[3 if self.output_attentions else 2]
|
||||||
if self.is_decoder and encoder_hidden_states is not None:
|
if self.is_decoder and encoder_hidden_states is not None:
|
||||||
encoder_decoder_position_bias = layer_outputs[4 if self.output_attentions else 3]
|
encoder_decoder_position_bias = layer_outputs[5 if self.output_attentions else 3]
|
||||||
# append next layer key value states
|
# append next layer key value states
|
||||||
present_key_value_states = present_key_value_states + (present_key_value_state,)
|
present_key_value_states = present_key_value_states + (present_key_value_state,)
|
||||||
|
|
||||||
|
|||||||
@@ -682,7 +682,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||||||
# layer_outputs = hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
|
# layer_outputs = hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
|
||||||
position_bias = layer_outputs[3 if self.output_attentions else 2]
|
position_bias = layer_outputs[3 if self.output_attentions else 2]
|
||||||
if self.is_decoder and encoder_hidden_states is not None:
|
if self.is_decoder and encoder_hidden_states is not None:
|
||||||
encoder_decoder_position_bias = layer_outputs[4 if self.output_attentions else 3]
|
encoder_decoder_position_bias = layer_outputs[5 if self.output_attentions else 3]
|
||||||
# append next layer key value states
|
# append next layer key value states
|
||||||
present_key_value_states = present_key_value_states + (present_key_value_state,)
|
present_key_value_states = present_key_value_states + (present_key_value_state,)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user