[WIP][Flax] Add general conversion script (#10809)
* save intermediate * finish first version * delete some more * improve import * fix roberta * Update src/transformers/modeling_flax_pytorch_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_flax_pytorch_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * small corrections * apply all comments * fix deterministic * make fix-copies Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
604c085087
commit
8780caa388
@@ -27,7 +27,7 @@ if is_flax_available():
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from transformers.modeling_flax_utils import convert_state_dict_from_pt
|
||||
from transformers.modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
|
||||
|
||||
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
|
||||
|
||||
@@ -79,8 +79,8 @@ class FlaxModelTesterMixin:
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
pt_model = pt_model_class(config).eval()
|
||||
|
||||
fx_state = convert_state_dict_from_pt(model_class, pt_model.state_dict(), config)
|
||||
fx_model = model_class(config, dtype=jnp.float32)
|
||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||
fx_model.params = fx_state
|
||||
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in inputs_dict.items()}
|
||||
|
||||
Reference in New Issue
Block a user