[FlaxCLIP] allow passing params to image and text feature methods (#13099)

* allow passing params to image and text feature method

* ifx for hybrid clip as well
This commit is contained in:
Suraj Patil
2021-08-12 18:35:01 +05:30
committed by GitHub
parent 9a498c37a2
commit f5cd27694a
2 changed files with 18 additions and 7 deletions

View File

@@ -208,6 +208,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
attention_mask=None,
position_ids=None,
token_type_ids=None,
params: dict = None,
dropout_rng: jax.random.PRNGKey = None,
train=False,
):
@@ -254,7 +255,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
return text_features
return self.module.apply(
{"params": self.params},
{"params": params or self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
@@ -264,7 +265,9 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
rngs=rngs,
)
def get_image_features(self, pixel_values, dropout_rng: jax.random.PRNGKey = None, train=False):
def get_image_features(
self, pixel_values, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train=False
):
r"""
Args:
pixel_values (:obj:`numpy.ndarray` of shape :obj:`(batch_size, num_channels, height, width)`):
@@ -289,7 +292,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
return image_features
return self.module.apply(
{"params": self.params},
{"params": params or self.params},
jnp.array(pixel_values, dtype=jnp.float32),
not train,
method=_get_features,

View File

@@ -785,7 +785,13 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel):
)
def get_text_features(
self, input_ids, attention_mask=None, position_ids=None, dropout_rng: jax.random.PRNGKey = None, train=False
self,
input_ids,
attention_mask=None,
position_ids=None,
params: dict = None,
dropout_rng: jax.random.PRNGKey = None,
train=False,
):
r"""
Args:
@@ -836,7 +842,7 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel):
return text_features
return self.module.apply(
{"params": self.params},
{"params": params or self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
@@ -845,7 +851,9 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel):
rngs=rngs,
)
def get_image_features(self, pixel_values, dropout_rng: jax.random.PRNGKey = None, train=False):
def get_image_features(
self, pixel_values, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train=False
):
r"""
Args:
pixel_values (:obj:`numpy.ndarray` of shape :obj:`(batch_size, num_channels, height, width)`):
@@ -887,7 +895,7 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel):
return image_features
return self.module.apply(
{"params": self.params},
{"params": params or self.params},
jnp.array(pixel_values, dtype=jnp.float32),
not train,
method=_get_features,