[WIP][Flax] Add general conversion script (#10809)

* save intermediate

* finish first version

* delete some more

* improve import

* fix roberta

* Update src/transformers/modeling_flax_pytorch_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/modeling_flax_pytorch_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* small corrections

* apply all comments

* fix deterministic

* make fix-copies

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Patrick von Platen
2021-03-30 12:13:59 +03:00
committed by GitHub
parent 604c085087
commit 8780caa388
7 changed files with 370 additions and 297 deletions

View File

@@ -27,7 +27,7 @@ if is_flax_available():
import jax
import jax.numpy as jnp
from transformers.modeling_flax_utils import convert_state_dict_from_pt
from transformers.modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
@@ -79,8 +79,8 @@ class FlaxModelTesterMixin:
pt_model_class = getattr(transformers, pt_model_class_name)
pt_model = pt_model_class(config).eval()
fx_state = convert_state_dict_from_pt(model_class, pt_model.state_dict(), config)
fx_model = model_class(config, dtype=jnp.float32)
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in inputs_dict.items()}