[Flax Generation] Correct inconsistencies PyTorch/Flax (#12662)
* fix_torch_device_generate_test * remove @ * correct greedy search * save intertmed * add final logits bias * correct * up * add more tests * fix another bug * finish tests * finish marian tests * up Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
7a22a02a70
commit
cee2d2135f
@@ -16,8 +16,9 @@ import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import is_flax_available
|
||||
from transformers.testing_utils import require_flax
|
||||
import transformers
|
||||
from transformers import is_flax_available, is_torch_available
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_flax
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
@@ -26,10 +27,15 @@ if is_flax_available():
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import jit
|
||||
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
|
||||
|
||||
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
def ids_tensor(shape, vocab_size, rng=None):
|
||||
"""Creates a random int32 tensor of the shape within the vocab size."""
|
||||
if rng is None:
|
||||
@@ -78,6 +84,29 @@ class FlaxGenerationTesterMixin:
|
||||
config.pad_token_id = config.eos_token_id
|
||||
return config, input_ids, attention_mask, max_length
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_greedy_generate_pt_fx(self):
|
||||
config, input_ids, _, max_length = self._get_input_ids_and_config()
|
||||
config.do_sample = False
|
||||
config.max_length = max_length
|
||||
config.decoder_start_token_id = 0
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
flax_model = model_class(config)
|
||||
|
||||
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, flax_model.params)
|
||||
|
||||
flax_generation_outputs = flax_model.generate(input_ids).sequences
|
||||
pt_generation_outputs = pt_model.generate(torch.tensor(input_ids, dtype=torch.long))
|
||||
|
||||
if flax_generation_outputs.shape[-1] > pt_generation_outputs.shape[-1]:
|
||||
flax_generation_outputs = flax_generation_outputs[:, : pt_generation_outputs.shape[-1]]
|
||||
|
||||
self.assertListEqual(pt_generation_outputs.numpy().tolist(), flax_generation_outputs.tolist())
|
||||
|
||||
def test_greedy_generate(self):
|
||||
config, input_ids, _, max_length = self._get_input_ids_and_config()
|
||||
config.do_sample = False
|
||||
|
||||
Reference in New Issue
Block a user