Replace default split function with jnp.split() in flax models (#37001)

Replace split with jnp's split function for flax models (#36854)
This commit is contained in:
Prem Kumar M
2025-03-27 20:29:57 +05:30
committed by GitHub
parent 41a0e58e5b
commit 4cc65e990f
5 changed files with 6 additions and 6 deletions

View File

@@ -1087,7 +1087,7 @@ class FlaxAlbertForQuestionAnsweringModule(nn.Module):
hidden_states = outputs[0] hidden_states = outputs[0]
logits = self.qa_outputs(hidden_states) logits = self.qa_outputs(hidden_states)
start_logits, end_logits = logits.split(self.config.num_labels, axis=-1) start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1)

View File

@@ -2407,7 +2407,7 @@ class FlaxBigBirdForQuestionAnsweringModule(nn.Module):
# removing question tokens from the competition # removing question tokens from the competition
logits = logits - logits_mask * 1e6 logits = logits - logits_mask * 1e6
start_logits, end_logits = logits.split(self.config.num_labels, axis=-1) start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1)

View File

@@ -861,7 +861,7 @@ class FlaxDistilBertForQuestionAnsweringModule(nn.Module):
hidden_states = self.dropout(hidden_states, deterministic=deterministic) hidden_states = self.dropout(hidden_states, deterministic=deterministic)
logits = self.qa_outputs(hidden_states) logits = self.qa_outputs(hidden_states)
start_logits, end_logits = logits.split(self.config.num_labels, axis=-1) start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1)

View File

@@ -1364,7 +1364,7 @@ class FlaxElectraForQuestionAnsweringModule(nn.Module):
) )
hidden_states = outputs[0] hidden_states = outputs[0]
logits = self.qa_outputs(hidden_states) logits = self.qa_outputs(hidden_states)
start_logits, end_logits = logits.split(self.config.num_labels, axis=-1) start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1)

View File

@@ -262,7 +262,7 @@ class FlaxRoFormerSelfAttention(nn.Module):
@staticmethod @staticmethod
def apply_rotary_position_embeddings(sinusoidal_pos, query_layer, key_layer, value_layer=None): def apply_rotary_position_embeddings(sinusoidal_pos, query_layer, key_layer, value_layer=None):
sin, cos = sinusoidal_pos.split(2, axis=-1) sin, cos = jnp.split(sinusoidal_pos, 2, axis=-1)
sin_pos = jnp.stack([sin, sin], axis=-1).reshape(sinusoidal_pos.shape) sin_pos = jnp.stack([sin, sin], axis=-1).reshape(sinusoidal_pos.shape)
cos_pos = jnp.stack([cos, cos], axis=-1).reshape(sinusoidal_pos.shape) cos_pos = jnp.stack([cos, cos], axis=-1).reshape(sinusoidal_pos.shape)
@@ -1046,7 +1046,7 @@ class FlaxRoFormerForQuestionAnsweringModule(nn.Module):
hidden_states = outputs[0] hidden_states = outputs[0]
logits = self.qa_outputs(hidden_states) logits = self.qa_outputs(hidden_states)
start_logits, end_logits = logits.split(self.config.num_labels, axis=-1) start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1)