Fix Flax params dtype (#13098)

* fix inits

* fix embed dtype

* fix embed dtype

* add test to check default dtype

* quality

* add type conversion methods for flax models

* more robust casting

* cast sinusoidal positions

* update pegasus

* update albert

* update test

* make sure dtype is passed to every module

* style

* fix electra dense

* fix t5

* quality

* add more tests

* better name

* use the dtype for lm head computation

* fix albert

* style

* fix albert embed dtype

* more tests

* fix vision enc-dec

* cleanup

* fix embed dtype pegasus

* fix default param test

* doc

* update template

* fix final_logits_bias dtype

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* fix doc

* fix doc

* add detailed docstring for dtype parameter

* remove un-necessary import

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Suraj Patil
2021-11-11 14:45:20 +05:30
committed by GitHub
parent 1c76a51615
commit e92190c0f8
23 changed files with 731 additions and 262 deletions

View File

@@ -50,13 +50,13 @@ class FlaxHybridCLIPModule(nn.Module):
self.visual_projection = nn.Dense(
self.projection_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype),
kernel_init=jax.nn.initializers.normal(0.02),
use_bias=False,
)
self.text_projection = nn.Dense(
self.projection_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype),
kernel_init=jax.nn.initializers.normal(0.02),
use_bias=False,
)
self.logit_scale = self.param("logit_scale", jax.nn.initializers.ones, [])