[FlaxCLIP] allow passing params to image and text feature methods (#13099)
* allow passing params to image and text feature method * ifx for hybrid clip as well
This commit is contained in:
@@ -208,6 +208,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
|
params: dict = None,
|
||||||
dropout_rng: jax.random.PRNGKey = None,
|
dropout_rng: jax.random.PRNGKey = None,
|
||||||
train=False,
|
train=False,
|
||||||
):
|
):
|
||||||
@@ -254,7 +255,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
|
|||||||
return text_features
|
return text_features
|
||||||
|
|
||||||
return self.module.apply(
|
return self.module.apply(
|
||||||
{"params": self.params},
|
{"params": params or self.params},
|
||||||
jnp.array(input_ids, dtype="i4"),
|
jnp.array(input_ids, dtype="i4"),
|
||||||
jnp.array(attention_mask, dtype="i4"),
|
jnp.array(attention_mask, dtype="i4"),
|
||||||
jnp.array(position_ids, dtype="i4"),
|
jnp.array(position_ids, dtype="i4"),
|
||||||
@@ -264,7 +265,9 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
|
|||||||
rngs=rngs,
|
rngs=rngs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_image_features(self, pixel_values, dropout_rng: jax.random.PRNGKey = None, train=False):
|
def get_image_features(
|
||||||
|
self, pixel_values, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train=False
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
pixel_values (:obj:`numpy.ndarray` of shape :obj:`(batch_size, num_channels, height, width)`):
|
pixel_values (:obj:`numpy.ndarray` of shape :obj:`(batch_size, num_channels, height, width)`):
|
||||||
@@ -289,7 +292,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
|
|||||||
return image_features
|
return image_features
|
||||||
|
|
||||||
return self.module.apply(
|
return self.module.apply(
|
||||||
{"params": self.params},
|
{"params": params or self.params},
|
||||||
jnp.array(pixel_values, dtype=jnp.float32),
|
jnp.array(pixel_values, dtype=jnp.float32),
|
||||||
not train,
|
not train,
|
||||||
method=_get_features,
|
method=_get_features,
|
||||||
|
|||||||
@@ -785,7 +785,13 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_text_features(
|
def get_text_features(
|
||||||
self, input_ids, attention_mask=None, position_ids=None, dropout_rng: jax.random.PRNGKey = None, train=False
|
self,
|
||||||
|
input_ids,
|
||||||
|
attention_mask=None,
|
||||||
|
position_ids=None,
|
||||||
|
params: dict = None,
|
||||||
|
dropout_rng: jax.random.PRNGKey = None,
|
||||||
|
train=False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
@@ -836,7 +842,7 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
return text_features
|
return text_features
|
||||||
|
|
||||||
return self.module.apply(
|
return self.module.apply(
|
||||||
{"params": self.params},
|
{"params": params or self.params},
|
||||||
jnp.array(input_ids, dtype="i4"),
|
jnp.array(input_ids, dtype="i4"),
|
||||||
jnp.array(attention_mask, dtype="i4"),
|
jnp.array(attention_mask, dtype="i4"),
|
||||||
jnp.array(position_ids, dtype="i4"),
|
jnp.array(position_ids, dtype="i4"),
|
||||||
@@ -845,7 +851,9 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
rngs=rngs,
|
rngs=rngs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_image_features(self, pixel_values, dropout_rng: jax.random.PRNGKey = None, train=False):
|
def get_image_features(
|
||||||
|
self, pixel_values, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train=False
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
pixel_values (:obj:`numpy.ndarray` of shape :obj:`(batch_size, num_channels, height, width)`):
|
pixel_values (:obj:`numpy.ndarray` of shape :obj:`(batch_size, num_channels, height, width)`):
|
||||||
@@ -887,7 +895,7 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
return image_features
|
return image_features
|
||||||
|
|
||||||
return self.module.apply(
|
return self.module.apply(
|
||||||
{"params": self.params},
|
{"params": params or self.params},
|
||||||
jnp.array(pixel_values, dtype=jnp.float32),
|
jnp.array(pixel_values, dtype=jnp.float32),
|
||||||
not train,
|
not train,
|
||||||
method=_get_features,
|
method=_get_features,
|
||||||
|
|||||||
Reference in New Issue
Block a user