From f5cd27694a0c7d0036954c8350f774a5c1181a57 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Thu, 12 Aug 2021 18:35:01 +0530 Subject: [PATCH] [FlaxCLIP] allow passing params to image and text feature methods (#13099) * allow passing params to image and text feature method * ifx for hybrid clip as well --- .../hybrid_clip/modeling_hybrid_clip.py | 9 ++++++--- .../models/clip/modeling_flax_clip.py | 16 ++++++++++++---- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/examples/research_projects/jax-projects/hybrid_clip/modeling_hybrid_clip.py b/examples/research_projects/jax-projects/hybrid_clip/modeling_hybrid_clip.py index df7693ef0b..1348cf99af 100644 --- a/examples/research_projects/jax-projects/hybrid_clip/modeling_hybrid_clip.py +++ b/examples/research_projects/jax-projects/hybrid_clip/modeling_hybrid_clip.py @@ -208,6 +208,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel): attention_mask=None, position_ids=None, token_type_ids=None, + params: dict = None, dropout_rng: jax.random.PRNGKey = None, train=False, ): @@ -254,7 +255,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel): return text_features return self.module.apply( - {"params": self.params}, + {"params": params or self.params}, jnp.array(input_ids, dtype="i4"), jnp.array(attention_mask, dtype="i4"), jnp.array(position_ids, dtype="i4"), @@ -264,7 +265,9 @@ class FlaxHybridCLIP(FlaxPreTrainedModel): rngs=rngs, ) - def get_image_features(self, pixel_values, dropout_rng: jax.random.PRNGKey = None, train=False): + def get_image_features( + self, pixel_values, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train=False + ): r""" Args: pixel_values (:obj:`numpy.ndarray` of shape :obj:`(batch_size, num_channels, height, width)`): @@ -289,7 +292,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel): return image_features return self.module.apply( - {"params": self.params}, + {"params": params or self.params}, jnp.array(pixel_values, dtype=jnp.float32), not train, method=_get_features, diff --git a/src/transformers/models/clip/modeling_flax_clip.py b/src/transformers/models/clip/modeling_flax_clip.py index 2285bbf1f9..ff5efc050b 100644 --- a/src/transformers/models/clip/modeling_flax_clip.py +++ b/src/transformers/models/clip/modeling_flax_clip.py @@ -785,7 +785,13 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel): ) def get_text_features( - self, input_ids, attention_mask=None, position_ids=None, dropout_rng: jax.random.PRNGKey = None, train=False + self, + input_ids, + attention_mask=None, + position_ids=None, + params: dict = None, + dropout_rng: jax.random.PRNGKey = None, + train=False, ): r""" Args: @@ -836,7 +842,7 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel): return text_features return self.module.apply( - {"params": self.params}, + {"params": params or self.params}, jnp.array(input_ids, dtype="i4"), jnp.array(attention_mask, dtype="i4"), jnp.array(position_ids, dtype="i4"), @@ -845,7 +851,9 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel): rngs=rngs, ) - def get_image_features(self, pixel_values, dropout_rng: jax.random.PRNGKey = None, train=False): + def get_image_features( + self, pixel_values, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train=False + ): r""" Args: pixel_values (:obj:`numpy.ndarray` of shape :obj:`(batch_size, num_channels, height, width)`): @@ -887,7 +895,7 @@ class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel): return image_features return self.module.apply( - {"params": self.params}, + {"params": params or self.params}, jnp.array(pixel_values, dtype=jnp.float32), not train, method=_get_features,