Cannot index None (#6984)
This commit is contained in:
@@ -464,6 +464,8 @@ class BertEncoder(nn.Module):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||||
|
|
||||||
if getattr(self.config, "gradient_checkpointing", False):
|
if getattr(self.config, "gradient_checkpointing", False):
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
@@ -476,7 +478,7 @@ class BertEncoder(nn.Module):
|
|||||||
create_custom_forward(layer_module),
|
create_custom_forward(layer_module),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
head_mask[i],
|
layer_head_mask,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
)
|
)
|
||||||
@@ -484,7 +486,7 @@ class BertEncoder(nn.Module):
|
|||||||
layer_outputs = layer_module(
|
layer_outputs = layer_module(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
head_mask[i],
|
layer_head_mask,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
|
|||||||
Reference in New Issue
Block a user