Replace default split function with jnp.split() in flax models (#37001)
Replace split with jnp's split function for flax models (#36854)
This commit is contained in:
@@ -1087,7 +1087,7 @@ class FlaxAlbertForQuestionAnsweringModule(nn.Module):
|
|||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
|
|
||||||
logits = self.qa_outputs(hidden_states)
|
logits = self.qa_outputs(hidden_states)
|
||||||
start_logits, end_logits = logits.split(self.config.num_labels, axis=-1)
|
start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
|
||||||
start_logits = start_logits.squeeze(-1)
|
start_logits = start_logits.squeeze(-1)
|
||||||
end_logits = end_logits.squeeze(-1)
|
end_logits = end_logits.squeeze(-1)
|
||||||
|
|
||||||
|
|||||||
@@ -2407,7 +2407,7 @@ class FlaxBigBirdForQuestionAnsweringModule(nn.Module):
|
|||||||
# removing question tokens from the competition
|
# removing question tokens from the competition
|
||||||
logits = logits - logits_mask * 1e6
|
logits = logits - logits_mask * 1e6
|
||||||
|
|
||||||
start_logits, end_logits = logits.split(self.config.num_labels, axis=-1)
|
start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
|
||||||
start_logits = start_logits.squeeze(-1)
|
start_logits = start_logits.squeeze(-1)
|
||||||
end_logits = end_logits.squeeze(-1)
|
end_logits = end_logits.squeeze(-1)
|
||||||
|
|
||||||
|
|||||||
@@ -861,7 +861,7 @@ class FlaxDistilBertForQuestionAnsweringModule(nn.Module):
|
|||||||
|
|
||||||
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
||||||
logits = self.qa_outputs(hidden_states)
|
logits = self.qa_outputs(hidden_states)
|
||||||
start_logits, end_logits = logits.split(self.config.num_labels, axis=-1)
|
start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
|
||||||
start_logits = start_logits.squeeze(-1)
|
start_logits = start_logits.squeeze(-1)
|
||||||
end_logits = end_logits.squeeze(-1)
|
end_logits = end_logits.squeeze(-1)
|
||||||
|
|
||||||
|
|||||||
@@ -1364,7 +1364,7 @@ class FlaxElectraForQuestionAnsweringModule(nn.Module):
|
|||||||
)
|
)
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
logits = self.qa_outputs(hidden_states)
|
logits = self.qa_outputs(hidden_states)
|
||||||
start_logits, end_logits = logits.split(self.config.num_labels, axis=-1)
|
start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
|
||||||
start_logits = start_logits.squeeze(-1)
|
start_logits = start_logits.squeeze(-1)
|
||||||
end_logits = end_logits.squeeze(-1)
|
end_logits = end_logits.squeeze(-1)
|
||||||
|
|
||||||
|
|||||||
@@ -262,7 +262,7 @@ class FlaxRoFormerSelfAttention(nn.Module):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def apply_rotary_position_embeddings(sinusoidal_pos, query_layer, key_layer, value_layer=None):
|
def apply_rotary_position_embeddings(sinusoidal_pos, query_layer, key_layer, value_layer=None):
|
||||||
sin, cos = sinusoidal_pos.split(2, axis=-1)
|
sin, cos = jnp.split(sinusoidal_pos, 2, axis=-1)
|
||||||
sin_pos = jnp.stack([sin, sin], axis=-1).reshape(sinusoidal_pos.shape)
|
sin_pos = jnp.stack([sin, sin], axis=-1).reshape(sinusoidal_pos.shape)
|
||||||
cos_pos = jnp.stack([cos, cos], axis=-1).reshape(sinusoidal_pos.shape)
|
cos_pos = jnp.stack([cos, cos], axis=-1).reshape(sinusoidal_pos.shape)
|
||||||
|
|
||||||
@@ -1046,7 +1046,7 @@ class FlaxRoFormerForQuestionAnsweringModule(nn.Module):
|
|||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
|
|
||||||
logits = self.qa_outputs(hidden_states)
|
logits = self.qa_outputs(hidden_states)
|
||||||
start_logits, end_logits = logits.split(self.config.num_labels, axis=-1)
|
start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
|
||||||
start_logits = start_logits.squeeze(-1)
|
start_logits = start_logits.squeeze(-1)
|
||||||
end_logits = end_logits.squeeze(-1)
|
end_logits = end_logits.squeeze(-1)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user