Remove jnp.DeviceArray since it is deprecated. (#24875)
* Remove jnp.DeviceArray since it is deprecated. * Replace all instances of jnp.DeviceArray with jax.Array * Update src/transformers/models/bert/modeling_flax_bert.py --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
@@ -1469,7 +1469,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule(nn.Module):
|
||||
class Flax{{cookiecutter.camelcase_modelname}}ForCausalLM(Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel):
|
||||
module_class = Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
|
||||
# initializing the cache
|
||||
batch_size, seq_length = input_ids.shape
|
||||
|
||||
@@ -2969,8 +2969,8 @@ class Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(Flax{{coo
|
||||
self,
|
||||
decoder_input_ids,
|
||||
max_length,
|
||||
attention_mask: Optional[jnp.DeviceArray] = None,
|
||||
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
|
||||
attention_mask: Optional[jax.Array] = None,
|
||||
decoder_attention_mask: Optional[jax.Array] = None,
|
||||
encoder_outputs=None,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user