[Flax] Correct pt to flax conversion if from base to head (#13006)
* finish PR * add tests * correct tests * finish * correct other flax tests * better naming * correct naming * finish * apply sylvains suggestions
This commit is contained in:
committed by
GitHub
parent
33929448a1
commit
60e448c87e
@@ -17,6 +17,7 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from pickle import UnpicklingError
|
from pickle import UnpicklingError
|
||||||
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -58,60 +59,97 @@ def load_pytorch_checkpoint_in_flax_state_dict(flax_model, pytorch_checkpoint_pa
|
|||||||
return flax_state_dict
|
return flax_state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def rename_key_and_reshape_tensor(
|
||||||
|
pt_tuple_key: Tuple[str],
|
||||||
|
pt_tensor: np.ndarray,
|
||||||
|
random_flax_state_dict: Dict[str, jnp.ndarray],
|
||||||
|
model_prefix: str,
|
||||||
|
) -> (Tuple[str], np.ndarray):
|
||||||
|
"""Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""
|
||||||
|
|
||||||
|
def is_key_or_prefix_key_in_dict(key: Tuple[str]) -> bool:
|
||||||
|
"""Checks if ``key`` of ``(prefix,) + key`` is in random_flax_state_dict"""
|
||||||
|
return len(set(random_flax_state_dict) & set([key, (model_prefix,) + key])) > 0
|
||||||
|
|
||||||
|
# layer norm
|
||||||
|
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
|
||||||
|
if pt_tuple_key[-1] in ["weight", "gamma"] and is_key_or_prefix_key_in_dict(renamed_pt_tuple_key):
|
||||||
|
return renamed_pt_tuple_key, pt_tensor
|
||||||
|
|
||||||
|
# embedding
|
||||||
|
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
|
||||||
|
if pt_tuple_key[-1] == "weight" and is_key_or_prefix_key_in_dict(renamed_pt_tuple_key):
|
||||||
|
return renamed_pt_tuple_key, pt_tensor
|
||||||
|
|
||||||
|
# conv layer
|
||||||
|
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
|
||||||
|
if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4 and not is_key_or_prefix_key_in_dict(pt_tuple_key):
|
||||||
|
pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
|
||||||
|
return renamed_pt_tuple_key, pt_tensor
|
||||||
|
|
||||||
|
# linear layer
|
||||||
|
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
|
||||||
|
if pt_tuple_key[-1] == "weight" and not is_key_or_prefix_key_in_dict(pt_tuple_key):
|
||||||
|
pt_tensor = pt_tensor.T
|
||||||
|
return renamed_pt_tuple_key, pt_tensor
|
||||||
|
|
||||||
|
# old PyTorch layer norm weight
|
||||||
|
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
|
||||||
|
if pt_tuple_key[-1] == "gamma":
|
||||||
|
return renamed_pt_tuple_key, pt_tensor
|
||||||
|
|
||||||
|
# old PyTorch layer norm bias
|
||||||
|
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
|
||||||
|
if pt_tuple_key[-1] == "beta":
|
||||||
|
return renamed_pt_tuple_key, pt_tensor
|
||||||
|
|
||||||
|
return pt_tuple_key, pt_tensor
|
||||||
|
|
||||||
|
|
||||||
def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
|
def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
|
||||||
# convert pytorch tensor to numpy
|
# convert pytorch tensor to numpy
|
||||||
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
|
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
|
||||||
|
|
||||||
|
model_prefix = flax_model.base_model_prefix
|
||||||
random_flax_state_dict = flatten_dict(flax_model.params)
|
random_flax_state_dict = flatten_dict(flax_model.params)
|
||||||
flax_state_dict = {}
|
flax_state_dict = {}
|
||||||
|
|
||||||
remove_base_model_prefix = (flax_model.base_model_prefix not in flax_model.params) and (
|
load_model_with_head_into_base_model = (model_prefix not in flax_model.params) and (
|
||||||
flax_model.base_model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()])
|
model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()])
|
||||||
)
|
)
|
||||||
add_base_model_prefix = (flax_model.base_model_prefix in flax_model.params) and (
|
load_base_model_into_model_with_head = (model_prefix in flax_model.params) and (
|
||||||
flax_model.base_model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()])
|
model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()])
|
||||||
)
|
)
|
||||||
|
|
||||||
# Need to change some parameters name to match Flax names so that we don't have to fork any layer
|
# Need to change some parameters name to match Flax names
|
||||||
for pt_key, pt_tensor in pt_state_dict.items():
|
for pt_key, pt_tensor in pt_state_dict.items():
|
||||||
|
|
||||||
pt_tuple_key = tuple(pt_key.split("."))
|
pt_tuple_key = tuple(pt_key.split("."))
|
||||||
|
|
||||||
has_base_model_prefix = pt_tuple_key[0] == flax_model.base_model_prefix
|
# remove base model prefix if necessary
|
||||||
require_base_model_prefix = (flax_model.base_model_prefix,) + pt_tuple_key in random_flax_state_dict
|
has_base_model_prefix = pt_tuple_key[0] == model_prefix
|
||||||
|
if load_model_with_head_into_base_model and has_base_model_prefix:
|
||||||
if remove_base_model_prefix and has_base_model_prefix:
|
|
||||||
pt_tuple_key = pt_tuple_key[1:]
|
pt_tuple_key = pt_tuple_key[1:]
|
||||||
elif add_base_model_prefix and require_base_model_prefix:
|
|
||||||
pt_tuple_key = (flax_model.base_model_prefix,) + pt_tuple_key
|
|
||||||
|
|
||||||
# Correctly rename weight parameters
|
# Correctly rename weight parameters
|
||||||
if pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
|
flax_key, flax_tensor = rename_key_and_reshape_tensor(
|
||||||
pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
|
pt_tuple_key, pt_tensor, random_flax_state_dict, model_prefix
|
||||||
if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
|
)
|
||||||
pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
|
|
||||||
elif pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4 and pt_tuple_key not in random_flax_state_dict:
|
|
||||||
# conv layer
|
|
||||||
pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
|
|
||||||
pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
|
|
||||||
elif pt_tuple_key[-1] == "weight" and pt_tuple_key not in random_flax_state_dict:
|
|
||||||
# linear layer
|
|
||||||
pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
|
|
||||||
pt_tensor = pt_tensor.T
|
|
||||||
elif pt_tuple_key[-1] == "gamma":
|
|
||||||
pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
|
|
||||||
elif pt_tuple_key[-1] == "beta":
|
|
||||||
pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
|
|
||||||
|
|
||||||
if pt_tuple_key in random_flax_state_dict:
|
# add model prefix if necessary
|
||||||
if pt_tensor.shape != random_flax_state_dict[pt_tuple_key].shape:
|
require_base_model_prefix = (model_prefix,) + flax_key in random_flax_state_dict
|
||||||
|
if load_base_model_into_model_with_head and require_base_model_prefix:
|
||||||
|
flax_key = (model_prefix,) + flax_key
|
||||||
|
|
||||||
|
if flax_key in random_flax_state_dict:
|
||||||
|
if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
|
f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
|
||||||
f"{random_flax_state_dict[pt_tuple_key].shape}, but is {pt_tensor.shape}."
|
f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
|
||||||
)
|
)
|
||||||
|
|
||||||
# also add unexpected weight so that warning is thrown
|
# also add unexpected weight so that warning is thrown
|
||||||
flax_state_dict[pt_tuple_key] = jnp.asarray(pt_tensor)
|
flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
|
||||||
|
|
||||||
return unflatten_dict(flax_state_dict)
|
return unflatten_dict(flax_state_dict)
|
||||||
|
|
||||||
@@ -154,10 +192,10 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state):
|
|||||||
flax_state_dict = flatten_dict(flax_state)
|
flax_state_dict = flatten_dict(flax_state)
|
||||||
pt_model_dict = pt_model.state_dict()
|
pt_model_dict = pt_model.state_dict()
|
||||||
|
|
||||||
remove_base_model_prefix = (pt_model.base_model_prefix in flax_state) and (
|
load_model_with_head_into_base_model = (pt_model.base_model_prefix in flax_state) and (
|
||||||
pt_model.base_model_prefix not in set([k.split(".")[0] for k in pt_model_dict.keys()])
|
pt_model.base_model_prefix not in set([k.split(".")[0] for k in pt_model_dict.keys()])
|
||||||
)
|
)
|
||||||
add_base_model_prefix = (pt_model.base_model_prefix not in flax_state) and (
|
load_base_model_into_model_with_head = (pt_model.base_model_prefix not in flax_state) and (
|
||||||
pt_model.base_model_prefix in set([k.split(".")[0] for k in pt_model_dict.keys()])
|
pt_model.base_model_prefix in set([k.split(".")[0] for k in pt_model_dict.keys()])
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -170,9 +208,9 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state):
|
|||||||
require_base_model_prefix = ".".join((pt_model.base_model_prefix,) + flax_key_tuple) in pt_model_dict
|
require_base_model_prefix = ".".join((pt_model.base_model_prefix,) + flax_key_tuple) in pt_model_dict
|
||||||
|
|
||||||
# adapt flax_key to prepare for loading from/to base model only
|
# adapt flax_key to prepare for loading from/to base model only
|
||||||
if remove_base_model_prefix and has_base_model_prefix:
|
if load_model_with_head_into_base_model and has_base_model_prefix:
|
||||||
flax_key_tuple = flax_key_tuple[1:]
|
flax_key_tuple = flax_key_tuple[1:]
|
||||||
elif add_base_model_prefix and require_base_model_prefix:
|
elif load_base_model_into_model_with_head and require_base_model_prefix:
|
||||||
flax_key_tuple = (pt_model.base_model_prefix,) + flax_key_tuple
|
flax_key_tuple = (pt_model.base_model_prefix,) + flax_key_tuple
|
||||||
|
|
||||||
# rename flax weights to PyTorch format
|
# rename flax weights to PyTorch format
|
||||||
|
|||||||
@@ -213,9 +213,20 @@ class FlaxCLIPVisionModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
|||||||
def test_save_load_from_base(self):
|
def test_save_load_from_base(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# FlaxCLIPVisionModel does not have any base model
|
||||||
def test_save_load_to_base(self):
|
def test_save_load_to_base(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# FlaxCLIPVisionModel does not have any base model
|
||||||
|
@is_pt_flax_cross_test
|
||||||
|
def test_save_load_from_base_pt(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# FlaxCLIPVisionModel does not have any base model
|
||||||
|
@is_pt_flax_cross_test
|
||||||
|
def test_save_load_to_base_pt(self):
|
||||||
|
pass
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
for model_class_name in self.all_model_classes:
|
for model_class_name in self.all_model_classes:
|
||||||
@@ -307,9 +318,20 @@ class FlaxCLIPTextModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
|||||||
def test_save_load_from_base(self):
|
def test_save_load_from_base(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# FlaxCLIPVisionModel does not have any base model
|
||||||
def test_save_load_to_base(self):
|
def test_save_load_to_base(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# FlaxCLIPVisionModel does not have any base model
|
||||||
|
@is_pt_flax_cross_test
|
||||||
|
def test_save_load_from_base_pt(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# FlaxCLIPVisionModel does not have any base model
|
||||||
|
@is_pt_flax_cross_test
|
||||||
|
def test_save_load_to_base_pt(self):
|
||||||
|
pass
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
for model_class_name in self.all_model_classes:
|
for model_class_name in self.all_model_classes:
|
||||||
|
|||||||
@@ -334,6 +334,63 @@ class FlaxModelTesterMixin:
|
|||||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||||
|
|
||||||
|
@is_pt_flax_cross_test
|
||||||
|
def test_save_load_from_base_pt(self):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
base_class = FLAX_MODEL_MAPPING[config.__class__]
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
if model_class == base_class:
|
||||||
|
continue
|
||||||
|
|
||||||
|
model = base_class(config)
|
||||||
|
base_params = flatten_dict(unfreeze(model.params))
|
||||||
|
|
||||||
|
# convert Flax model to PyTorch model
|
||||||
|
pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||||
|
pt_model = pt_model_class(config).eval()
|
||||||
|
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
|
||||||
|
|
||||||
|
# check that all base model weights are loaded correctly
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
# save pt model
|
||||||
|
pt_model.save_pretrained(tmpdirname)
|
||||||
|
head_model = model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||||
|
|
||||||
|
base_param_from_head = flatten_dict(unfreeze(head_model.params[head_model.base_model_prefix]))
|
||||||
|
|
||||||
|
for key in base_param_from_head.keys():
|
||||||
|
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
|
||||||
|
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||||
|
|
||||||
|
@is_pt_flax_cross_test
|
||||||
|
def test_save_load_to_base_pt(self):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
base_class = FLAX_MODEL_MAPPING[config.__class__]
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
if model_class == base_class:
|
||||||
|
continue
|
||||||
|
|
||||||
|
model = model_class(config)
|
||||||
|
base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))
|
||||||
|
|
||||||
|
# convert Flax model to PyTorch model
|
||||||
|
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||||
|
pt_model = pt_model_class(config).eval()
|
||||||
|
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
|
||||||
|
|
||||||
|
# check that all base model weights are loaded correctly
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
pt_model.save_pretrained(tmpdirname)
|
||||||
|
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
|
||||||
|
|
||||||
|
base_params = flatten_dict(unfreeze(base_model.params))
|
||||||
|
|
||||||
|
for key in base_params_from_head.keys():
|
||||||
|
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||||
|
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_jit_compilation(self):
|
def test_jit_compilation(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|||||||
@@ -17,8 +17,15 @@ import unittest
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
import transformers
|
||||||
from transformers import is_flax_available
|
from transformers import is_flax_available
|
||||||
from transformers.testing_utils import require_flax, require_sentencepiece, require_tokenizers, slow
|
from transformers.testing_utils import (
|
||||||
|
is_pt_flax_cross_test,
|
||||||
|
require_flax,
|
||||||
|
require_sentencepiece,
|
||||||
|
require_tokenizers,
|
||||||
|
slow,
|
||||||
|
)
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_generation_flax_utils import FlaxGenerationTesterMixin
|
from .test_generation_flax_utils import FlaxGenerationTesterMixin
|
||||||
@@ -40,6 +47,7 @@ if is_flax_available():
|
|||||||
from flax.training.common_utils import onehot
|
from flax.training.common_utils import onehot
|
||||||
from flax.traverse_util import flatten_dict
|
from flax.traverse_util import flatten_dict
|
||||||
from transformers import FLAX_MODEL_MAPPING, ByT5Tokenizer, T5Config, T5Tokenizer
|
from transformers import FLAX_MODEL_MAPPING, ByT5Tokenizer, T5Config, T5Tokenizer
|
||||||
|
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
|
||||||
from transformers.models.t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, shift_tokens_right
|
from transformers.models.t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, shift_tokens_right
|
||||||
|
|
||||||
|
|
||||||
@@ -363,6 +371,65 @@ class FlaxT5ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.
|
|||||||
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||||
|
|
||||||
|
# overwrite since special base model prefix is used
|
||||||
|
@is_pt_flax_cross_test
|
||||||
|
def test_save_load_from_base_pt(self):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
base_class = FLAX_MODEL_MAPPING[config.__class__]
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
if model_class == base_class:
|
||||||
|
continue
|
||||||
|
|
||||||
|
model = base_class(config)
|
||||||
|
base_params = flatten_dict(unfreeze(model.params))
|
||||||
|
|
||||||
|
# convert Flax model to PyTorch model
|
||||||
|
pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||||
|
pt_model = pt_model_class(config).eval()
|
||||||
|
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
|
||||||
|
|
||||||
|
# check that all base model weights are loaded correctly
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
# save pt model
|
||||||
|
pt_model.save_pretrained(tmpdirname)
|
||||||
|
head_model = model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||||
|
|
||||||
|
base_param_from_head = flatten_dict(unfreeze(head_model.params))
|
||||||
|
|
||||||
|
for key in base_param_from_head.keys():
|
||||||
|
max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
|
||||||
|
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||||
|
|
||||||
|
# overwrite since special base model prefix is used
|
||||||
|
@is_pt_flax_cross_test
|
||||||
|
def test_save_load_to_base_pt(self):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
base_class = FLAX_MODEL_MAPPING[config.__class__]
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
if model_class == base_class:
|
||||||
|
continue
|
||||||
|
|
||||||
|
model = model_class(config)
|
||||||
|
base_params_from_head = flatten_dict(unfreeze(model.params))
|
||||||
|
|
||||||
|
# convert Flax model to PyTorch model
|
||||||
|
pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning
|
||||||
|
pt_model = pt_model_class(config).eval()
|
||||||
|
pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params)
|
||||||
|
|
||||||
|
# check that all base model weights are loaded correctly
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
pt_model.save_pretrained(tmpdirname)
|
||||||
|
base_model = base_class.from_pretrained(tmpdirname, from_pt=True)
|
||||||
|
|
||||||
|
base_params = flatten_dict(unfreeze(base_model.params))
|
||||||
|
|
||||||
|
for key in base_params_from_head.keys():
|
||||||
|
max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
|
||||||
|
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||||
|
|
||||||
|
|
||||||
@require_sentencepiece
|
@require_sentencepiece
|
||||||
@require_tokenizers
|
@require_tokenizers
|
||||||
|
|||||||
Reference in New Issue
Block a user