[Flax Generation] Correct inconsistencies PyTorch/Flax (#12662)

* fix_torch_device_generate_test

* remove @

* correct greedy search

* save intertmed

* add final logits bias

* correct

* up

* add more tests

* fix another bug

* finish tests

* finish marian tests

* up

Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
Patrick von Platen
2021-07-13 18:53:30 +01:00
committed by GitHub
parent 7a22a02a70
commit cee2d2135f
7 changed files with 113 additions and 44 deletions

View File

@@ -16,8 +16,9 @@ import random
import numpy as np
from transformers import is_flax_available
from transformers.testing_utils import require_flax
import transformers
from transformers import is_flax_available, is_torch_available
from transformers.testing_utils import is_pt_flax_cross_test, require_flax
if is_flax_available():
@@ -26,10 +27,15 @@ if is_flax_available():
import jax
import jax.numpy as jnp
from jax import jit
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
if is_torch_available():
import torch
def ids_tensor(shape, vocab_size, rng=None):
"""Creates a random int32 tensor of the shape within the vocab size."""
if rng is None:
@@ -78,6 +84,29 @@ class FlaxGenerationTesterMixin:
config.pad_token_id = config.eos_token_id
return config, input_ids, attention_mask, max_length
@is_pt_flax_cross_test
def test_greedy_generate_pt_fx(self):
config, input_ids, _, max_length = self._get_input_ids_and_config()
config.do_sample = False
config.max_length = max_length
config.decoder_start_token_id = 0
for model_class in self.all_generative_model_classes:
flax_model = model_class(config)
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, flax_model.params)
flax_generation_outputs = flax_model.generate(input_ids).sequences
pt_generation_outputs = pt_model.generate(torch.tensor(input_ids, dtype=torch.long))
if flax_generation_outputs.shape[-1] > pt_generation_outputs.shape[-1]:
flax_generation_outputs = flax_generation_outputs[:, : pt_generation_outputs.shape[-1]]
self.assertListEqual(pt_generation_outputs.numpy().tolist(), flax_generation_outputs.tolist())
def test_greedy_generate(self):
config, input_ids, _, max_length = self._get_input_ids_and_config()
config.do_sample = False