Replace default split function with jnp.split() in flax models (#37001)
Replace split with jnp's split function for flax models (#36854)
This commit is contained in:
@@ -1087,7 +1087,7 @@ class FlaxAlbertForQuestionAnsweringModule(nn.Module):
|
||||
hidden_states = outputs[0]
|
||||
|
||||
logits = self.qa_outputs(hidden_states)
|
||||
start_logits, end_logits = logits.split(self.config.num_labels, axis=-1)
|
||||
start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
|
||||
start_logits = start_logits.squeeze(-1)
|
||||
end_logits = end_logits.squeeze(-1)
|
||||
|
||||
|
||||
@@ -2407,7 +2407,7 @@ class FlaxBigBirdForQuestionAnsweringModule(nn.Module):
|
||||
# removing question tokens from the competition
|
||||
logits = logits - logits_mask * 1e6
|
||||
|
||||
start_logits, end_logits = logits.split(self.config.num_labels, axis=-1)
|
||||
start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
|
||||
start_logits = start_logits.squeeze(-1)
|
||||
end_logits = end_logits.squeeze(-1)
|
||||
|
||||
|
||||
@@ -861,7 +861,7 @@ class FlaxDistilBertForQuestionAnsweringModule(nn.Module):
|
||||
|
||||
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
||||
logits = self.qa_outputs(hidden_states)
|
||||
start_logits, end_logits = logits.split(self.config.num_labels, axis=-1)
|
||||
start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
|
||||
start_logits = start_logits.squeeze(-1)
|
||||
end_logits = end_logits.squeeze(-1)
|
||||
|
||||
|
||||
@@ -1364,7 +1364,7 @@ class FlaxElectraForQuestionAnsweringModule(nn.Module):
|
||||
)
|
||||
hidden_states = outputs[0]
|
||||
logits = self.qa_outputs(hidden_states)
|
||||
start_logits, end_logits = logits.split(self.config.num_labels, axis=-1)
|
||||
start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
|
||||
start_logits = start_logits.squeeze(-1)
|
||||
end_logits = end_logits.squeeze(-1)
|
||||
|
||||
|
||||
@@ -262,7 +262,7 @@ class FlaxRoFormerSelfAttention(nn.Module):
|
||||
|
||||
@staticmethod
|
||||
def apply_rotary_position_embeddings(sinusoidal_pos, query_layer, key_layer, value_layer=None):
|
||||
sin, cos = sinusoidal_pos.split(2, axis=-1)
|
||||
sin, cos = jnp.split(sinusoidal_pos, 2, axis=-1)
|
||||
sin_pos = jnp.stack([sin, sin], axis=-1).reshape(sinusoidal_pos.shape)
|
||||
cos_pos = jnp.stack([cos, cos], axis=-1).reshape(sinusoidal_pos.shape)
|
||||
|
||||
@@ -1046,7 +1046,7 @@ class FlaxRoFormerForQuestionAnsweringModule(nn.Module):
|
||||
hidden_states = outputs[0]
|
||||
|
||||
logits = self.qa_outputs(hidden_states)
|
||||
start_logits, end_logits = logits.split(self.config.num_labels, axis=-1)
|
||||
start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
|
||||
start_logits = start_logits.squeeze(-1)
|
||||
end_logits = end_logits.squeeze(-1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user