From e6f221c8d4829c9a3bca699c18a32043ab21f7a0 Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Fri, 9 Sep 2022 14:18:56 +0100 Subject: [PATCH] [JAX] Replace all jax.tree_* calls with jax.tree_util.tree_* (#18361) * [JAX] Replace all jax.tree_* calls with jax.tree_util.tree_* * fix double tree_util --- .../run_image_captioning_flax.py | 4 ++-- .../flax/language-modeling/run_clm_flax.py | 4 ++-- .../flax/language-modeling/run_mlm_flax.py | 10 +++++----- .../flax/language-modeling/run_t5_mlm_flax.py | 6 +++--- .../summarization/run_summarization_flax.py | 6 +++--- .../flax/vision/run_image_classification.py | 4 ++-- .../jax-projects/big_bird/bigbird_flax.py | 2 +- .../dataset-streaming/run_mlm_flax_stream.py | 6 +++--- .../hybrid_clip/run_hybrid_clip.py | 2 +- .../jax-projects/model_parallel/run_clm_mp.py | 8 ++++---- .../wav2vec2/run_wav2vec2_pretrain_flax.py | 4 ++-- .../performer/run_mlm_performer.py | 4 ++-- src/transformers/generation_flax_utils.py | 6 +++--- .../modeling_flax_pytorch_utils.py | 4 ++-- src/transformers/modeling_flax_utils.py | 6 +++--- .../convert_owlvit_original_flax_to_hf.py | 4 ++-- tests/test_modeling_flax_common.py | 18 +++++++++--------- 17 files changed, 49 insertions(+), 49 deletions(-) diff --git a/examples/flax/image-captioning/run_image_captioning_flax.py b/examples/flax/image-captioning/run_image_captioning_flax.py index 348a719857..5b3fd187f0 100644 --- a/examples/flax/image-captioning/run_image_captioning_flax.py +++ b/examples/flax/image-captioning/run_image_captioning_flax.py @@ -1011,7 +1011,7 @@ def main(): # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: - params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) + params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)) model.save_pretrained(os.path.join(training_args.output_dir, ckpt_dir), params=params) tokenizer.save_pretrained(os.path.join(training_args.output_dir, ckpt_dir)) if training_args.push_to_hub: @@ -1064,7 +1064,7 @@ def main(): if metrics: # normalize metrics metrics = get_metrics(metrics) - metrics = jax.tree_map(jnp.mean, metrics) + metrics = jax.tree_util.tree_map(jnp.mean, metrics) # compute ROUGE metrics generations = [] diff --git a/examples/flax/language-modeling/run_clm_flax.py b/examples/flax/language-modeling/run_clm_flax.py index 1a0428fdd6..7e0d1010c1 100755 --- a/examples/flax/language-modeling/run_clm_flax.py +++ b/examples/flax/language-modeling/run_clm_flax.py @@ -781,7 +781,7 @@ def main(): # normalize eval metrics eval_metrics = get_metrics(eval_metrics) - eval_metrics = jax.tree_map(jnp.mean, eval_metrics) + eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics) try: eval_metrics["perplexity"] = math.exp(eval_metrics["loss"]) @@ -824,7 +824,7 @@ def main(): # normalize eval metrics eval_metrics = get_metrics(eval_metrics) - eval_metrics = jax.tree_map(lambda x: jnp.mean(x).item(), eval_metrics) + eval_metrics = jax.tree_util.tree_map(lambda x: jnp.mean(x).item(), eval_metrics) try: eval_metrics["perplexity"] = math.exp(eval_metrics["loss"]) diff --git a/examples/flax/language-modeling/run_mlm_flax.py b/examples/flax/language-modeling/run_mlm_flax.py index 408e09fc11..5e1519bbd5 100755 --- a/examples/flax/language-modeling/run_mlm_flax.py +++ b/examples/flax/language-modeling/run_mlm_flax.py @@ -827,9 +827,9 @@ def main(): # normalize eval metrics eval_metrics = get_metrics(eval_metrics) - eval_metrics = jax.tree_map(jnp.sum, eval_metrics) + eval_metrics = jax.tree_util.tree_map(jnp.sum, eval_metrics) eval_normalizer = eval_metrics.pop("normalizer") - eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics) + eval_metrics = jax.tree_util.tree_map(lambda x: x / eval_normalizer, eval_metrics) # Update progress bar epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})" @@ -841,7 +841,7 @@ def main(): if cur_step % training_args.save_steps == 0 and cur_step > 0: # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: - params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) + params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)) model.save_pretrained(training_args.output_dir, params=params) tokenizer.save_pretrained(training_args.output_dir) if training_args.push_to_hub: @@ -867,9 +867,9 @@ def main(): # normalize eval metrics eval_metrics = get_metrics(eval_metrics) - eval_metrics = jax.tree_map(lambda metric: jnp.sum(metric).item(), eval_metrics) + eval_metrics = jax.tree_util.tree_map(lambda metric: jnp.sum(metric).item(), eval_metrics) eval_normalizer = eval_metrics.pop("normalizer") - eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics) + eval_metrics = jax.tree_util.tree_map(lambda x: x / eval_normalizer, eval_metrics) try: perplexity = math.exp(eval_metrics["loss"]) diff --git a/examples/flax/language-modeling/run_t5_mlm_flax.py b/examples/flax/language-modeling/run_t5_mlm_flax.py index 0030fc8da6..c9d748de3d 100755 --- a/examples/flax/language-modeling/run_t5_mlm_flax.py +++ b/examples/flax/language-modeling/run_t5_mlm_flax.py @@ -940,7 +940,7 @@ def main(): # get eval metrics eval_metrics = get_metrics(eval_metrics) - eval_metrics = jax.tree_map(jnp.mean, eval_metrics) + eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics) # Update progress bar epochs.write(f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})") @@ -952,7 +952,7 @@ def main(): if cur_step % training_args.save_steps == 0 and cur_step > 0: # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: - params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) + params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)) model.save_pretrained(training_args.output_dir, params=params) tokenizer.save_pretrained(training_args.output_dir) if training_args.push_to_hub: @@ -978,7 +978,7 @@ def main(): # get eval metrics eval_metrics = get_metrics(eval_metrics) - eval_metrics = jax.tree_map(lambda metric: jnp.mean(metric).item(), eval_metrics) + eval_metrics = jax.tree_util.tree_map(lambda metric: jnp.mean(metric).item(), eval_metrics) if jax.process_index() == 0: eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()} diff --git a/examples/flax/summarization/run_summarization_flax.py b/examples/flax/summarization/run_summarization_flax.py index 2813c88a3b..ed151b8bbe 100644 --- a/examples/flax/summarization/run_summarization_flax.py +++ b/examples/flax/summarization/run_summarization_flax.py @@ -902,7 +902,7 @@ def main(): # normalize eval metrics eval_metrics = get_metrics(eval_metrics) - eval_metrics = jax.tree_map(jnp.mean, eval_metrics) + eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics) # compute ROUGE metrics rouge_desc = "" @@ -923,7 +923,7 @@ def main(): # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: - params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) + params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)) model.save_pretrained(training_args.output_dir, params=params) tokenizer.save_pretrained(training_args.output_dir) if training_args.push_to_hub: @@ -957,7 +957,7 @@ def main(): # normalize prediction metrics pred_metrics = get_metrics(pred_metrics) - pred_metrics = jax.tree_map(jnp.mean, pred_metrics) + pred_metrics = jax.tree_util.tree_map(jnp.mean, pred_metrics) # compute ROUGE metrics rouge_desc = "" diff --git a/examples/flax/vision/run_image_classification.py b/examples/flax/vision/run_image_classification.py index 3de3c977ab..22065438d2 100644 --- a/examples/flax/vision/run_image_classification.py +++ b/examples/flax/vision/run_image_classification.py @@ -542,7 +542,7 @@ def main(): # normalize eval metrics eval_metrics = get_metrics(eval_metrics) - eval_metrics = jax.tree_map(jnp.mean, eval_metrics) + eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics) # Print metrics and update progress bar eval_step_progress_bar.close() @@ -560,7 +560,7 @@ def main(): # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: - params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) + params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)) model.save_pretrained(training_args.output_dir, params=params) if training_args.push_to_hub: repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False) diff --git a/examples/research_projects/jax-projects/big_bird/bigbird_flax.py b/examples/research_projects/jax-projects/big_bird/bigbird_flax.py index d272125472..b9ff9da281 100644 --- a/examples/research_projects/jax-projects/big_bird/bigbird_flax.py +++ b/examples/research_projects/jax-projects/big_bird/bigbird_flax.py @@ -104,7 +104,7 @@ class DataCollator: def __call__(self, batch): batch = self.collate_fn(batch) - batch = jax.tree_map(shard, batch) + batch = jax.tree_util.tree_map(shard, batch) return batch def collate_fn(self, features): diff --git a/examples/research_projects/jax-projects/dataset-streaming/run_mlm_flax_stream.py b/examples/research_projects/jax-projects/dataset-streaming/run_mlm_flax_stream.py index fadcec09cb..e6bbdbee8c 100755 --- a/examples/research_projects/jax-projects/dataset-streaming/run_mlm_flax_stream.py +++ b/examples/research_projects/jax-projects/dataset-streaming/run_mlm_flax_stream.py @@ -608,9 +608,9 @@ if __name__ == "__main__": # normalize eval metrics eval_metrics = get_metrics(eval_metrics) - eval_metrics = jax.tree_map(jnp.sum, eval_metrics) + eval_metrics = jax.tree_util.tree_map(jnp.sum, eval_metrics) eval_normalizer = eval_metrics.pop("normalizer") - eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics) + eval_metrics = jax.tree_util.tree_map(lambda x: x / eval_normalizer, eval_metrics) # Update progress bar steps.desc = ( @@ -624,7 +624,7 @@ if __name__ == "__main__": # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: - params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) + params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)) model.save_pretrained( training_args.output_dir, params=params, diff --git a/examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py b/examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py index 6ee974666a..1be46f6af9 100644 --- a/examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py +++ b/examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py @@ -551,7 +551,7 @@ def main(): # normalize eval metrics eval_metrics = get_metrics(eval_metrics) - eval_metrics = jax.tree_map(jnp.mean, eval_metrics) + eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics) # Print metrics and update progress bar eval_step_progress_bar.close() diff --git a/examples/research_projects/jax-projects/model_parallel/run_clm_mp.py b/examples/research_projects/jax-projects/model_parallel/run_clm_mp.py index 518ef9f7b2..16eb1007b4 100644 --- a/examples/research_projects/jax-projects/model_parallel/run_clm_mp.py +++ b/examples/research_projects/jax-projects/model_parallel/run_clm_mp.py @@ -481,7 +481,7 @@ def main(): param_spec = set_partitions(unfreeze(model.params)) # Get the PyTree for opt_state, we don't actually initialize the opt_state yet. - params_shapes = jax.tree_map(lambda x: x.shape, model.params) + params_shapes = jax.tree_util.tree_map(lambda x: x.shape, model.params) state_shapes = jax.eval_shape(get_initial_state, params_shapes) # get PartitionSpec for opt_state, this is very specific to adamw @@ -492,7 +492,7 @@ def main(): return param_spec return None - opt_state_spec, param_spec = jax.tree_map( + opt_state_spec, param_spec = jax.tree_util.tree_map( get_opt_spec, state_shapes, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState)) ) @@ -506,7 +506,7 @@ def main(): # hack: move the inital params to CPU to free up device memory # TODO: allow loading weights on CPU in pre-trained model - model.params = jax.tree_map(lambda x: np.asarray(x), model.params) + model.params = jax.tree_util.tree_map(lambda x: np.asarray(x), model.params) # mesh defination mesh_devices = np.array(jax.devices()).reshape(1, jax.local_device_count()) @@ -636,7 +636,7 @@ def main(): # normalize eval metrics eval_metrics = stack_forest(eval_metrics) - eval_metrics = jax.tree_map(jnp.mean, eval_metrics) + eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics) try: eval_metrics["perplexity"] = math.exp(eval_metrics["loss"]) diff --git a/examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py b/examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py index 457c58d44f..71bf60d2c6 100755 --- a/examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py +++ b/examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py @@ -591,7 +591,7 @@ def main(): # get eval metrics eval_metrics = get_metrics(eval_metrics) - eval_metrics = jax.tree_map(jnp.mean, eval_metrics) + eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics) # Update progress bar epochs.write( @@ -606,7 +606,7 @@ def main(): # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: - params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) + params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)) model.save_pretrained(training_args.output_dir, params=params, push_to_hub=training_args.push_to_hub) diff --git a/examples/research_projects/performer/run_mlm_performer.py b/examples/research_projects/performer/run_mlm_performer.py index 8e8fe91765..35de233f72 100644 --- a/examples/research_projects/performer/run_mlm_performer.py +++ b/examples/research_projects/performer/run_mlm_performer.py @@ -674,9 +674,9 @@ if __name__ == "__main__": eval_metrics.append(metrics) eval_metrics_np = get_metrics(eval_metrics) - eval_metrics_np = jax.tree_map(jnp.sum, eval_metrics_np) + eval_metrics_np = jax.tree_util.tree_map(jnp.sum, eval_metrics_np) eval_normalizer = eval_metrics_np.pop("normalizer") - eval_summary = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics_np) + eval_summary = jax.tree_util.tree_map(lambda x: x / eval_normalizer, eval_metrics_np) # Update progress bar epochs.desc = ( diff --git a/src/transformers/generation_flax_utils.py b/src/transformers/generation_flax_utils.py index 1c052aae7b..353df6fdbb 100644 --- a/src/transformers/generation_flax_utils.py +++ b/src/transformers/generation_flax_utils.py @@ -699,7 +699,7 @@ class FlaxGenerationMixin: else: return tensor[batch_indices, beam_indices] - return jax.tree_map(gather_fn, nested) + return jax.tree_util.tree_map(gather_fn, nested) # init values max_length = max_length if max_length is not None else self.config.max_length @@ -788,7 +788,7 @@ class FlaxGenerationMixin: model_outputs = model(input_token, params=params, **state.model_kwargs) logits = unflatten_beam_dim(model_outputs.logits[:, -1], batch_size, num_beams) - cache = jax.tree_map( + cache = jax.tree_util.tree_map( lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams), model_outputs.past_key_values ) @@ -874,7 +874,7 @@ class FlaxGenerationMixin: # With these, gather the top k beam-associated caches. next_running_indices = gather_beams(topk_beam_indices, next_topk_indices, batch_size, num_beams) next_cache = gather_beams(cache, next_running_indices, batch_size, num_beams) - model_outputs["past_key_values"] = jax.tree_map(lambda x: flatten_beam_dim(x), next_cache) + model_outputs["past_key_values"] = jax.tree_util.tree_map(lambda x: flatten_beam_dim(x), next_cache) next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs) return BeamSearchState( diff --git a/src/transformers/modeling_flax_pytorch_utils.py b/src/transformers/modeling_flax_pytorch_utils.py index 76eaa53f89..47da8c2871 100644 --- a/src/transformers/modeling_flax_pytorch_utils.py +++ b/src/transformers/modeling_flax_pytorch_utils.py @@ -253,7 +253,7 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state): raise # check if we have bf16 weights - is_type_bf16 = flatten_dict(jax.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values() + is_type_bf16 = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values() if any(is_type_bf16): # convert all weights to fp32 if the are bf16 since torch.from_numpy can-not handle bf16 # and bf16 is not fully supported in PT yet. @@ -261,7 +261,7 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state): "Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` " "before loading those in PyTorch model." ) - flax_state = jax.tree_map( + flax_state = jax.tree_util.tree_map( lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state ) diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index 00bb5480ff..b19f3db77e 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -303,10 +303,10 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): return param if mask is None: - return jax.tree_map(conditional_cast, params) + return jax.tree_util.tree_map(conditional_cast, params) flat_params = flatten_dict(params) - flat_mask, _ = jax.tree_flatten(mask) + flat_mask, _ = jax.tree_util.tree_flatten(mask) for masked, key in zip(flat_mask, flat_params.keys()): if masked: @@ -900,7 +900,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ) # dictionary of key: dtypes for the model params - param_dtypes = jax.tree_map(lambda x: x.dtype, state) + param_dtypes = jax.tree_util.tree_map(lambda x: x.dtype, state) # extract keys of parameters not in jnp.float32 fp16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.float16] bf16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.bfloat16] diff --git a/src/transformers/models/owlvit/convert_owlvit_original_flax_to_hf.py b/src/transformers/models/owlvit/convert_owlvit_original_flax_to_hf.py index dde57c168a..09942fa392 100644 --- a/src/transformers/models/owlvit/convert_owlvit_original_flax_to_hf.py +++ b/src/transformers/models/owlvit/convert_owlvit_original_flax_to_hf.py @@ -90,7 +90,7 @@ def flatten_nested_dict(params, parent_key="", sep="/"): def to_f32(params): - return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, params) + return jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, params) def copy_attn_layer(hf_attn_layer, pt_attn_layer): @@ -398,7 +398,7 @@ if __name__ == "__main__": # Load from checkpoint and convert params to float-32 variables = checkpoints.restore_checkpoint(args.owlvit_checkpoint, target=None)["optimizer"]["target"] - flax_params = jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, variables) + flax_params = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, variables) del variables # Convert CLIP backbone diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index 837f874889..37171e2138 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -776,7 +776,7 @@ class FlaxModelTesterMixin: for model_class in self.all_model_classes: # check if all params are still in float32 when dtype of computation is half-precision model = model_class(config, dtype=jnp.float16) - types = jax.tree_map(lambda x: x.dtype, model.params) + types = jax.tree_util.tree_map(lambda x: x.dtype, model.params) types = flatten_dict(types) for name, type_ in types.items(): @@ -790,7 +790,7 @@ class FlaxModelTesterMixin: # cast all params to bf16 params = model.to_bf16(model.params) - types = flatten_dict(jax.tree_map(lambda x: x.dtype, params)) + types = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype, params)) # test if all params are in bf16 for name, type_ in types.items(): self.assertEqual(type_, jnp.bfloat16, msg=f"param {name} is not in bf16.") @@ -802,7 +802,7 @@ class FlaxModelTesterMixin: mask = unflatten_dict(mask) params = model.to_bf16(model.params, mask) - types = flatten_dict(jax.tree_map(lambda x: x.dtype, params)) + types = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype, params)) # test if all params are in bf16 except key for name, type_ in types.items(): if name == key: @@ -818,7 +818,7 @@ class FlaxModelTesterMixin: # cast all params to fp16 params = model.to_fp16(model.params) - types = flatten_dict(jax.tree_map(lambda x: x.dtype, params)) + types = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype, params)) # test if all params are in fp16 for name, type_ in types.items(): self.assertEqual(type_, jnp.float16, msg=f"param {name} is not in fp16.") @@ -830,7 +830,7 @@ class FlaxModelTesterMixin: mask = unflatten_dict(mask) params = model.to_fp16(model.params, mask) - types = flatten_dict(jax.tree_map(lambda x: x.dtype, params)) + types = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype, params)) # test if all params are in fp16 except key for name, type_ in types.items(): if name == key: @@ -849,7 +849,7 @@ class FlaxModelTesterMixin: params = model.to_fp32(params) # test if all params are in fp32 - types = flatten_dict(jax.tree_map(lambda x: x.dtype, params)) + types = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype, params)) for name, type_ in types.items(): self.assertEqual(type_, jnp.float32, msg=f"param {name} is not in fp32.") @@ -864,7 +864,7 @@ class FlaxModelTesterMixin: params = model.to_fp32(params, mask) # test if all params are in fp32 except key - types = flatten_dict(jax.tree_map(lambda x: x.dtype, params)) + types = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype, params)) for name, type_ in types.items(): if name == key: self.assertEqual(type_, jnp.float16, msg=f"param {name} should be in fp16.") @@ -884,7 +884,7 @@ class FlaxModelTesterMixin: # load the weights again and check if they are still in fp16 model = model_class.from_pretrained(tmpdirname) - types = flatten_dict(jax.tree_map(lambda x: x.dtype, model.params)) + types = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype, model.params)) for name, type_ in types.items(): self.assertEqual(type_, jnp.float16, msg=f"param {name} is not in fp16.") @@ -901,7 +901,7 @@ class FlaxModelTesterMixin: # load the weights again and check if they are still in fp16 model = model_class.from_pretrained(tmpdirname) - types = flatten_dict(jax.tree_map(lambda x: x.dtype, model.params)) + types = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype, model.params)) for name, type_ in types.items(): self.assertEqual(type_, jnp.bfloat16, msg=f"param {name} is not in bf16.")