Fix Flax params dtype (#13098)
* fix inits * fix embed dtype * fix embed dtype * add test to check default dtype * quality * add type conversion methods for flax models * more robust casting * cast sinusoidal positions * update pegasus * update albert * update test * make sure dtype is passed to every module * style * fix electra dense * fix t5 * quality * add more tests * better name * use the dtype for lm head computation * fix albert * style * fix albert embed dtype * more tests * fix vision enc-dec * cleanup * fix embed dtype pegasus * fix default param test * doc * update template * fix final_logits_bias dtype * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * fix doc * fix doc * add detailed docstring for dtype parameter * remove un-necessary import Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -36,7 +36,7 @@ if is_flax_available():
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import unfreeze
|
||||
from flax.traverse_util import flatten_dict
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from transformers import (
|
||||
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
@@ -613,6 +613,141 @@ class FlaxModelTesterMixin:
|
||||
else:
|
||||
new_model_without_prefix(input_ids)
|
||||
|
||||
def test_default_params_dtype(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
# check if all params are still in float32 when dtype of computation is half-precision
|
||||
model = model_class(config, dtype=jnp.float16)
|
||||
types = jax.tree_map(lambda x: x.dtype, model.params)
|
||||
types = flatten_dict(types)
|
||||
|
||||
for name, type_ in types.items():
|
||||
self.assertEquals(type_, jnp.float32, msg=f"param {name} is not initialized in fp32.")
|
||||
|
||||
def test_to_bf16(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
# cast all params to bf16
|
||||
params = model.to_bf16(model.params)
|
||||
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
|
||||
# test if all params are in bf16
|
||||
for name, type_ in types.items():
|
||||
self.assertEqual(type_, jnp.bfloat16, msg=f"param {name} is not in bf16.")
|
||||
|
||||
# test masking
|
||||
flat_params = flatten_dict(params)
|
||||
key = random.choice(list(flat_params.keys())) # choose a random param
|
||||
mask = {path: path != key for path in flat_params} # don't cast the key
|
||||
mask = unflatten_dict(mask)
|
||||
|
||||
params = model.to_bf16(model.params, mask)
|
||||
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
|
||||
# test if all params are in bf16 except key
|
||||
for name, type_ in types.items():
|
||||
if name == key:
|
||||
self.assertEqual(type_, jnp.float32, msg=f"param {name} should be in fp32.")
|
||||
else:
|
||||
self.assertEqual(type_, jnp.bfloat16, msg=f"param {name} is not in bf16.")
|
||||
|
||||
def test_to_fp16(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
# cast all params to fp16
|
||||
params = model.to_fp16(model.params)
|
||||
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
|
||||
# test if all params are in fp16
|
||||
for name, type_ in types.items():
|
||||
self.assertEqual(type_, jnp.float16, msg=f"param {name} is not in fp16.")
|
||||
|
||||
# test masking
|
||||
flat_params = flatten_dict(params)
|
||||
key = random.choice(list(flat_params.keys())) # choose a random param
|
||||
mask = {path: path != key for path in flat_params} # don't cast the key
|
||||
mask = unflatten_dict(mask)
|
||||
|
||||
params = model.to_fp16(model.params, mask)
|
||||
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
|
||||
# test if all params are in fp16 except key
|
||||
for name, type_ in types.items():
|
||||
if name == key:
|
||||
self.assertEqual(type_, jnp.float32, msg=f"param {name} should be in fp32.")
|
||||
else:
|
||||
self.assertEqual(type_, jnp.float16, msg=f"param {name} is not in fp16.")
|
||||
|
||||
def test_to_fp32(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
# cast all params to fp16 and back to fp32
|
||||
params = model.to_fp16(model.params)
|
||||
params = model.to_fp32(params)
|
||||
|
||||
# test if all params are in fp32
|
||||
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
|
||||
for name, type_ in types.items():
|
||||
self.assertEqual(type_, jnp.float32, msg=f"param {name} is not in fp32.")
|
||||
|
||||
# test masking
|
||||
flat_params = flatten_dict(params)
|
||||
key = random.choice(list(flat_params.keys())) # choose a random param
|
||||
mask = {path: path != key for path in flat_params} # don't cast the key
|
||||
mask = unflatten_dict(mask)
|
||||
|
||||
# cast to fp16 and back to fp32 with mask
|
||||
params = model.to_fp16(model.params)
|
||||
params = model.to_fp32(params, mask)
|
||||
|
||||
# test if all params are in fp32 except key
|
||||
types = flatten_dict(jax.tree_map(lambda x: x.dtype, params))
|
||||
for name, type_ in types.items():
|
||||
if name == key:
|
||||
self.assertEqual(type_, jnp.float16, msg=f"param {name} should be in fp16.")
|
||||
else:
|
||||
self.assertEqual(type_, jnp.float32, msg=f"param {name} is not in fp32.")
|
||||
|
||||
def test_save_load_in_fp16(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
# convert weights to fp16 and save
|
||||
params = model.to_fp16(model.params)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname, params=params)
|
||||
|
||||
# load the weights again and check if they are still in fp16
|
||||
model = model_class.from_pretrained(tmpdirname)
|
||||
types = flatten_dict(jax.tree_map(lambda x: x.dtype, model.params))
|
||||
for name, type_ in types.items():
|
||||
self.assertEqual(type_, jnp.float16, msg=f"param {name} is not in fp16.")
|
||||
|
||||
def test_save_load_in_bf16(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
# convert weights to bf16 and save
|
||||
params = model.to_bf16(model.params)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname, params=params)
|
||||
|
||||
# load the weights again and check if they are still in fp16
|
||||
model = model_class.from_pretrained(tmpdirname)
|
||||
types = flatten_dict(jax.tree_map(lambda x: x.dtype, model.params))
|
||||
for name, type_ in types.items():
|
||||
self.assertEqual(type_, jnp.bfloat16, msg=f"param {name} is not in bf16.")
|
||||
|
||||
|
||||
@require_flax
|
||||
@is_staging_test
|
||||
|
||||
Reference in New Issue
Block a user