remove jax.ops.index (#16220)
This commit is contained in:
@@ -35,7 +35,7 @@ model = FlaxGPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
|
|||||||
|
|
||||||
emb = jnp.zeros((50264, model.config.hidden_size))
|
emb = jnp.zeros((50264, model.config.hidden_size))
|
||||||
# update the first 50257 weights using pre-trained weights
|
# update the first 50257 weights using pre-trained weights
|
||||||
emb = emb.at[jax.ops.index[:50257, :]].set(model.params["transformer"]["wte"]["embedding"])
|
emb = emb.at[:50257, :].set(model.params["transformer"]["wte"]["embedding"])
|
||||||
params = model.params
|
params = model.params
|
||||||
params["transformer"]["wte"]["embedding"] = emb
|
params["transformer"]["wte"]["embedding"] = emb
|
||||||
|
|
||||||
|
|||||||
@@ -143,10 +143,10 @@ class FlaxTopPLogitsWarper(FlaxLogitsWarper):
|
|||||||
|
|
||||||
# include the token that is higher than top_p as well
|
# include the token that is higher than top_p as well
|
||||||
score_mask = jnp.roll(score_mask, 1)
|
score_mask = jnp.roll(score_mask, 1)
|
||||||
score_mask |= score_mask.at[jax.ops.index[:, 0]].set(True)
|
score_mask |= score_mask.at[:, 0].set(True)
|
||||||
|
|
||||||
# min tokens to keep
|
# min tokens to keep
|
||||||
score_mask = score_mask.at[jax.ops.index[:, : self.min_tokens_to_keep]].set(True)
|
score_mask = score_mask.at[:, : self.min_tokens_to_keep].set(True)
|
||||||
|
|
||||||
topk_next_scores = jnp.where(score_mask, topk_scores, mask_scores)
|
topk_next_scores = jnp.where(score_mask, topk_scores, mask_scores)
|
||||||
next_scores = jax.lax.sort_key_val(topk_indices, topk_next_scores)[-1]
|
next_scores = jax.lax.sort_key_val(topk_indices, topk_next_scores)[-1]
|
||||||
@@ -207,7 +207,7 @@ class FlaxForcedBOSTokenLogitsProcessor(FlaxLogitsProcessor):
|
|||||||
|
|
||||||
apply_penalty = 1 - jnp.bool_(cur_len - 1)
|
apply_penalty = 1 - jnp.bool_(cur_len - 1)
|
||||||
|
|
||||||
scores = jnp.where(apply_penalty, new_scores.at[jax.ops.index[:, self.bos_token_id]].set(0), scores)
|
scores = jnp.where(apply_penalty, new_scores.at[:, self.bos_token_id].set(0), scores)
|
||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
@@ -232,7 +232,7 @@ class FlaxForcedEOSTokenLogitsProcessor(FlaxLogitsProcessor):
|
|||||||
|
|
||||||
apply_penalty = 1 - jnp.bool_(cur_len - self.max_length + 1)
|
apply_penalty = 1 - jnp.bool_(cur_len - self.max_length + 1)
|
||||||
|
|
||||||
scores = jnp.where(apply_penalty, new_scores.at[jax.ops.index[:, self.eos_token_id]].set(0), scores)
|
scores = jnp.where(apply_penalty, new_scores.at[:, self.eos_token_id].set(0), scores)
|
||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
@@ -263,6 +263,6 @@ class FlaxMinLengthLogitsProcessor(FlaxLogitsProcessor):
|
|||||||
# create boolean flag to decide if min length penalty should be applied
|
# create boolean flag to decide if min length penalty should be applied
|
||||||
apply_penalty = 1 - jnp.clip(cur_len - self.min_length, 0, 1)
|
apply_penalty = 1 - jnp.clip(cur_len - self.min_length, 0, 1)
|
||||||
|
|
||||||
scores = jnp.where(apply_penalty, scores.at[jax.ops.index[:, self.eos_token_id]].set(-float("inf")), scores)
|
scores = jnp.where(apply_penalty, scores.at[:, self.eos_token_id].set(-float("inf")), scores)
|
||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|||||||
@@ -2124,7 +2124,7 @@ class FlaxBigBirdForQuestionAnswering(FlaxBigBirdPreTrainedModel):
|
|||||||
if token_type_ids is None:
|
if token_type_ids is None:
|
||||||
token_type_ids = (~logits_mask).astype("i4")
|
token_type_ids = (~logits_mask).astype("i4")
|
||||||
logits_mask = jnp.expand_dims(logits_mask, axis=2)
|
logits_mask = jnp.expand_dims(logits_mask, axis=2)
|
||||||
logits_mask = logits_mask.at[jax.ops.index[:, 0]].set(False)
|
logits_mask = logits_mask.at[:, 0].set(False)
|
||||||
|
|
||||||
# init input tensors if not passed
|
# init input tensors if not passed
|
||||||
if token_type_ids is None:
|
if token_type_ids is None:
|
||||||
|
|||||||
@@ -1422,7 +1422,7 @@ class FlaxMarianMTModel(FlaxMarianPreTrainedModel):
|
|||||||
|
|
||||||
def _adapt_logits_for_beam_search(self, logits):
|
def _adapt_logits_for_beam_search(self, logits):
|
||||||
"""This function enforces the padding token never to be generated."""
|
"""This function enforces the padding token never to be generated."""
|
||||||
logits = logits.at[jax.ops.index[:, :, self.config.pad_token_id]].set(float("-inf"))
|
logits = logits.at[:, :, self.config.pad_token_id].set(float("-inf"))
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
|
|||||||
@@ -964,8 +964,7 @@ class FlaxWav2Vec2Module(nn.Module):
|
|||||||
|
|
||||||
# these two operations makes sure that all values
|
# these two operations makes sure that all values
|
||||||
# before the output lengths indices are attended to
|
# before the output lengths indices are attended to
|
||||||
idx = jax.ops.index[jnp.arange(attention_mask.shape[0]), output_lengths - 1]
|
attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1)
|
||||||
attention_mask = attention_mask.at[idx].set(1)
|
|
||||||
attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
|
attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
|
||||||
|
|
||||||
hidden_states, extract_features = self.feature_projection(extract_features, deterministic=deterministic)
|
hidden_states, extract_features = self.feature_projection(extract_features, deterministic=deterministic)
|
||||||
|
|||||||
Reference in New Issue
Block a user