[Flax] Correct pt to flax conversion if from base to head (#13006)

* finish PR

* add tests

* correct tests

* finish

* correct other flax tests

* better naming

* correct naming

* finish

* apply sylvains suggestions
This commit is contained in:
Patrick von Platen
2021-08-05 18:38:50 +02:00
committed by GitHub
parent 33929448a1
commit 60e448c87e
4 changed files with 220 additions and 36 deletions

View File

@@ -213,9 +213,20 @@ class FlaxCLIPVisionModelTest(FlaxModelTesterMixin, unittest.TestCase):
def test_save_load_from_base(self):
pass
# FlaxCLIPVisionModel does not have any base model
def test_save_load_to_base(self):
pass
# FlaxCLIPVisionModel does not have any base model
@is_pt_flax_cross_test
def test_save_load_from_base_pt(self):
pass
# FlaxCLIPVisionModel does not have any base model
@is_pt_flax_cross_test
def test_save_load_to_base_pt(self):
pass
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:
@@ -307,9 +318,20 @@ class FlaxCLIPTextModelTest(FlaxModelTesterMixin, unittest.TestCase):
def test_save_load_from_base(self):
pass
# FlaxCLIPVisionModel does not have any base model
def test_save_load_to_base(self):
pass
# FlaxCLIPVisionModel does not have any base model
@is_pt_flax_cross_test
def test_save_load_from_base_pt(self):
pass
# FlaxCLIPVisionModel does not have any base model
@is_pt_flax_cross_test
def test_save_load_to_base_pt(self):
pass
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:

View File

@@ -334,6 +334,63 @@ class FlaxModelTesterMixin:
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
@is_pt_flax_cross_test
def test_save_load_from_base_pt(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = FLAX_MODEL_MAPPING[config.__class__]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
model = base_class(config)
base_params = flatten_dict(unfreeze(model.params))
# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
# save pt model
pt_model.save_pretrained(tmpdirname)
head_model = model_class.from_pretrained(tmpdirname, from_pt=True)
base_param_from_head = flatten_dict(unfreeze(head_model.params[head_model.base_model_prefix]))
for key in base_param_from_head.keys():
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
@is_pt_flax_cross_test
def test_save_load_to_base_pt(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = FLAX_MODEL_MAPPING[config.__class__]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
model = model_class(config)
base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))
# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
base_params = flatten_dict(unfreeze(base_model.params))
for key in base_params_from_head.keys():
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
@slow
def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

View File

@@ -17,8 +17,15 @@ import unittest
import numpy as np
import transformers
from transformers import is_flax_available
from transformers.testing_utils import require_flax, require_sentencepiece, require_tokenizers, slow
from transformers.testing_utils import (
is_pt_flax_cross_test,
require_flax,
require_sentencepiece,
require_tokenizers,
slow,
)
from .test_configuration_common import ConfigTester
from .test_generation_flax_utils import FlaxGenerationTesterMixin
@@ -40,6 +47,7 @@ if is_flax_available():
from flax.training.common_utils import onehot
from flax.traverse_util import flatten_dict
from transformers import FLAX_MODEL_MAPPING, ByT5Tokenizer, T5Config, T5Tokenizer
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
from transformers.models.t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, shift_tokens_right
@@ -363,6 +371,65 @@ class FlaxT5ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
# overwrite since special base model prefix is used
@is_pt_flax_cross_test
def test_save_load_from_base_pt(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = FLAX_MODEL_MAPPING[config.__class__]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
model = base_class(config)
base_params = flatten_dict(unfreeze(model.params))
# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
# save pt model
pt_model.save_pretrained(tmpdirname)
head_model = model_class.from_pretrained(tmpdirname, from_pt=True)
base_param_from_head = flatten_dict(unfreeze(head_model.params))
for key in base_param_from_head.keys():
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
# overwrite since special base model prefix is used
@is_pt_flax_cross_test
def test_save_load_to_base_pt(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
base_class = FLAX_MODEL_MAPPING[config.__class__]
for model_class in self.all_model_classes:
if model_class == base_class:
continue
model = model_class(config)
base_params_from_head = flatten_dict(unfreeze(model.params))
# convert Flax model to PyTorch model
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
base_params = flatten_dict(unfreeze(base_model.params))
for key in base_params_from_head.keys():
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
@require_sentencepiece
@require_tokenizers