[FlaxCLIP] allow passing params to image and text feature methods (#13099)
* allow passing params to image and text feature method * ifx for hybrid clip as well
This commit is contained in:
@@ -208,6 +208,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
token_type_ids=None,
|
||||
params: dict = None,
|
||||
dropout_rng: jax.random.PRNGKey = None,
|
||||
train=False,
|
||||
):
|
||||
@@ -254,7 +255,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
|
||||
return text_features
|
||||
|
||||
return self.module.apply(
|
||||
{"params": self.params},
|
||||
{"params": params or self.params},
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
jnp.array(position_ids, dtype="i4"),
|
||||
@@ -264,7 +265,9 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
|
||||
rngs=rngs,
|
||||
)
|
||||
|
||||
def get_image_features(self, pixel_values, dropout_rng: jax.random.PRNGKey = None, train=False):
|
||||
def get_image_features(
|
||||
self, pixel_values, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train=False
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
pixel_values (:obj:`numpy.ndarray` of shape :obj:`(batch_size, num_channels, height, width)`):
|
||||
@@ -289,7 +292,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
|
||||
return image_features
|
||||
|
||||
return self.module.apply(
|
||||
{"params": self.params},
|
||||
{"params": params or self.params},
|
||||
jnp.array(pixel_values, dtype=jnp.float32),
|
||||
not train,
|
||||
method=_get_features,
|
||||
|
||||
@@ -785,7 +785,13 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel):
|
||||
)
|
||||
|
||||
def get_text_features(
|
||||
self, input_ids, attention_mask=None, position_ids=None, dropout_rng: jax.random.PRNGKey = None, train=False
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
params: dict = None,
|
||||
dropout_rng: jax.random.PRNGKey = None,
|
||||
train=False,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
@@ -836,7 +842,7 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel):
|
||||
return text_features
|
||||
|
||||
return self.module.apply(
|
||||
{"params": self.params},
|
||||
{"params": params or self.params},
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
jnp.array(position_ids, dtype="i4"),
|
||||
@@ -845,7 +851,9 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel):
|
||||
rngs=rngs,
|
||||
)
|
||||
|
||||
def get_image_features(self, pixel_values, dropout_rng: jax.random.PRNGKey = None, train=False):
|
||||
def get_image_features(
|
||||
self, pixel_values, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train=False
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
pixel_values (:obj:`numpy.ndarray` of shape :obj:`(batch_size, num_channels, height, width)`):
|
||||
@@ -887,7 +895,7 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel):
|
||||
return image_features
|
||||
|
||||
return self.module.apply(
|
||||
{"params": self.params},
|
||||
{"params": params or self.params},
|
||||
jnp.array(pixel_values, dtype=jnp.float32),
|
||||
not train,
|
||||
method=_get_features,
|
||||
|
||||
Reference in New Issue
Block a user