[Flax] Align FlaxBertForMaskedLM with BertForMaskedLM, implement from_pretrained, init (#9054)
* save intermediate * save intermediate * save intermediate * correct flax bert model file * new module / model naming * make style * almost finish BERT * finish roberta * make fix-copies * delete keys file * last refactor * fixes in run_mlm_flax.py * remove pooled from run_mlm_flax.py` * fix gelu | gelu_new * remove Module from inits * splits * dirty print * preventing warmup_steps == 0 * smaller splits * make fix-copies * dirty print * dirty print * initial_evaluation argument * declaration order fix * proper model initialization/loading * proper initialization * run_mlm_flax improvements: improper model inputs bugfix + automatic dataset splitting + tokenizers parallelism warning + avoiding warmup_steps=0 bug * removed tokenizers warning hack, fixed model re-initialization * reverted training_args.py changes * fix flax from pretrained * improve test in flax * apply sylvains tips * update init * make 0.3.0 compatible * revert tevens changes * revert tevens changes 2 * finalize revert * fix bug * add docs * add pretrained to init * Update src/transformers/modeling_flax_utils.py * fix copies * final improvements Co-authored-by: TevenLeScao <teven.lescao@gmail.com>
This commit is contained in:
committed by
GitHub
parent
51adb97cd6
commit
640e6fe190
@@ -14,14 +14,16 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import BertConfig, is_flax_available
|
||||
from transformers.testing_utils import require_flax
|
||||
from transformers.testing_utils import require_flax, slow
|
||||
|
||||
from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
from transformers.models.bert.modeling_flax_bert import FlaxBertModel
|
||||
from transformers.models.bert.modeling_flax_bert import FlaxBertForMaskedLM, FlaxBertModel
|
||||
|
||||
|
||||
class FlaxBertModelTester(unittest.TestCase):
|
||||
@@ -105,7 +107,14 @@ class FlaxBertModelTester(unittest.TestCase):
|
||||
@require_flax
|
||||
class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (FlaxBertModel,) if is_flax_available() else ()
|
||||
all_model_classes = (FlaxBertModel, FlaxBertForMaskedLM) if is_flax_available() else ()
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = FlaxBertModelTester(self)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_class_name in self.all_model_classes:
|
||||
model = model_class_name.from_pretrained("bert-base-cased")
|
||||
outputs = model(np.ones((1, 1)))
|
||||
self.assertIsNotNone(outputs)
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -26,7 +27,7 @@ if is_flax_available():
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.traverse_util import unflatten_dict
|
||||
from transformers.modeling_flax_utils import convert_state_dict_from_pt
|
||||
|
||||
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
|
||||
|
||||
@@ -59,21 +60,13 @@ def random_attention_mask(shape, rng=None):
|
||||
return attn_mask
|
||||
|
||||
|
||||
def convert_pt_model_to_flax(pt_model, config, flax_model_cls):
|
||||
state = pt_model.state_dict()
|
||||
state = {k: v.numpy() for k, v in state.items()}
|
||||
state = flax_model_cls.convert_from_pytorch(state, config)
|
||||
state = unflatten_dict({tuple(k.split(".")): v for k, v in state.items()})
|
||||
return flax_model_cls(config, state, dtype=jnp.float32)
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxModelTesterMixin:
|
||||
model_tester = None
|
||||
all_model_classes = ()
|
||||
|
||||
def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
|
||||
diff = np.abs((a - b)).sum()
|
||||
diff = np.abs((a - b)).max()
|
||||
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
|
||||
|
||||
@require_torch
|
||||
@@ -86,30 +79,54 @@ class FlaxModelTesterMixin:
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
pt_model = pt_model_class(config).eval()
|
||||
|
||||
fx_model = convert_pt_model_to_flax(pt_model, config, model_class)
|
||||
fx_state = convert_state_dict_from_pt(model_class, pt_model.state_dict(), config)
|
||||
fx_model = model_class(config, dtype=jnp.float32)
|
||||
fx_model.params = fx_state
|
||||
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in inputs_dict.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
fx_outputs = fx_model(**inputs_dict)
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 5e-3)
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**inputs_dict)
|
||||
self.assertEqual(
|
||||
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 5e-3)
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
model = model_class(config)
|
||||
|
||||
outputs = model(**inputs_dict)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_loaded = model_class.from_pretrained(tmpdirname)
|
||||
|
||||
outputs_loaded = model_loaded(**inputs_dict)
|
||||
for output_loaded, output in zip(outputs_loaded, outputs):
|
||||
self.assert_almost_equals(output_loaded, output, 5e-3)
|
||||
|
||||
@require_torch
|
||||
def test_jit_compilation(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
|
||||
# TODO later: have some way to initialize easily a Flax model from config, for now I go through PT
|
||||
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
pt_model = pt_model_class(config).eval()
|
||||
|
||||
model = convert_pt_model_to_flax(pt_model, config, model_class)
|
||||
model = model_class(config)
|
||||
|
||||
@jax.jit
|
||||
def model_jitted(input_ids, attention_mask=None, token_type_ids=None):
|
||||
@@ -125,3 +142,14 @@ class FlaxModelTesterMixin:
|
||||
self.assertEqual(len(outputs), len(jitted_outputs))
|
||||
for jitted_output, output in zip(jitted_outputs, outputs):
|
||||
self.assertEqual(jitted_output.shape, output.shape)
|
||||
|
||||
def test_naming_convention(self):
|
||||
for model_class in self.all_model_classes:
|
||||
model_class_name = model_class.__name__
|
||||
module_class_name = (
|
||||
model_class_name[:-5] + "Module" if model_class_name[-5:] == "Model" else model_class_name + "Module"
|
||||
)
|
||||
bert_modeling_flax_module = __import__(model_class.__module__, fromlist=[module_class_name])
|
||||
module_cls = getattr(bert_modeling_flax_module, module_class_name)
|
||||
|
||||
self.assertIsNotNone(module_cls)
|
||||
|
||||
@@ -14,8 +14,10 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import RobertaConfig, is_flax_available
|
||||
from transformers.testing_utils import require_flax
|
||||
from transformers.testing_utils import require_flax, slow
|
||||
|
||||
from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
@@ -109,3 +111,10 @@ class FlaxRobertaModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = FlaxRobertaModelTester(self)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_class_name in self.all_model_classes:
|
||||
model = model_class_name.from_pretrained("roberta-base")
|
||||
outputs = model(np.ones((1, 1)))
|
||||
self.assertIsNotNone(outputs)
|
||||
|
||||
Reference in New Issue
Block a user