Remove jnp.DeviceArray since it is deprecated. (#24875)
* Remove jnp.DeviceArray since it is deprecated. * Replace all instances of jnp.DeviceArray with jax.Array * Update src/transformers/models/bert/modeling_flax_bert.py --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
@@ -1467,8 +1467,8 @@ class FlaxBartForConditionalGeneration(FlaxBartPreTrainedModel):
|
|||||||
self,
|
self,
|
||||||
decoder_input_ids,
|
decoder_input_ids,
|
||||||
max_length,
|
max_length,
|
||||||
attention_mask: Optional[jnp.DeviceArray] = None,
|
attention_mask: Optional[jax.Array] = None,
|
||||||
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
|
decoder_attention_mask: Optional[jax.Array] = None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@@ -1960,7 +1960,7 @@ class FlaxBartForCausalLMModule(nn.Module):
|
|||||||
class FlaxBartForCausalLM(FlaxBartDecoderPreTrainedModel):
|
class FlaxBartForCausalLM(FlaxBartDecoderPreTrainedModel):
|
||||||
module_class = FlaxBartForCausalLMModule
|
module_class = FlaxBartForCausalLMModule
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
|
||||||
# initializing the cache
|
# initializing the cache
|
||||||
batch_size, seq_length = input_ids.shape
|
batch_size, seq_length = input_ids.shape
|
||||||
|
|
||||||
|
|||||||
@@ -1677,7 +1677,7 @@ class FlaxBertForCausalLMModule(nn.Module):
|
|||||||
class FlaxBertForCausalLM(FlaxBertPreTrainedModel):
|
class FlaxBertForCausalLM(FlaxBertPreTrainedModel):
|
||||||
module_class = FlaxBertForCausalLMModule
|
module_class = FlaxBertForCausalLMModule
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
|
||||||
# initializing the cache
|
# initializing the cache
|
||||||
batch_size, seq_length = input_ids.shape
|
batch_size, seq_length = input_ids.shape
|
||||||
|
|
||||||
|
|||||||
@@ -2599,7 +2599,7 @@ class FlaxBigBirdForCausalLMModule(nn.Module):
|
|||||||
class FlaxBigBirdForCausalLM(FlaxBigBirdPreTrainedModel):
|
class FlaxBigBirdForCausalLM(FlaxBigBirdPreTrainedModel):
|
||||||
module_class = FlaxBigBirdForCausalLMModule
|
module_class = FlaxBigBirdForCausalLMModule
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
|
||||||
# initializing the cache
|
# initializing the cache
|
||||||
batch_size, seq_length = input_ids.shape
|
batch_size, seq_length = input_ids.shape
|
||||||
|
|
||||||
|
|||||||
@@ -1443,8 +1443,8 @@ class FlaxBlenderbotForConditionalGeneration(FlaxBlenderbotPreTrainedModel):
|
|||||||
self,
|
self,
|
||||||
decoder_input_ids,
|
decoder_input_ids,
|
||||||
max_length,
|
max_length,
|
||||||
attention_mask: Optional[jnp.DeviceArray] = None,
|
attention_mask: Optional[jax.Array] = None,
|
||||||
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
|
decoder_attention_mask: Optional[jax.Array] = None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -1441,8 +1441,8 @@ class FlaxBlenderbotSmallForConditionalGeneration(FlaxBlenderbotSmallPreTrainedM
|
|||||||
self,
|
self,
|
||||||
decoder_input_ids,
|
decoder_input_ids,
|
||||||
max_length,
|
max_length,
|
||||||
attention_mask: Optional[jnp.DeviceArray] = None,
|
attention_mask: Optional[jax.Array] = None,
|
||||||
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
|
decoder_attention_mask: Optional[jax.Array] = None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -1565,7 +1565,7 @@ class FlaxElectraForCausalLMModule(nn.Module):
|
|||||||
class FlaxElectraForCausalLM(FlaxElectraPreTrainedModel):
|
class FlaxElectraForCausalLM(FlaxElectraPreTrainedModel):
|
||||||
module_class = FlaxElectraForCausalLMModule
|
module_class = FlaxElectraForCausalLMModule
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
|
||||||
# initializing the cache
|
# initializing the cache
|
||||||
batch_size, seq_length = input_ids.shape
|
batch_size, seq_length = input_ids.shape
|
||||||
|
|
||||||
|
|||||||
@@ -722,8 +722,8 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
|
|||||||
self,
|
self,
|
||||||
decoder_input_ids,
|
decoder_input_ids,
|
||||||
max_length,
|
max_length,
|
||||||
attention_mask: Optional[jnp.DeviceArray] = None,
|
attention_mask: Optional[jax.Array] = None,
|
||||||
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
|
decoder_attention_mask: Optional[jax.Array] = None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -742,7 +742,7 @@ class FlaxGPT2LMHeadModule(nn.Module):
|
|||||||
class FlaxGPT2LMHeadModel(FlaxGPT2PreTrainedModel):
|
class FlaxGPT2LMHeadModel(FlaxGPT2PreTrainedModel):
|
||||||
module_class = FlaxGPT2LMHeadModule
|
module_class = FlaxGPT2LMHeadModule
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
|
||||||
# initializing the cache
|
# initializing the cache
|
||||||
batch_size, seq_length = input_ids.shape
|
batch_size, seq_length = input_ids.shape
|
||||||
|
|
||||||
|
|||||||
@@ -654,7 +654,7 @@ class FlaxGPTNeoForCausalLMModule(nn.Module):
|
|||||||
class FlaxGPTNeoForCausalLM(FlaxGPTNeoPreTrainedModel):
|
class FlaxGPTNeoForCausalLM(FlaxGPTNeoPreTrainedModel):
|
||||||
module_class = FlaxGPTNeoForCausalLMModule
|
module_class = FlaxGPTNeoForCausalLMModule
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
|
||||||
# initializing the cache
|
# initializing the cache
|
||||||
batch_size, seq_length = input_ids.shape
|
batch_size, seq_length = input_ids.shape
|
||||||
|
|
||||||
|
|||||||
@@ -683,7 +683,7 @@ class FlaxGPTJForCausalLMModule(nn.Module):
|
|||||||
class FlaxGPTJForCausalLM(FlaxGPTJPreTrainedModel):
|
class FlaxGPTJForCausalLM(FlaxGPTJPreTrainedModel):
|
||||||
module_class = FlaxGPTJForCausalLMModule
|
module_class = FlaxGPTJForCausalLMModule
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
|
||||||
# initializing the cache
|
# initializing the cache
|
||||||
batch_size, seq_length = input_ids.shape
|
batch_size, seq_length = input_ids.shape
|
||||||
|
|
||||||
|
|||||||
@@ -2388,8 +2388,8 @@ class FlaxLongT5ForConditionalGeneration(FlaxLongT5PreTrainedModel):
|
|||||||
self,
|
self,
|
||||||
decoder_input_ids,
|
decoder_input_ids,
|
||||||
max_length,
|
max_length,
|
||||||
attention_mask: Optional[jnp.DeviceArray] = None,
|
attention_mask: Optional[jax.Array] = None,
|
||||||
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
|
decoder_attention_mask: Optional[jax.Array] = None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -1436,8 +1436,8 @@ class FlaxMarianMTModel(FlaxMarianPreTrainedModel):
|
|||||||
self,
|
self,
|
||||||
decoder_input_ids,
|
decoder_input_ids,
|
||||||
max_length,
|
max_length,
|
||||||
attention_mask: Optional[jnp.DeviceArray] = None,
|
attention_mask: Optional[jax.Array] = None,
|
||||||
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
|
decoder_attention_mask: Optional[jax.Array] = None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -1502,8 +1502,8 @@ class FlaxMBartForConditionalGeneration(FlaxMBartPreTrainedModel):
|
|||||||
self,
|
self,
|
||||||
decoder_input_ids,
|
decoder_input_ids,
|
||||||
max_length,
|
max_length,
|
||||||
attention_mask: Optional[jnp.DeviceArray] = None,
|
attention_mask: Optional[jax.Array] = None,
|
||||||
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
|
decoder_attention_mask: Optional[jax.Array] = None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -763,7 +763,7 @@ class FlaxOPTForCausalLMModule(nn.Module):
|
|||||||
class FlaxOPTForCausalLM(FlaxOPTPreTrainedModel):
|
class FlaxOPTForCausalLM(FlaxOPTPreTrainedModel):
|
||||||
module_class = FlaxOPTForCausalLMModule
|
module_class = FlaxOPTForCausalLMModule
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
|
||||||
# initializing the cache
|
# initializing the cache
|
||||||
batch_size, seq_length = input_ids.shape
|
batch_size, seq_length = input_ids.shape
|
||||||
|
|
||||||
|
|||||||
@@ -1450,8 +1450,8 @@ class FlaxPegasusForConditionalGeneration(FlaxPegasusPreTrainedModel):
|
|||||||
self,
|
self,
|
||||||
decoder_input_ids,
|
decoder_input_ids,
|
||||||
max_length,
|
max_length,
|
||||||
attention_mask: Optional[jnp.DeviceArray] = None,
|
attention_mask: Optional[jax.Array] = None,
|
||||||
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
|
decoder_attention_mask: Optional[jax.Array] = None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -1452,7 +1452,7 @@ class FlaxRobertaForCausalLMModule(nn.Module):
|
|||||||
class FlaxRobertaForCausalLM(FlaxRobertaPreTrainedModel):
|
class FlaxRobertaForCausalLM(FlaxRobertaPreTrainedModel):
|
||||||
module_class = FlaxRobertaForCausalLMModule
|
module_class = FlaxRobertaForCausalLMModule
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
|
||||||
# initializing the cache
|
# initializing the cache
|
||||||
batch_size, seq_length = input_ids.shape
|
batch_size, seq_length = input_ids.shape
|
||||||
|
|
||||||
|
|||||||
@@ -1478,7 +1478,7 @@ class FlaxRobertaPreLayerNormForCausalLMModule(nn.Module):
|
|||||||
class FlaxRobertaPreLayerNormForCausalLM(FlaxRobertaPreLayerNormPreTrainedModel):
|
class FlaxRobertaPreLayerNormForCausalLM(FlaxRobertaPreLayerNormPreTrainedModel):
|
||||||
module_class = FlaxRobertaPreLayerNormForCausalLMModule
|
module_class = FlaxRobertaPreLayerNormForCausalLMModule
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
|
||||||
# initializing the cache
|
# initializing the cache
|
||||||
batch_size, seq_length = input_ids.shape
|
batch_size, seq_length = input_ids.shape
|
||||||
|
|
||||||
|
|||||||
@@ -745,8 +745,8 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
|
|||||||
self,
|
self,
|
||||||
decoder_input_ids,
|
decoder_input_ids,
|
||||||
max_length,
|
max_length,
|
||||||
attention_mask: Optional[jnp.DeviceArray] = None,
|
attention_mask: Optional[jax.Array] = None,
|
||||||
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
|
decoder_attention_mask: Optional[jax.Array] = None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -1740,8 +1740,8 @@ class FlaxT5ForConditionalGeneration(FlaxT5PreTrainedModel):
|
|||||||
self,
|
self,
|
||||||
decoder_input_ids,
|
decoder_input_ids,
|
||||||
max_length,
|
max_length,
|
||||||
attention_mask: Optional[jnp.DeviceArray] = None,
|
attention_mask: Optional[jax.Array] = None,
|
||||||
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
|
decoder_attention_mask: Optional[jax.Array] = None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -688,7 +688,7 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel):
|
|||||||
self,
|
self,
|
||||||
decoder_input_ids,
|
decoder_input_ids,
|
||||||
max_length,
|
max_length,
|
||||||
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
|
decoder_attention_mask: Optional[jax.Array] = None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -1448,8 +1448,8 @@ class FlaxWhisperForConditionalGeneration(FlaxWhisperPreTrainedModel):
|
|||||||
self,
|
self,
|
||||||
decoder_input_ids,
|
decoder_input_ids,
|
||||||
max_length,
|
max_length,
|
||||||
attention_mask: Optional[jnp.DeviceArray] = None,
|
attention_mask: Optional[jax.Array] = None,
|
||||||
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
|
decoder_attention_mask: Optional[jax.Array] = None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -766,7 +766,7 @@ class FlaxXGLMForCausalLMModule(nn.Module):
|
|||||||
class FlaxXGLMForCausalLM(FlaxXGLMPreTrainedModel):
|
class FlaxXGLMForCausalLM(FlaxXGLMPreTrainedModel):
|
||||||
module_class = FlaxXGLMForCausalLMModule
|
module_class = FlaxXGLMForCausalLMModule
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
|
||||||
# initializing the cache
|
# initializing the cache
|
||||||
batch_size, seq_length = input_ids.shape
|
batch_size, seq_length = input_ids.shape
|
||||||
|
|
||||||
|
|||||||
@@ -1469,7 +1469,7 @@ class FlaxXLMRobertaForCausalLMModule(nn.Module):
|
|||||||
class FlaxXLMRobertaForCausalLM(FlaxXLMRobertaPreTrainedModel):
|
class FlaxXLMRobertaForCausalLM(FlaxXLMRobertaPreTrainedModel):
|
||||||
module_class = FlaxXLMRobertaForCausalLMModule
|
module_class = FlaxXLMRobertaForCausalLMModule
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
|
||||||
# initializing the cache
|
# initializing the cache
|
||||||
batch_size, seq_length = input_ids.shape
|
batch_size, seq_length = input_ids.shape
|
||||||
|
|
||||||
|
|||||||
@@ -1469,7 +1469,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule(nn.Module):
|
|||||||
class Flax{{cookiecutter.camelcase_modelname}}ForCausalLM(Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel):
|
class Flax{{cookiecutter.camelcase_modelname}}ForCausalLM(Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel):
|
||||||
module_class = Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule
|
module_class = Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
|
||||||
# initializing the cache
|
# initializing the cache
|
||||||
batch_size, seq_length = input_ids.shape
|
batch_size, seq_length = input_ids.shape
|
||||||
|
|
||||||
@@ -2969,8 +2969,8 @@ class Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(Flax{{coo
|
|||||||
self,
|
self,
|
||||||
decoder_input_ids,
|
decoder_input_ids,
|
||||||
max_length,
|
max_length,
|
||||||
attention_mask: Optional[jnp.DeviceArray] = None,
|
attention_mask: Optional[jax.Array] = None,
|
||||||
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
|
decoder_attention_mask: Optional[jax.Array] = None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
|
|||||||
Reference in New Issue
Block a user