[Flax] [WIP] allow loading head model with base model weights (#12255)
* boom boom * remove flax clip example * allow loading head model with base model weights * add test * fix imports * disable save, load test for clip * add test_save_load_to_base
This commit is contained in:
@@ -348,6 +348,11 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
||||
if cls.base_model_prefix not in dict(model.params) and cls.base_model_prefix in state:
|
||||
state = state[cls.base_model_prefix]
|
||||
|
||||
# if model is head model and we are loading weights from base model
|
||||
# we initialize new params dict with base_model_prefix
|
||||
if cls.base_model_prefix in dict(model.params) and cls.base_model_prefix not in state:
|
||||
state = {cls.base_model_prefix: state}
|
||||
|
||||
# flatten dicts
|
||||
state = flatten_dict(state)
|
||||
|
||||
|
||||
@@ -209,6 +209,13 @@ class FlaxCLIPVisionModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
[self.model_tester.num_attention_heads, seq_length, seq_length],
|
||||
)
|
||||
|
||||
# FlaxCLIPVisionModel does not have any base model
|
||||
def test_save_load_from_base(self):
|
||||
pass
|
||||
|
||||
def test_save_load_to_base(self):
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_class_name in self.all_model_classes:
|
||||
@@ -296,6 +303,13 @@ class FlaxCLIPTextModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.model_tester = FlaxCLIPTextModelTester(self)
|
||||
|
||||
# FlaxCLIPTextModel does not have any base model
|
||||
def test_save_load_from_base(self):
|
||||
pass
|
||||
|
||||
def test_save_load_to_base(self):
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_class_name in self.all_model_classes:
|
||||
|
||||
@@ -32,7 +32,9 @@ if is_flax_available():
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jaxlib.xla_extension as jax_xla
|
||||
from transformers import FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
||||
from flax.core.frozen_dict import unfreeze
|
||||
from flax.traverse_util import flatten_dict
|
||||
from transformers import FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, FLAX_MODEL_MAPPING
|
||||
from transformers.modeling_flax_pytorch_utils import (
|
||||
convert_pytorch_state_dict_to_flax,
|
||||
load_flax_weights_in_pytorch_model,
|
||||
@@ -273,6 +275,50 @@ class FlaxModelTesterMixin:
|
||||
for output_loaded, output in zip(outputs_loaded, outputs):
|
||||
self.assert_almost_equals(output_loaded, output, 1e-3)
|
||||
|
||||
def test_save_load_from_base(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = FLAX_MODEL_MAPPING[config.__class__]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class == base_class:
|
||||
continue
|
||||
|
||||
model = base_class(config)
|
||||
base_params = flatten_dict(unfreeze(model.params))
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
head_model = model_class.from_pretrained(tmpdirname)
|
||||
|
||||
base_param_from_head = flatten_dict(unfreeze(head_model.params[head_model.base_model_prefix]))
|
||||
|
||||
for key in base_param_from_head.keys():
|
||||
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
def test_save_load_to_base(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
base_class = FLAX_MODEL_MAPPING[config.__class__]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class == base_class:
|
||||
continue
|
||||
|
||||
model = model_class(config)
|
||||
base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))
|
||||
|
||||
# check that all base model weights are loaded correctly
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
base_model = base_class.from_pretrained(tmpdirname)
|
||||
|
||||
base_params = flatten_dict(unfreeze(base_model.params))
|
||||
|
||||
for key in base_params_from_head.keys():
|
||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
@slow
|
||||
def test_jit_compilation(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
Reference in New Issue
Block a user