From 4cc65e990fa03158ac5e1d4d2d034cadae0c6213 Mon Sep 17 00:00:00 2001 From: Prem Kumar M <44919870+premmurugan229@users.noreply.github.com> Date: Thu, 27 Mar 2025 20:29:57 +0530 Subject: [PATCH] Replace default split function with jnp.split() in flax models (#37001) Replace split with jnp's split function for flax models (#36854) --- src/transformers/models/albert/modeling_flax_albert.py | 2 +- src/transformers/models/big_bird/modeling_flax_big_bird.py | 2 +- .../models/distilbert/modeling_flax_distilbert.py | 2 +- src/transformers/models/electra/modeling_flax_electra.py | 2 +- src/transformers/models/roformer/modeling_flax_roformer.py | 4 ++-- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/albert/modeling_flax_albert.py b/src/transformers/models/albert/modeling_flax_albert.py index b5b49219ae..df2ebddc7e 100644 --- a/src/transformers/models/albert/modeling_flax_albert.py +++ b/src/transformers/models/albert/modeling_flax_albert.py @@ -1087,7 +1087,7 @@ class FlaxAlbertForQuestionAnsweringModule(nn.Module): hidden_states = outputs[0] logits = self.qa_outputs(hidden_states) - start_logits, end_logits = logits.split(self.config.num_labels, axis=-1) + start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1) start_logits = start_logits.squeeze(-1) end_logits = end_logits.squeeze(-1) diff --git a/src/transformers/models/big_bird/modeling_flax_big_bird.py b/src/transformers/models/big_bird/modeling_flax_big_bird.py index 6cb31b2eae..7f43a4c5ab 100644 --- a/src/transformers/models/big_bird/modeling_flax_big_bird.py +++ b/src/transformers/models/big_bird/modeling_flax_big_bird.py @@ -2407,7 +2407,7 @@ class FlaxBigBirdForQuestionAnsweringModule(nn.Module): # removing question tokens from the competition logits = logits - logits_mask * 1e6 - start_logits, end_logits = logits.split(self.config.num_labels, axis=-1) + start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1) start_logits = start_logits.squeeze(-1) end_logits = end_logits.squeeze(-1) diff --git a/src/transformers/models/distilbert/modeling_flax_distilbert.py b/src/transformers/models/distilbert/modeling_flax_distilbert.py index 683cfc67ee..1f2b6ac96a 100644 --- a/src/transformers/models/distilbert/modeling_flax_distilbert.py +++ b/src/transformers/models/distilbert/modeling_flax_distilbert.py @@ -861,7 +861,7 @@ class FlaxDistilBertForQuestionAnsweringModule(nn.Module): hidden_states = self.dropout(hidden_states, deterministic=deterministic) logits = self.qa_outputs(hidden_states) - start_logits, end_logits = logits.split(self.config.num_labels, axis=-1) + start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1) start_logits = start_logits.squeeze(-1) end_logits = end_logits.squeeze(-1) diff --git a/src/transformers/models/electra/modeling_flax_electra.py b/src/transformers/models/electra/modeling_flax_electra.py index 646f66a878..77a445e6cc 100644 --- a/src/transformers/models/electra/modeling_flax_electra.py +++ b/src/transformers/models/electra/modeling_flax_electra.py @@ -1364,7 +1364,7 @@ class FlaxElectraForQuestionAnsweringModule(nn.Module): ) hidden_states = outputs[0] logits = self.qa_outputs(hidden_states) - start_logits, end_logits = logits.split(self.config.num_labels, axis=-1) + start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1) start_logits = start_logits.squeeze(-1) end_logits = end_logits.squeeze(-1) diff --git a/src/transformers/models/roformer/modeling_flax_roformer.py b/src/transformers/models/roformer/modeling_flax_roformer.py index 0c7cdab581..f47146f16b 100644 --- a/src/transformers/models/roformer/modeling_flax_roformer.py +++ b/src/transformers/models/roformer/modeling_flax_roformer.py @@ -262,7 +262,7 @@ class FlaxRoFormerSelfAttention(nn.Module): @staticmethod def apply_rotary_position_embeddings(sinusoidal_pos, query_layer, key_layer, value_layer=None): - sin, cos = sinusoidal_pos.split(2, axis=-1) + sin, cos = jnp.split(sinusoidal_pos, 2, axis=-1) sin_pos = jnp.stack([sin, sin], axis=-1).reshape(sinusoidal_pos.shape) cos_pos = jnp.stack([cos, cos], axis=-1).reshape(sinusoidal_pos.shape) @@ -1046,7 +1046,7 @@ class FlaxRoFormerForQuestionAnsweringModule(nn.Module): hidden_states = outputs[0] logits = self.qa_outputs(hidden_states) - start_logits, end_logits = logits.split(self.config.num_labels, axis=-1) + start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1) start_logits = start_logits.squeeze(-1) end_logits = end_logits.squeeze(-1)