[JAX] Replace all jax.tree_* calls with jax.tree_util.tree_* (#18361)

* [JAX] Replace all jax.tree_* calls with jax.tree_util.tree_*

* fix double tree_util
This commit is contained in:
Sanchit Gandhi
2022-09-09 14:18:56 +01:00
committed by GitHub
parent 22f7218560
commit e6f221c8d4
17 changed files with 49 additions and 49 deletions

View File

@@ -776,7 +776,7 @@ class FlaxModelTesterMixin:
for model_class in self.all_model_classes:
# check if all params are still in float32 when dtype of computation is half-precision
model = model_class(config, dtype=jnp.float16)
types = jax.tree_map(lambda x: x.dtype, model.params)
types = jax.tree_util.tree_map(lambda x: x.dtype, model.params)
types = flatten_dict(types)
for name, type_ in types.items():
@@ -790,7 +790,7 @@ class FlaxModelTesterMixin:
# cast all params to bf16
params = model.to_bf16(model.params)
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
types = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype, params))
# test if all params are in bf16
for name, type_ in types.items():
self.assertEqual(type_, jnp.bfloat16, msg=f"param {name} is not in bf16.")
@@ -802,7 +802,7 @@ class FlaxModelTesterMixin:
mask = unflatten_dict(mask)
params = model.to_bf16(model.params, mask)
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
types = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype, params))
# test if all params are in bf16 except key
for name, type_ in types.items():
if name == key:
@@ -818,7 +818,7 @@ class FlaxModelTesterMixin:
# cast all params to fp16
params = model.to_fp16(model.params)
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
types = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype, params))
# test if all params are in fp16
for name, type_ in types.items():
self.assertEqual(type_, jnp.float16, msg=f"param {name} is not in fp16.")
@@ -830,7 +830,7 @@ class FlaxModelTesterMixin:
mask = unflatten_dict(mask)
params = model.to_fp16(model.params, mask)
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
types = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype, params))
# test if all params are in fp16 except key
for name, type_ in types.items():
if name == key:
@@ -849,7 +849,7 @@ class FlaxModelTesterMixin:
params = model.to_fp32(params)
# test if all params are in fp32
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
types = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype, params))
for name, type_ in types.items():
self.assertEqual(type_, jnp.float32, msg=f"param {name} is not in fp32.")
@@ -864,7 +864,7 @@ class FlaxModelTesterMixin:
params = model.to_fp32(params, mask)
# test if all params are in fp32 except key
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
types = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype, params))
for name, type_ in types.items():
if name == key:
self.assertEqual(type_, jnp.float16, msg=f"param {name} should be in fp16.")
@@ -884,7 +884,7 @@ class FlaxModelTesterMixin:
# load the weights again and check if they are still in fp16
model = model_class.from_pretrained(tmpdirname)
types = flatten_dict(jax.tree_map(lambda x: x.dtype, model.params))
types = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype, model.params))
for name, type_ in types.items():
self.assertEqual(type_, jnp.float16, msg=f"param {name} is not in fp16.")
@@ -901,7 +901,7 @@ class FlaxModelTesterMixin:
# load the weights again and check if they are still in fp16
model = model_class.from_pretrained(tmpdirname)
types = flatten_dict(jax.tree_map(lambda x: x.dtype, model.params))
types = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype, model.params))
for name, type_ in types.items():
self.assertEqual(type_, jnp.bfloat16, msg=f"param {name} is not in bf16.")