Fix Flax params dtype (#13098)
* fix inits * fix embed dtype * fix embed dtype * add test to check default dtype * quality * add type conversion methods for flax models * more robust casting * cast sinusoidal positions * update pegasus * update albert * update test * make sure dtype is passed to every module * style * fix electra dense * fix t5 * quality * add more tests * better name * use the dtype for lm head computation * fix albert * style * fix albert embed dtype * more tests * fix vision enc-dec * cleanup * fix embed dtype pegasus * fix default param test * doc * update template * fix final_logits_bias dtype * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * fix doc * fix doc * add detailed docstring for dtype parameter * remove un-necessary import Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -50,13 +50,13 @@ class FlaxHybridCLIPModule(nn.Module):
|
||||
self.visual_projection = nn.Dense(
|
||||
self.projection_dim,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(0.02),
|
||||
use_bias=False,
|
||||
)
|
||||
self.text_projection = nn.Dense(
|
||||
self.projection_dim,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype),
|
||||
kernel_init=jax.nn.initializers.normal(0.02),
|
||||
use_bias=False,
|
||||
)
|
||||
self.logit_scale = self.param("logit_scale", jax.nn.initializers.ones, [])
|
||||
|
||||
Reference in New Issue
Block a user