Add FlaxCLIP (#11883)
* add flax CLIP * default input_shape * add tests * fix test * fix name * fix docs * fix shapes * attend at least 1 token * flax conv to torch conv * return floats * fix equivalence tests * fix import * return attention_weights and update tests * fix dosctrings * address patricks comments * input_shape arg * add tests for get_image_features and get_text_features methods * fix tests
This commit is contained in:
@@ -60,6 +60,22 @@ def ids_tensor(shape, vocab_size, rng=None):
|
||||
return output
|
||||
|
||||
|
||||
def floats_tensor(shape, scale=1.0, rng=None, name=None):
|
||||
"""Creates a random float32 tensor"""
|
||||
if rng is None:
|
||||
rng = random.Random()
|
||||
|
||||
total_dims = 1
|
||||
for dim in shape:
|
||||
total_dims *= dim
|
||||
|
||||
values = []
|
||||
for _ in range(total_dims):
|
||||
values.append(rng.random() * scale)
|
||||
|
||||
return np.array(values, dtype=jnp.float32).reshape(shape)
|
||||
|
||||
|
||||
def random_attention_mask(shape, rng=None):
|
||||
attn_mask = ids_tensor(shape, vocab_size=2, rng=rng)
|
||||
# make sure that at least one token is attended to for each batch
|
||||
|
||||
Reference in New Issue
Block a user