[Flax] Align FlaxBertForMaskedLM with BertForMaskedLM, implement from_pretrained, init (#9054)

* save intermediate

* save intermediate

* save intermediate

* correct flax bert model file

* new module / model naming

* make style

* almost finish BERT

* finish roberta

* make fix-copies

* delete keys file

* last refactor

* fixes in run_mlm_flax.py

* remove pooled from run_mlm_flax.py`

* fix gelu | gelu_new

* remove Module from inits

* splits

* dirty print

* preventing warmup_steps == 0

* smaller splits

* make fix-copies

* dirty print

* dirty print

* initial_evaluation argument

* declaration order fix

* proper model initialization/loading

* proper initialization

* run_mlm_flax improvements: improper model inputs bugfix + automatic dataset splitting + tokenizers parallelism warning + avoiding warmup_steps=0 bug

* removed tokenizers warning hack, fixed model re-initialization

* reverted training_args.py changes

* fix flax from pretrained

* improve test in flax

* apply sylvains tips

* update init

* make 0.3.0 compatible

* revert tevens changes

* revert tevens changes 2

* finalize revert

* fix bug

* add docs

* add pretrained to init

* Update src/transformers/modeling_flax_utils.py

* fix copies

* final improvements

Co-authored-by: TevenLeScao <teven.lescao@gmail.com>
This commit is contained in:
Patrick von Platen
2020-12-16 13:03:32 +01:00
committed by GitHub
parent 51adb97cd6
commit 640e6fe190
14 changed files with 700 additions and 359 deletions

View File

@@ -13,6 +13,7 @@
# limitations under the License.
import random
import tempfile
import numpy as np
@@ -26,7 +27,7 @@ if is_flax_available():
import jax
import jax.numpy as jnp
from flax.traverse_util import unflatten_dict
from transformers.modeling_flax_utils import convert_state_dict_from_pt
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
@@ -59,21 +60,13 @@ def random_attention_mask(shape, rng=None):
return attn_mask
def convert_pt_model_to_flax(pt_model, config, flax_model_cls):
state = pt_model.state_dict()
state = {k: v.numpy() for k, v in state.items()}
state = flax_model_cls.convert_from_pytorch(state, config)
state = unflatten_dict({tuple(k.split(".")): v for k, v in state.items()})
return flax_model_cls(config, state, dtype=jnp.float32)
@require_flax
class FlaxModelTesterMixin:
model_tester = None
all_model_classes = ()
def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
diff = np.abs((a - b)).sum()
diff = np.abs((a - b)).max()
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
@require_torch
@@ -86,30 +79,54 @@ class FlaxModelTesterMixin:
pt_model_class = getattr(transformers, pt_model_class_name)
pt_model = pt_model_class(config).eval()
fx_model = convert_pt_model_to_flax(pt_model, config, model_class)
fx_state = convert_state_dict_from_pt(model_class, pt_model.state_dict(), config)
fx_model = model_class(config, dtype=jnp.float32)
fx_model.params = fx_state
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in inputs_dict.items()}
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
fx_outputs = fx_model(**inputs_dict)
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 5e-3)
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**inputs_dict)
self.assertEqual(
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
)
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 5e-3)
def test_from_pretrained_save_pretrained(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
model = model_class(config)
outputs = model(**inputs_dict)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_loaded = model_class.from_pretrained(tmpdirname)
outputs_loaded = model_loaded(**inputs_dict)
for output_loaded, output in zip(outputs_loaded, outputs):
self.assert_almost_equals(output_loaded, output, 5e-3)
@require_torch
def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
# TODO later: have some way to initialize easily a Flax model from config, for now I go through PT
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)
pt_model = pt_model_class(config).eval()
model = convert_pt_model_to_flax(pt_model, config, model_class)
model = model_class(config)
@jax.jit
def model_jitted(input_ids, attention_mask=None, token_type_ids=None):
@@ -125,3 +142,14 @@ class FlaxModelTesterMixin:
self.assertEqual(len(outputs), len(jitted_outputs))
for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape)
def test_naming_convention(self):
for model_class in self.all_model_classes:
model_class_name = model_class.__name__
module_class_name = (
model_class_name[:-5] + "Module" if model_class_name[-5:] == "Model" else model_class_name + "Module"
)
bert_modeling_flax_module = __import__(model_class.__module__, fromlist=[module_class_name])
module_cls = getattr(bert_modeling_flax_module, module_class_name)
self.assertIsNotNone(module_cls)