[JAX] Replace all jax.tree_* calls with jax.tree_util.tree_* (#18361)
* [JAX] Replace all jax.tree_* calls with jax.tree_util.tree_* * fix double tree_util
This commit is contained in:
@@ -776,7 +776,7 @@ class FlaxModelTesterMixin:
|
||||
for model_class in self.all_model_classes:
|
||||
# check if all params are still in float32 when dtype of computation is half-precision
|
||||
model = model_class(config, dtype=jnp.float16)
|
||||
types = jax.tree_map(lambda x: x.dtype, model.params)
|
||||
types = jax.tree_util.tree_map(lambda x: x.dtype, model.params)
|
||||
types = flatten_dict(types)
|
||||
|
||||
for name, type_ in types.items():
|
||||
@@ -790,7 +790,7 @@ class FlaxModelTesterMixin:
|
||||
|
||||
# cast all params to bf16
|
||||
params = model.to_bf16(model.params)
|
||||
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
|
||||
types = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype, params))
|
||||
# test if all params are in bf16
|
||||
for name, type_ in types.items():
|
||||
self.assertEqual(type_, jnp.bfloat16, msg=f"param {name} is not in bf16.")
|
||||
@@ -802,7 +802,7 @@ class FlaxModelTesterMixin:
|
||||
mask = unflatten_dict(mask)
|
||||
|
||||
params = model.to_bf16(model.params, mask)
|
||||
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
|
||||
types = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype, params))
|
||||
# test if all params are in bf16 except key
|
||||
for name, type_ in types.items():
|
||||
if name == key:
|
||||
@@ -818,7 +818,7 @@ class FlaxModelTesterMixin:
|
||||
|
||||
# cast all params to fp16
|
||||
params = model.to_fp16(model.params)
|
||||
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
|
||||
types = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype, params))
|
||||
# test if all params are in fp16
|
||||
for name, type_ in types.items():
|
||||
self.assertEqual(type_, jnp.float16, msg=f"param {name} is not in fp16.")
|
||||
@@ -830,7 +830,7 @@ class FlaxModelTesterMixin:
|
||||
mask = unflatten_dict(mask)
|
||||
|
||||
params = model.to_fp16(model.params, mask)
|
||||
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
|
||||
types = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype, params))
|
||||
# test if all params are in fp16 except key
|
||||
for name, type_ in types.items():
|
||||
if name == key:
|
||||
@@ -849,7 +849,7 @@ class FlaxModelTesterMixin:
|
||||
params = model.to_fp32(params)
|
||||
|
||||
# test if all params are in fp32
|
||||
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
|
||||
types = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype, params))
|
||||
for name, type_ in types.items():
|
||||
self.assertEqual(type_, jnp.float32, msg=f"param {name} is not in fp32.")
|
||||
|
||||
@@ -864,7 +864,7 @@ class FlaxModelTesterMixin:
|
||||
params = model.to_fp32(params, mask)
|
||||
|
||||
# test if all params are in fp32 except key
|
||||
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
|
||||
types = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype, params))
|
||||
for name, type_ in types.items():
|
||||
if name == key:
|
||||
self.assertEqual(type_, jnp.float16, msg=f"param {name} should be in fp16.")
|
||||
@@ -884,7 +884,7 @@ class FlaxModelTesterMixin:
|
||||
|
||||
# load the weights again and check if they are still in fp16
|
||||
model = model_class.from_pretrained(tmpdirname)
|
||||
types = flatten_dict(jax.tree_map(lambda x: x.dtype, model.params))
|
||||
types = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype, model.params))
|
||||
for name, type_ in types.items():
|
||||
self.assertEqual(type_, jnp.float16, msg=f"param {name} is not in fp16.")
|
||||
|
||||
@@ -901,7 +901,7 @@ class FlaxModelTesterMixin:
|
||||
|
||||
# load the weights again and check if they are still in fp16
|
||||
model = model_class.from_pretrained(tmpdirname)
|
||||
types = flatten_dict(jax.tree_map(lambda x: x.dtype, model.params))
|
||||
types = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype, model.params))
|
||||
for name, type_ in types.items():
|
||||
self.assertEqual(type_, jnp.bfloat16, msg=f"param {name} is not in bf16.")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user