[Flax] [WIP] allow loading head model with base model weights (#12255)
* boom boom * remove flax clip example * allow loading head model with base model weights * add test * fix imports * disable save, load test for clip * add test_save_load_to_base
This commit is contained in:
@@ -348,6 +348,11 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|||||||
if cls.base_model_prefix not in dict(model.params) and cls.base_model_prefix in state:
|
if cls.base_model_prefix not in dict(model.params) and cls.base_model_prefix in state:
|
||||||
state = state[cls.base_model_prefix]
|
state = state[cls.base_model_prefix]
|
||||||
|
|
||||||
|
# if model is head model and we are loading weights from base model
|
||||||
|
# we initialize new params dict with base_model_prefix
|
||||||
|
if cls.base_model_prefix in dict(model.params) and cls.base_model_prefix not in state:
|
||||||
|
state = {cls.base_model_prefix: state}
|
||||||
|
|
||||||
# flatten dicts
|
# flatten dicts
|
||||||
state = flatten_dict(state)
|
state = flatten_dict(state)
|
||||||
|
|
||||||
|
|||||||
@@ -209,6 +209,13 @@ class FlaxCLIPVisionModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
|||||||
[self.model_tester.num_attention_heads, seq_length, seq_length],
|
[self.model_tester.num_attention_heads, seq_length, seq_length],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# FlaxCLIPVisionModel does not have any base model
|
||||||
|
def test_save_load_from_base(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_save_load_to_base(self):
|
||||||
|
pass
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
for model_class_name in self.all_model_classes:
|
for model_class_name in self.all_model_classes:
|
||||||
@@ -296,6 +303,13 @@ class FlaxCLIPTextModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
|||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = FlaxCLIPTextModelTester(self)
|
self.model_tester = FlaxCLIPTextModelTester(self)
|
||||||
|
|
||||||
|
# FlaxCLIPTextModel does not have any base model
|
||||||
|
def test_save_load_from_base(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_save_load_to_base(self):
|
||||||
|
pass
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
for model_class_name in self.all_model_classes:
|
for model_class_name in self.all_model_classes:
|
||||||
|
|||||||
@@ -32,7 +32,9 @@ if is_flax_available():
|
|||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import jaxlib.xla_extension as jax_xla
|
import jaxlib.xla_extension as jax_xla
|
||||||
from transformers import FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
from flax.core.frozen_dict import unfreeze
|
||||||
|
from flax.traverse_util import flatten_dict
|
||||||
|
from transformers import FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, FLAX_MODEL_MAPPING
|
||||||
from transformers.modeling_flax_pytorch_utils import (
|
from transformers.modeling_flax_pytorch_utils import (
|
||||||
convert_pytorch_state_dict_to_flax,
|
convert_pytorch_state_dict_to_flax,
|
||||||
load_flax_weights_in_pytorch_model,
|
load_flax_weights_in_pytorch_model,
|
||||||
@@ -273,6 +275,50 @@ class FlaxModelTesterMixin:
|
|||||||
for output_loaded, output in zip(outputs_loaded, outputs):
|
for output_loaded, output in zip(outputs_loaded, outputs):
|
||||||
self.assert_almost_equals(output_loaded, output, 1e-3)
|
self.assert_almost_equals(output_loaded, output, 1e-3)
|
||||||
|
|
||||||
|
def test_save_load_from_base(self):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
base_class = FLAX_MODEL_MAPPING[config.__class__]
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
if model_class == base_class:
|
||||||
|
continue
|
||||||
|
|
||||||
|
model = base_class(config)
|
||||||
|
base_params = flatten_dict(unfreeze(model.params))
|
||||||
|
|
||||||
|
# check that all base model weights are loaded correctly
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model.save_pretrained(tmpdirname)
|
||||||
|
head_model = model_class.from_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
base_param_from_head = flatten_dict(unfreeze(head_model.params[head_model.base_model_prefix]))
|
||||||
|
|
||||||
|
for key in base_param_from_head.keys():
|
||||||
|
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
|
||||||
|
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||||
|
|
||||||
|
def test_save_load_to_base(self):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
base_class = FLAX_MODEL_MAPPING[config.__class__]
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
if model_class == base_class:
|
||||||
|
continue
|
||||||
|
|
||||||
|
model = model_class(config)
|
||||||
|
base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))
|
||||||
|
|
||||||
|
# check that all base model weights are loaded correctly
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model.save_pretrained(tmpdirname)
|
||||||
|
base_model = base_class.from_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
base_params = flatten_dict(unfreeze(base_model.params))
|
||||||
|
|
||||||
|
for key in base_params_from_head.keys():
|
||||||
|
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||||
|
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_jit_compilation(self):
|
def test_jit_compilation(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|||||||
Reference in New Issue
Block a user