[tests] remove tf/flax tests in /generation (#36235)
This commit is contained in:
@@ -1,343 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Team Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a clone of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import is_flax_available
|
||||
from transformers.testing_utils import require_flax
|
||||
|
||||
from ..test_modeling_flax_common import ids_tensor
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from transformers.generation import (
|
||||
FlaxForcedBOSTokenLogitsProcessor,
|
||||
FlaxForcedEOSTokenLogitsProcessor,
|
||||
FlaxLogitsProcessorList,
|
||||
FlaxMinLengthLogitsProcessor,
|
||||
FlaxNoRepeatNGramLogitsProcessor,
|
||||
FlaxTemperatureLogitsWarper,
|
||||
FlaxTopKLogitsWarper,
|
||||
FlaxTopPLogitsWarper,
|
||||
)
|
||||
|
||||
|
||||
@require_flax
|
||||
class LogitsProcessorTest(unittest.TestCase):
|
||||
def _get_uniform_logits(self, batch_size: int, length: int):
|
||||
scores = jnp.ones((batch_size, length)) / length
|
||||
return scores
|
||||
|
||||
def test_temperature_dist_warper(self):
|
||||
input_ids = None
|
||||
length = 20
|
||||
|
||||
scores = self._get_uniform_logits(batch_size=2, length=length)
|
||||
|
||||
# tweak scores to not be uniform anymore
|
||||
scores = scores.at[1, 5].set((1 / length) + 0.1) # peak, 1st batch
|
||||
scores = scores.at[1, 10].set((1 / length) - 0.4) # valley, 1st batch
|
||||
|
||||
# compute softmax
|
||||
probs = jax.nn.softmax(scores, axis=-1)
|
||||
|
||||
temp_dist_warper_sharper = FlaxTemperatureLogitsWarper(temperature=0.5)
|
||||
temp_dist_warper_smoother = FlaxTemperatureLogitsWarper(temperature=1.3)
|
||||
|
||||
warped_prob_sharp = jax.nn.softmax(temp_dist_warper_sharper(input_ids, scores.copy(), cur_len=None), axis=-1)
|
||||
warped_prob_smooth = jax.nn.softmax(temp_dist_warper_smoother(input_ids, scores.copy(), cur_len=None), axis=-1)
|
||||
|
||||
# uniform distribution stays uniform
|
||||
self.assertTrue(jnp.allclose(probs[0, :], warped_prob_sharp[0, :], atol=1e-3))
|
||||
self.assertTrue(jnp.allclose(probs[0, :], warped_prob_smooth[0, :], atol=1e-3))
|
||||
|
||||
# sharp peaks get higher, valleys get lower
|
||||
self.assertLess(probs[1, :].max(), warped_prob_sharp[1, :].max())
|
||||
self.assertGreater(probs[1, :].min(), warped_prob_sharp[1, :].min())
|
||||
|
||||
# smooth peaks get lower, valleys get higher
|
||||
self.assertGreater(probs[1, :].max(), warped_prob_smooth[1, :].max())
|
||||
self.assertLess(probs[1, :].min(), warped_prob_smooth[1, :].min())
|
||||
|
||||
def test_top_k_dist_warper(self):
|
||||
input_ids = None
|
||||
vocab_size = 10
|
||||
batch_size = 2
|
||||
|
||||
# create ramp distribution
|
||||
ramp_logits = np.broadcast_to(np.arange(vocab_size)[None, :], (batch_size, vocab_size)).copy()
|
||||
ramp_logits[1:, : vocab_size // 2] = ramp_logits[1:, : vocab_size // 2] + vocab_size
|
||||
|
||||
top_k_warp = FlaxTopKLogitsWarper(3)
|
||||
|
||||
scores = top_k_warp(input_ids, ramp_logits, cur_len=None)
|
||||
|
||||
# check that correct tokens are filtered
|
||||
self.assertListEqual(jnp.isinf(scores[0]).tolist(), 7 * [True] + 3 * [False])
|
||||
self.assertListEqual(jnp.isinf(scores[1]).tolist(), 2 * [True] + 3 * [False] + 5 * [True])
|
||||
|
||||
# check special case
|
||||
length = 5
|
||||
top_k_warp_safety_check = FlaxTopKLogitsWarper(top_k=1, filter_value=0.0, min_tokens_to_keep=3)
|
||||
|
||||
ramp_logits = np.broadcast_to(np.arange(length)[None, :], (batch_size, length)).copy()
|
||||
scores = top_k_warp_safety_check(input_ids, ramp_logits, cur_len=None)
|
||||
|
||||
# min_tokens overwrites k: 3 tokens are kept => 2 tokens are nullified
|
||||
self.assertListEqual((scores == 0.0).sum(axis=-1).tolist(), [2, 2])
|
||||
|
||||
def test_top_p_dist_warper(self):
|
||||
input_ids = None
|
||||
vocab_size = 10
|
||||
batch_size = 2
|
||||
|
||||
# create distribution and take log (inverse to Softmax as taken in TopPLogitsWarper)
|
||||
dist = np.log(np.array([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]]))
|
||||
|
||||
top_p_warp = FlaxTopPLogitsWarper(0.8)
|
||||
filtered_dist = np.exp(top_p_warp(input_ids, dist, cur_len=None))
|
||||
|
||||
# dist should be filtered to keep min num values so that sum is >= top_p
|
||||
# exp (-inf) => 0
|
||||
EXPECTED_FILTERED_DIST = np.array([[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]])
|
||||
self.assertTrue(np.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3))
|
||||
|
||||
# check edge cases with negative and extreme logits
|
||||
ramp_logits = np.broadcast_to(np.arange(vocab_size)[None, :], (batch_size, vocab_size)).copy() - (
|
||||
vocab_size // 2
|
||||
)
|
||||
|
||||
# make ramp_logits more extreme
|
||||
ramp_logits[1] = ramp_logits[1] * 100.0
|
||||
|
||||
# make sure at least 2 tokens are kept
|
||||
top_p_warp = FlaxTopPLogitsWarper(0.9, min_tokens_to_keep=2, filter_value=0.0)
|
||||
filtered_dist = top_p_warp(input_ids, ramp_logits, cur_len=None)
|
||||
|
||||
# first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
|
||||
self.assertListEqual((filtered_dist != 0.0).sum(axis=-1).tolist(), [3, 2])
|
||||
|
||||
def test_min_length_dist_processor(self):
|
||||
vocab_size = 20
|
||||
batch_size = 4
|
||||
eos_token_id = 0
|
||||
|
||||
min_dist_processor = FlaxMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
|
||||
|
||||
# check that min length is applied at length 5
|
||||
input_ids = ids_tensor((batch_size, 20), vocab_size=20)
|
||||
cur_len = 5
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores_before_min_length = min_dist_processor(input_ids, scores, cur_len=cur_len)
|
||||
self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), 4 * [-float("inf")])
|
||||
|
||||
# check that min length is not applied anymore at length 15
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
cur_len = 15
|
||||
scores_before_min_length = min_dist_processor(input_ids, scores, cur_len=cur_len)
|
||||
self.assertFalse(jnp.isinf(scores_before_min_length).any())
|
||||
|
||||
def test_forced_bos_token_logits_processor(self):
|
||||
vocab_size = 20
|
||||
batch_size = 4
|
||||
bos_token_id = 0
|
||||
|
||||
logits_processor = FlaxForcedBOSTokenLogitsProcessor(bos_token_id=bos_token_id)
|
||||
|
||||
# check that all scores are -inf except the bos_token_id score
|
||||
input_ids = ids_tensor((batch_size, 1), vocab_size=20)
|
||||
cur_len = 1
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores = logits_processor(input_ids, scores, cur_len=cur_len)
|
||||
self.assertTrue(jnp.isneginf(scores[:, bos_token_id + 1 :]).all())
|
||||
self.assertListEqual(scores[:, bos_token_id].tolist(), 4 * [0]) # score for bos_token_id shold be zero
|
||||
|
||||
# check that bos_token_id is not forced if current length is greater than 1
|
||||
cur_len = 3
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores = logits_processor(input_ids, scores, cur_len=cur_len)
|
||||
self.assertFalse(jnp.isinf(scores).any())
|
||||
|
||||
def test_forced_eos_token_logits_processor(self):
|
||||
vocab_size = 20
|
||||
batch_size = 4
|
||||
eos_token_id = 0
|
||||
max_length = 5
|
||||
|
||||
logits_processor = FlaxForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id)
|
||||
|
||||
# check that all scores are -inf except the eos_token_id when max_length is reached
|
||||
input_ids = ids_tensor((batch_size, 4), vocab_size=20)
|
||||
cur_len = 4
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores = logits_processor(input_ids, scores, cur_len=cur_len)
|
||||
self.assertTrue(jnp.isneginf(scores[:, eos_token_id + 1 :]).all())
|
||||
self.assertListEqual(scores[:, eos_token_id].tolist(), 4 * [0]) # score for eos_token_id should be zero
|
||||
|
||||
# check that eos_token_id is not forced if max_length is not reached
|
||||
cur_len = 3
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores = logits_processor(input_ids, scores, cur_len=cur_len)
|
||||
self.assertFalse(jnp.isinf(scores).any())
|
||||
|
||||
def test_no_repeat_ngram_dist_processor(self):
|
||||
vocab_size = 3
|
||||
batch_size = 2
|
||||
|
||||
cur_len = 4
|
||||
input_ids = np.array([[1, 1, 2, 1], [0, 1, 0, 1]], dtype="i4")
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
|
||||
no_repeat_proc_2_gram = FlaxNoRepeatNGramLogitsProcessor(2)
|
||||
no_repeat_proc_3_gram = FlaxNoRepeatNGramLogitsProcessor(3)
|
||||
|
||||
filtered_scores_2_gram = no_repeat_proc_2_gram(input_ids, scores, cur_len=cur_len)
|
||||
filtered_scores_3_gram = no_repeat_proc_3_gram(input_ids, scores, cur_len=cur_len)
|
||||
|
||||
# 2-gram would forbid 2nd and 3rd token (1,2) at 1st batch and 1st token (0) at 2nd batch
|
||||
self.assertListEqual(jnp.isinf(filtered_scores_2_gram).tolist(), [[False, True, True], [True, False, False]])
|
||||
|
||||
# 3-gram would forbid no token at 1st batch and 1st token (0) at 2nd batch
|
||||
self.assertListEqual(jnp.isinf(filtered_scores_3_gram).tolist(), [[False, False, False], [True, False, False]])
|
||||
|
||||
def test_processor_list(self):
|
||||
batch_size = 4
|
||||
sequence_length = 10
|
||||
vocab_size = 15
|
||||
eos_token_id = 2
|
||||
bos_token_id = 1
|
||||
max_length = 15
|
||||
|
||||
# dummy input_ids and scores
|
||||
input_ids = ids_tensor((batch_size, sequence_length), vocab_size)
|
||||
input_ids_comp = input_ids.copy()
|
||||
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores_comp = scores.copy()
|
||||
|
||||
# instantiate all dist processors
|
||||
temp_dist_warp = FlaxTemperatureLogitsWarper(temperature=0.5)
|
||||
top_k_warp = FlaxTopKLogitsWarper(3)
|
||||
top_p_warp = FlaxTopPLogitsWarper(0.8)
|
||||
no_repeat_proc = FlaxNoRepeatNGramLogitsProcessor(2)
|
||||
|
||||
# instantiate all logits processors
|
||||
min_dist_proc = FlaxMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
|
||||
bos_dist_proc = FlaxForcedBOSTokenLogitsProcessor(bos_token_id=bos_token_id)
|
||||
eos_dist_proc = FlaxForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id)
|
||||
|
||||
cur_len = 10
|
||||
|
||||
# no processor list
|
||||
scores = temp_dist_warp(input_ids, scores, cur_len=cur_len)
|
||||
scores = top_k_warp(input_ids, scores, cur_len=cur_len)
|
||||
scores = top_p_warp(input_ids, scores, cur_len=cur_len)
|
||||
scores = min_dist_proc(input_ids, scores, cur_len=cur_len)
|
||||
scores = bos_dist_proc(input_ids, scores, cur_len=cur_len)
|
||||
scores = eos_dist_proc(input_ids, scores, cur_len=cur_len)
|
||||
scores = no_repeat_proc(input_ids, scores, cur_len=cur_len)
|
||||
|
||||
# with processor list
|
||||
processor = FlaxLogitsProcessorList(
|
||||
[
|
||||
temp_dist_warp,
|
||||
top_k_warp,
|
||||
top_p_warp,
|
||||
min_dist_proc,
|
||||
bos_dist_proc,
|
||||
eos_dist_proc,
|
||||
no_repeat_proc,
|
||||
]
|
||||
)
|
||||
scores_comp = processor(input_ids, scores_comp, cur_len=cur_len)
|
||||
|
||||
# scores should be equal
|
||||
self.assertTrue(jnp.allclose(scores, scores_comp, atol=1e-3))
|
||||
|
||||
# input_ids should never be changed
|
||||
self.assertListEqual(input_ids.tolist(), input_ids_comp.tolist())
|
||||
|
||||
def test_processor_list_jitted(self):
|
||||
batch_size = 4
|
||||
sequence_length = 10
|
||||
vocab_size = 15
|
||||
eos_token_id = 2
|
||||
bos_token_id = 1
|
||||
max_length = 15
|
||||
|
||||
# dummy input_ids and scores
|
||||
input_ids = ids_tensor((batch_size, sequence_length), vocab_size)
|
||||
input_ids_comp = input_ids.copy()
|
||||
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores_comp = scores.copy()
|
||||
|
||||
# instantiate all dist processors
|
||||
temp_dist_warp = FlaxTemperatureLogitsWarper(temperature=0.5)
|
||||
top_k_warp = FlaxTopKLogitsWarper(3)
|
||||
top_p_warp = FlaxTopPLogitsWarper(0.8)
|
||||
no_repeat_proc = FlaxNoRepeatNGramLogitsProcessor(2)
|
||||
|
||||
# instantiate all logits processors
|
||||
min_dist_proc = FlaxMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
|
||||
bos_dist_proc = FlaxForcedBOSTokenLogitsProcessor(bos_token_id=bos_token_id)
|
||||
eos_dist_proc = FlaxForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id)
|
||||
|
||||
cur_len = 10
|
||||
|
||||
# no processor list
|
||||
def run_no_processor_list(input_ids, scores, cur_len):
|
||||
scores = temp_dist_warp(input_ids, scores, cur_len=cur_len)
|
||||
scores = top_k_warp(input_ids, scores, cur_len=cur_len)
|
||||
scores = top_p_warp(input_ids, scores, cur_len=cur_len)
|
||||
scores = min_dist_proc(input_ids, scores, cur_len=cur_len)
|
||||
scores = bos_dist_proc(input_ids, scores, cur_len=cur_len)
|
||||
scores = eos_dist_proc(input_ids, scores, cur_len=cur_len)
|
||||
scores = no_repeat_proc(input_ids, scores, cur_len=cur_len)
|
||||
return scores
|
||||
|
||||
# with processor list
|
||||
def run_processor_list(input_ids, scores, cur_len):
|
||||
processor = FlaxLogitsProcessorList(
|
||||
[
|
||||
temp_dist_warp,
|
||||
top_k_warp,
|
||||
top_p_warp,
|
||||
min_dist_proc,
|
||||
bos_dist_proc,
|
||||
eos_dist_proc,
|
||||
no_repeat_proc,
|
||||
]
|
||||
)
|
||||
scores = processor(input_ids, scores, cur_len=cur_len)
|
||||
return scores
|
||||
|
||||
jitted_run_no_processor_list = jax.jit(run_no_processor_list)
|
||||
jitted_run_processor_list = jax.jit(run_processor_list)
|
||||
|
||||
scores = jitted_run_no_processor_list(input_ids, scores, cur_len)
|
||||
scores_comp = jitted_run_processor_list(input_ids, scores_comp, cur_len)
|
||||
|
||||
# scores should be equal
|
||||
self.assertTrue(jnp.allclose(scores, scores_comp, atol=1e-3))
|
||||
|
||||
# input_ids should never be changed
|
||||
self.assertListEqual(input_ids.tolist(), input_ids_comp.tolist())
|
||||
@@ -1,313 +0,0 @@
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
import transformers
|
||||
from transformers import is_flax_available, is_torch_available
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_flax
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import os
|
||||
|
||||
import jax.numpy as jnp
|
||||
from jax import jit
|
||||
|
||||
from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
|
||||
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
|
||||
|
||||
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
def ids_tensor(shape, vocab_size, rng=None):
|
||||
"""Creates a random int32 tensor of the shape within the vocab size."""
|
||||
if rng is None:
|
||||
rng = random.Random()
|
||||
|
||||
total_dims = 1
|
||||
for dim in shape:
|
||||
total_dims *= dim
|
||||
|
||||
values = []
|
||||
for _ in range(total_dims):
|
||||
values.append(rng.randint(0, vocab_size - 1))
|
||||
|
||||
output = np.array(values, dtype=jnp.int32).reshape(shape)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def random_attention_mask(shape, rng=None):
|
||||
attn_mask = ids_tensor(shape, vocab_size=2, rng=rng)
|
||||
# make sure that at least one token is attended to for each batch
|
||||
attn_mask[:, -1] = 1
|
||||
return attn_mask
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxGenerationTesterMixin:
|
||||
model_tester = None
|
||||
|
||||
def _get_input_ids_and_config(self):
|
||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# cut to half length & take max batch_size 3
|
||||
max_batch_size = 2
|
||||
sequence_length = inputs["input_ids"].shape[-1] // 2
|
||||
input_ids = inputs["input_ids"][:max_batch_size, :sequence_length]
|
||||
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
attention_mask = attention_mask[:max_batch_size, :sequence_length]
|
||||
|
||||
# generate max 5 tokens
|
||||
max_length = input_ids.shape[-1] + 5
|
||||
if config.eos_token_id is not None and config.pad_token_id is None:
|
||||
# hack to allow generate for models such as GPT2 as is done in `generate()`
|
||||
config.pad_token_id = config.eos_token_id
|
||||
return config, input_ids, attention_mask, max_length
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_greedy_generate_pt_fx(self):
|
||||
config, input_ids, _, max_length = self._get_input_ids_and_config()
|
||||
config.do_sample = False
|
||||
config.max_length = max_length
|
||||
config.decoder_start_token_id = 0
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
flax_model = model_class(config)
|
||||
|
||||
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, flax_model.params)
|
||||
|
||||
# Generate max 5 tokens only otherwise seems to be numerical error accumulation
|
||||
pt_model.generation_config.max_length = 5
|
||||
flax_model.generation_config.max_length = 5
|
||||
|
||||
flax_generation_outputs = flax_model.generate(input_ids).sequences
|
||||
pt_generation_outputs = pt_model.generate(torch.tensor(input_ids, dtype=torch.long))
|
||||
|
||||
if flax_generation_outputs.shape[-1] > pt_generation_outputs.shape[-1]:
|
||||
flax_generation_outputs = flax_generation_outputs[:, : pt_generation_outputs.shape[-1]]
|
||||
|
||||
self.assertListEqual(pt_generation_outputs.numpy().tolist(), flax_generation_outputs.tolist())
|
||||
|
||||
def test_greedy_generate(self):
|
||||
config, input_ids, _, max_length = self._get_input_ids_and_config()
|
||||
config.do_sample = False
|
||||
config.max_length = max_length
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
generation_outputs = model.generate(input_ids).sequences
|
||||
self.assertEqual(generation_outputs.shape[-1], max_length)
|
||||
|
||||
jit_generate = jit(model.generate)
|
||||
jit_generation_outputs = jit_generate(input_ids).sequences
|
||||
|
||||
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
||||
|
||||
def test_sample_generate(self):
|
||||
config, input_ids, _, max_length = self._get_input_ids_and_config()
|
||||
config.do_sample = True
|
||||
config.max_length = max_length
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
generation_outputs = model.generate(input_ids).sequences
|
||||
self.assertEqual(generation_outputs.shape[-1], max_length)
|
||||
|
||||
jit_generate = jit(model.generate)
|
||||
jit_generation_outputs = jit_generate(input_ids).sequences
|
||||
|
||||
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
||||
|
||||
def test_beam_search_generate(self):
|
||||
config, input_ids, _, max_length = self._get_input_ids_and_config()
|
||||
config.do_sample = False
|
||||
config.max_length = max_length
|
||||
config.num_beams = 2
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
generation_outputs = model.generate(input_ids).sequences
|
||||
self.assertEqual(generation_outputs.shape[-1], max_length)
|
||||
|
||||
jit_generate = jit(model.generate)
|
||||
jit_generation_outputs = jit_generate(input_ids).sequences
|
||||
|
||||
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
||||
|
||||
def test_beam_search_generate_num_return_sequences(self):
|
||||
config, input_ids, _, max_length = self._get_input_ids_and_config()
|
||||
config.do_sample = False
|
||||
config.max_length = max_length
|
||||
config.num_beams = 2
|
||||
config.num_return_sequences = 2
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
generation_outputs = model.generate(input_ids).sequences
|
||||
self.assertEqual(generation_outputs.shape[0], input_ids.shape[0] * config.num_return_sequences)
|
||||
|
||||
def test_sample_generate_logits_warper(self):
|
||||
config, input_ids, _, max_length = self._get_input_ids_and_config()
|
||||
config.do_sample = True
|
||||
config.max_length = max_length
|
||||
config.temperature = 0.8
|
||||
config.top_k = 10
|
||||
config.top_p = 0.3
|
||||
config.min_length = 1
|
||||
config.forced_bos_token_id = 8
|
||||
config.forced_eos_token_id = 9
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
generation_outputs = model.generate(input_ids).sequences
|
||||
self.assertEqual(generation_outputs.shape[-1], max_length)
|
||||
|
||||
jit_generate = jit(model.generate)
|
||||
jit_generation_outputs = jit_generate(input_ids).sequences
|
||||
|
||||
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
||||
|
||||
def test_greedy_generate_logits_warper(self):
|
||||
config, input_ids, _, max_length = self._get_input_ids_and_config()
|
||||
config.max_length = max_length
|
||||
config.min_length = 1
|
||||
config.forced_bos_token_id = 8
|
||||
config.forced_eos_token_id = 9
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
generation_outputs = model.generate(input_ids).sequences
|
||||
self.assertEqual(generation_outputs.shape[-1], max_length)
|
||||
|
||||
jit_generate = jit(model.generate)
|
||||
jit_generation_outputs = jit_generate(input_ids).sequences
|
||||
|
||||
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
||||
|
||||
def test_beam_search_generate_logits_warper(self):
|
||||
config, input_ids, _, max_length = self._get_input_ids_and_config()
|
||||
config.max_length = max_length
|
||||
config.num_beams = 2
|
||||
config.min_length = 1
|
||||
config.forced_bos_token_id = 8
|
||||
config.forced_eos_token_id = 9
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
generation_outputs = model.generate(input_ids).sequences
|
||||
self.assertEqual(generation_outputs.shape[-1], max_length)
|
||||
|
||||
jit_generate = jit(model.generate)
|
||||
jit_generation_outputs = jit_generate(input_ids).sequences
|
||||
|
||||
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
||||
|
||||
def test_greedy_generate_attn_mask(self):
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
|
||||
# pad attention mask on the left
|
||||
attention_mask = attention_mask.at[(0, 0)].set(0)
|
||||
|
||||
config.do_sample = False
|
||||
config.max_length = max_length
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
generation_outputs = model.generate(input_ids, attention_mask=attention_mask).sequences
|
||||
self.assertEqual(generation_outputs.shape[-1], max_length)
|
||||
|
||||
jit_generate = jit(model.generate)
|
||||
jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences
|
||||
|
||||
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
||||
|
||||
def test_sample_generate_attn_mask(self):
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
|
||||
# pad attention mask on the left
|
||||
attention_mask = attention_mask.at[(0, 0)].set(0)
|
||||
|
||||
config.do_sample = True
|
||||
config.max_length = max_length
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
generation_outputs = model.generate(input_ids, attention_mask=attention_mask).sequences
|
||||
self.assertEqual(generation_outputs.shape[-1], max_length)
|
||||
|
||||
jit_generate = jit(model.generate)
|
||||
jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences
|
||||
|
||||
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
||||
|
||||
def test_beam_search_generate_attn_mask(self):
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
|
||||
# pad attention mask on the left
|
||||
attention_mask = attention_mask.at[(0, 0)].set(0)
|
||||
|
||||
config.num_beams = 2
|
||||
config.max_length = max_length
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
generation_outputs = model.generate(input_ids, attention_mask=attention_mask).sequences
|
||||
self.assertEqual(generation_outputs.shape[-1], max_length)
|
||||
|
||||
jit_generate = jit(model.generate)
|
||||
jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences
|
||||
|
||||
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxGenerationIntegrationTests(unittest.TestCase):
|
||||
def test_validate_generation_inputs(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-bert")
|
||||
model = FlaxAutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||
|
||||
encoder_input_str = "Hello world"
|
||||
input_ids = tokenizer(encoder_input_str, return_tensors="np").input_ids
|
||||
|
||||
# typos are quickly detected (the correct argument is `do_sample`)
|
||||
with self.assertRaisesRegex(ValueError, "do_samples"):
|
||||
model.generate(input_ids, do_samples=True)
|
||||
|
||||
# arbitrary arguments that will not be used anywhere are also not accepted
|
||||
with self.assertRaisesRegex(ValueError, "foo"):
|
||||
fake_model_kwargs = {"foo": "bar"}
|
||||
model.generate(input_ids, **fake_model_kwargs)
|
||||
@@ -1,688 +0,0 @@
|
||||
"""
|
||||
Framework agnostic tests for generate()-related methods.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.testing_utils import slow, torch_device
|
||||
|
||||
|
||||
class GenerationIntegrationTestsMixin:
|
||||
# To be populated by the child classes
|
||||
framework_dependent_parameters = {
|
||||
"AutoModelForCausalLM": None,
|
||||
"AutoModelForSpeechSeq2Seq": None,
|
||||
"AutoModelForSeq2SeqLM": None,
|
||||
"AutoModelForVision2Seq": None,
|
||||
"LogitsProcessorList": None,
|
||||
"MinLengthLogitsProcessor": None,
|
||||
"create_tensor_fn": None,
|
||||
"floats_tensor": None,
|
||||
"return_tensors": None,
|
||||
"set_seed": None,
|
||||
}
|
||||
|
||||
def test_validate_generation_inputs(self):
|
||||
model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"]
|
||||
return_tensors = self.framework_dependent_parameters["return_tensors"]
|
||||
create_tensor_fn = self.framework_dependent_parameters["create_tensor_fn"]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
model = model_cls.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
encoder_input_str = "Hello world"
|
||||
input_ids = tokenizer(encoder_input_str, return_tensors=return_tensors).input_ids
|
||||
|
||||
# typos are quickly detected (the correct argument is `do_sample`)
|
||||
with self.assertRaisesRegex(ValueError, "do_samples"):
|
||||
model.generate(input_ids, do_samples=True)
|
||||
|
||||
# arbitrary arguments that will not be used anywhere are also not accepted
|
||||
with self.assertRaisesRegex(ValueError, "foo"):
|
||||
fake_model_kwargs = {"foo": "bar"}
|
||||
model.generate(input_ids, **fake_model_kwargs)
|
||||
|
||||
# however, valid model_kwargs are accepted
|
||||
valid_model_kwargs = {"attention_mask": create_tensor_fn(np.zeros_like(input_ids))}
|
||||
model.generate(input_ids, **valid_model_kwargs)
|
||||
|
||||
def test_custom_logits_processor(self):
|
||||
model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"]
|
||||
logits_processor_list_cls = self.framework_dependent_parameters["LogitsProcessorList"]
|
||||
min_length_logits_processor_cls = self.framework_dependent_parameters["MinLengthLogitsProcessor"]
|
||||
return_tensors = self.framework_dependent_parameters["return_tensors"]
|
||||
|
||||
bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_model = model_cls.from_pretrained("hf-internal-testing/tiny-random-bart", min_length=1)
|
||||
input_ids = bart_tokenizer(article, return_tensors=return_tensors).input_ids
|
||||
|
||||
logits_processor = logits_processor_list_cls()
|
||||
logits_processor.append(min_length_logits_processor_cls(min_length=10, eos_token_id=0))
|
||||
# it should not be allowed to both define `min_length` via config and `logits_processor` list
|
||||
with self.assertRaises(ValueError):
|
||||
bart_model.generate(input_ids, logits_processor=logits_processor)
|
||||
|
||||
bart_model.config.min_length = None
|
||||
bart_model.generate(input_ids, logits_processor=logits_processor)
|
||||
|
||||
def test_max_new_tokens_encoder_decoder(self):
|
||||
model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"]
|
||||
return_tensors = self.framework_dependent_parameters["return_tensors"]
|
||||
is_pt = not model_cls.__name__.startswith("TF")
|
||||
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
|
||||
bart_model = model_cls.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
input_ids = bart_tokenizer(article, return_tensors=return_tensors).input_ids
|
||||
if is_pt:
|
||||
bart_model = bart_model.to(torch_device)
|
||||
input_ids = input_ids.to(torch_device)
|
||||
|
||||
self.assertEqual(list(input_ids.shape), [1, 29])
|
||||
|
||||
max_new_tokens = 3
|
||||
bart_model.config.max_length = 20
|
||||
bart_model.config.eos_token_id = None
|
||||
|
||||
# Encoder decoder call
|
||||
outputs = bart_model.generate(input_ids, max_new_tokens=max_new_tokens)
|
||||
# 1 BOS + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 4])
|
||||
|
||||
# Decoder only call
|
||||
outputs = bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=max_new_tokens)
|
||||
# 1 BOS + 29 (input length) + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 33])
|
||||
|
||||
# Encoder decoder call > 20
|
||||
outputs = bart_model.generate(max_new_tokens=max_new_tokens + 20)
|
||||
|
||||
# 1 BOS + 20 + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 24])
|
||||
|
||||
def test_max_new_tokens_decoder_only(self):
|
||||
model_cls = self.framework_dependent_parameters["AutoModelForCausalLM"]
|
||||
return_tensors = self.framework_dependent_parameters["return_tensors"]
|
||||
is_pt = not model_cls.__name__.startswith("TF")
|
||||
|
||||
article = """Justin Timberlake."""
|
||||
gpt2_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
|
||||
gpt2_model = model_cls.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
input_ids = gpt2_tokenizer(article, return_tensors=return_tensors).input_ids
|
||||
if is_pt:
|
||||
gpt2_model = gpt2_model.to(torch_device)
|
||||
input_ids = input_ids.to(torch_device)
|
||||
|
||||
self.assertEqual(list(input_ids.shape), [1, 9])
|
||||
|
||||
max_new_tokens = 3
|
||||
gpt2_model.config.max_length = 20
|
||||
|
||||
# call < 20
|
||||
outputs = gpt2_model.generate(input_ids, max_new_tokens=max_new_tokens)
|
||||
|
||||
# 9 input_ids + 3 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 12])
|
||||
|
||||
# call > 20
|
||||
outputs = gpt2_model.generate(max_new_tokens=max_new_tokens + 20)
|
||||
|
||||
# 1 BOS token + 23 new tokens
|
||||
self.assertEqual(list(outputs.shape), [1, 24])
|
||||
|
||||
def test_encoder_decoder_generate_with_inputs_embeds(self):
|
||||
model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"]
|
||||
return_tensors = self.framework_dependent_parameters["return_tensors"]
|
||||
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
model = model_cls.from_pretrained("hf-internal-testing/tiny-random-bart", max_length=5)
|
||||
model.config.eos_token_id = None
|
||||
input_ids = tokenizer(article, return_tensors=return_tensors).input_ids
|
||||
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
output_sequences = model.generate(inputs_embeds=inputs_embeds)
|
||||
|
||||
# make sure model generated correctly until `max_length`
|
||||
self.assertEqual(output_sequences.shape, (1, 5))
|
||||
|
||||
def test_transition_scores_greedy_search(self):
|
||||
model_cls = self.framework_dependent_parameters["AutoModelForCausalLM"]
|
||||
return_tensors = self.framework_dependent_parameters["return_tensors"]
|
||||
is_pt = not model_cls.__name__.startswith("TF")
|
||||
|
||||
articles = ["Justin Timberlake", "Michael Phelps"]
|
||||
tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2", padding_side="left")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
model = model_cls.from_pretrained("distilbert/distilgpt2")
|
||||
model.generation_config.eos_token_id = None
|
||||
input_ids = tokenizer(articles, return_tensors=return_tensors, padding=True).input_ids
|
||||
if is_pt:
|
||||
model = model.to(torch_device)
|
||||
input_ids = input_ids.to(torch_device)
|
||||
|
||||
outputs = model.generate(
|
||||
input_ids=input_ids,
|
||||
max_new_tokens=5,
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
)
|
||||
|
||||
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores)
|
||||
if is_pt:
|
||||
transition_scores = transition_scores.cpu().numpy()
|
||||
|
||||
expected_scores = np.array(
|
||||
[
|
||||
[-57.8844, -60.45698, -70.16364, -65.50791, -66.35648],
|
||||
[-54.417572, -60.216614, -62.661243, -58.621933, -58.298683],
|
||||
]
|
||||
)
|
||||
self.assertTrue(np.allclose(transition_scores, expected_scores, atol=1e-3))
|
||||
|
||||
def test_transition_scores_greedy_search_normalized(self):
|
||||
model_cls = self.framework_dependent_parameters["AutoModelForCausalLM"]
|
||||
return_tensors = self.framework_dependent_parameters["return_tensors"]
|
||||
is_pt = not model_cls.__name__.startswith("TF")
|
||||
|
||||
articles = ["Justin Timberlake", "Michael Phelps"]
|
||||
tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2", padding_side="left")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
model = model_cls.from_pretrained("distilbert/distilgpt2")
|
||||
model.generation_config.eos_token_id = None
|
||||
input_ids = tokenizer(articles, return_tensors=return_tensors, padding=True).input_ids
|
||||
if is_pt:
|
||||
model = model.to(torch_device)
|
||||
input_ids = input_ids.to(torch_device)
|
||||
|
||||
outputs = model.generate(
|
||||
input_ids=input_ids,
|
||||
max_new_tokens=5,
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
)
|
||||
|
||||
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, normalize_logits=True)
|
||||
if is_pt:
|
||||
transition_scores = transition_scores.cpu().numpy()
|
||||
|
||||
expected_scores = np.array(
|
||||
[
|
||||
[-2.538938, -2.2694316, -2.1580915, -1.572299, -2.6719835],
|
||||
[-1.8826028, -2.2461371, -1.7556462, -2.9644494, -1.7996008],
|
||||
]
|
||||
)
|
||||
self.assertTrue(np.allclose(transition_scores, expected_scores, atol=1e-3))
|
||||
|
||||
def test_transition_scores_beam_search_encoder_decoder(self):
|
||||
model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"]
|
||||
return_tensors = self.framework_dependent_parameters["return_tensors"]
|
||||
is_pt = not model_cls.__name__.startswith("TF")
|
||||
|
||||
articles = [
|
||||
"Justin Timberlake and Jessica Biel, welcome to parenthood.",
|
||||
"Michael Phelps is arguably the most decorated Olympian of all time.",
|
||||
]
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
|
||||
model = model_cls.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-bart",
|
||||
max_length=10,
|
||||
num_beams=4,
|
||||
num_return_sequences=2,
|
||||
eos_token_id=None,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
length_penalty=0.0,
|
||||
)
|
||||
input_ids = tokenizer(articles, return_tensors=return_tensors, padding=True).input_ids
|
||||
if is_pt:
|
||||
model = model.to(torch_device)
|
||||
input_ids = input_ids.to(torch_device)
|
||||
|
||||
outputs = model.generate(input_ids=input_ids)
|
||||
|
||||
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, outputs.beam_indices)
|
||||
if is_pt:
|
||||
transition_scores = transition_scores.cpu().numpy()
|
||||
outputs.sequences_scores = outputs.sequences_scores.cpu().numpy()
|
||||
|
||||
self.assertTrue(np.allclose(np.sum(transition_scores, axis=-1), outputs.sequences_scores, atol=1e-3))
|
||||
|
||||
def test_transition_scores_beam_search_encoder_decoder_with_eos(self):
|
||||
model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"]
|
||||
return_tensors = self.framework_dependent_parameters["return_tensors"]
|
||||
is_pt = not model_cls.__name__.startswith("TF")
|
||||
|
||||
articles = [
|
||||
"Justin Timberlake and Jessica Biel, welcome to parenthood.",
|
||||
"Michael Phelps is arguably the most decorated Olympian of all time.",
|
||||
]
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
|
||||
model = model_cls.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-bart",
|
||||
max_length=10,
|
||||
num_beams=4,
|
||||
num_return_sequences=2,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
length_penalty=0.0,
|
||||
)
|
||||
input_ids = tokenizer(articles, return_tensors=return_tensors, padding=True).input_ids
|
||||
if is_pt:
|
||||
model = model.to(torch_device)
|
||||
input_ids = input_ids.to(torch_device)
|
||||
|
||||
outputs = model.generate(input_ids=input_ids)
|
||||
|
||||
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, outputs.beam_indices)
|
||||
if is_pt:
|
||||
transition_scores = transition_scores.cpu().numpy()
|
||||
outputs.sequences_scores = outputs.sequences_scores.cpu().numpy()
|
||||
|
||||
self.assertTrue(np.allclose(np.sum(transition_scores, axis=-1), outputs.sequences_scores, atol=1e-3))
|
||||
|
||||
def test_transition_scores_beam_search_decoder_only(self):
|
||||
model_cls = self.framework_dependent_parameters["AutoModelForCausalLM"]
|
||||
return_tensors = self.framework_dependent_parameters["return_tensors"]
|
||||
is_pt = not model_cls.__name__.startswith("TF")
|
||||
|
||||
articles = [
|
||||
"Justin Timberlake",
|
||||
"Michael Phelps",
|
||||
]
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
model = model_cls.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-gpt2",
|
||||
max_length=10,
|
||||
num_beams=4,
|
||||
num_return_sequences=2,
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
eos_token_id=None,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
length_penalty=0.0,
|
||||
)
|
||||
input_ids = tokenizer(articles, return_tensors=return_tensors, padding=True).input_ids
|
||||
if is_pt:
|
||||
model = model.to(torch_device)
|
||||
input_ids = input_ids.to(torch_device)
|
||||
|
||||
outputs = model.generate(input_ids=input_ids)
|
||||
|
||||
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, outputs.beam_indices)
|
||||
if is_pt:
|
||||
transition_scores = transition_scores.cpu().numpy()
|
||||
outputs.sequences_scores = outputs.sequences_scores.cpu().numpy()
|
||||
|
||||
self.assertTrue(np.allclose(np.sum(transition_scores, axis=-1), outputs.sequences_scores, atol=1e-3))
|
||||
|
||||
def test_transition_scores_beam_sample_encoder_decoder(self):
|
||||
model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"]
|
||||
return_tensors = self.framework_dependent_parameters["return_tensors"]
|
||||
is_pt = not model_cls.__name__.startswith("TF")
|
||||
|
||||
articles = [
|
||||
"Justin Timberlake and Jessica Biel, welcome to parenthood.",
|
||||
"Michael Phelps is arguably the most decorated Olympian of all time.",
|
||||
]
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
|
||||
model = model_cls.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-bart",
|
||||
do_sample=True,
|
||||
max_length=10,
|
||||
num_beams=4,
|
||||
num_return_sequences=2,
|
||||
eos_token_id=None,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
length_penalty=0.0,
|
||||
)
|
||||
input_ids = tokenizer(articles, return_tensors=return_tensors, padding=True).input_ids
|
||||
if is_pt:
|
||||
model = model.to(torch_device)
|
||||
input_ids = input_ids.to(torch_device)
|
||||
|
||||
outputs = model.generate(input_ids=input_ids)
|
||||
|
||||
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, outputs.beam_indices)
|
||||
if is_pt:
|
||||
transition_scores = transition_scores.cpu().numpy()
|
||||
outputs.sequences_scores = outputs.sequences_scores.cpu().numpy()
|
||||
|
||||
self.assertTrue(np.allclose(np.sum(transition_scores, axis=-1), outputs.sequences_scores, atol=1e-3))
|
||||
|
||||
@slow
|
||||
def test_transition_scores_early_stopping(self):
|
||||
# This is an aggressive test that makes sure that `beam_search's`
|
||||
# transition scores are computed correctly for varying `num_return_sequences`, `num_beams` and `batch_size > 1`
|
||||
# 2 x input_ids for "question: How are you? \n context: I had a long day, "
|
||||
model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"]
|
||||
create_tensor_fn = self.framework_dependent_parameters["create_tensor_fn"]
|
||||
is_pt = not model_cls.__name__.startswith("TF")
|
||||
|
||||
input_ids = create_tensor_fn(2 * [[822, 10, 571, 33, 25, 58, 2625, 10, 27, 141, 3, 9, 307, 239, 6, 1]])
|
||||
model = model_cls.from_pretrained("google-t5/t5-small")
|
||||
if is_pt:
|
||||
model = model.to(torch_device)
|
||||
input_ids = input_ids.to(torch_device)
|
||||
|
||||
outputs = model.generate(
|
||||
input_ids,
|
||||
max_length=10,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
forced_eos_token_id=model.config.eos_token_id,
|
||||
num_beams=4,
|
||||
do_sample=False,
|
||||
num_return_sequences=3,
|
||||
length_penalty=0.0,
|
||||
)
|
||||
|
||||
transition_scores = model.compute_transition_scores(
|
||||
sequences=outputs.sequences, scores=outputs.scores, beam_indices=outputs.beam_indices
|
||||
)
|
||||
if is_pt:
|
||||
transition_scores = transition_scores.cpu().numpy()
|
||||
outputs.sequences_scores = outputs.sequences_scores.cpu().numpy()
|
||||
|
||||
self.assertTrue(np.allclose(np.sum(transition_scores, axis=-1), outputs.sequences_scores))
|
||||
|
||||
def test_encoder_decoder_generate_attention_mask(self):
|
||||
model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"]
|
||||
return_tensors = self.framework_dependent_parameters["return_tensors"]
|
||||
is_pt = not model_cls.__name__.startswith("TF")
|
||||
|
||||
articles = ["Timberlake", "Jessica Biel, welcome to parenthood among other things"]
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
# need extreme generation values here to force this test
|
||||
# to fail when `attention_mask` is not correctly treated in generate
|
||||
model = model_cls.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-bart", max_length=50, num_beams=5, num_return_sequences=5
|
||||
)
|
||||
model.config.eos_token_id = None
|
||||
input_ids = tokenizer(articles[0], return_tensors=return_tensors).input_ids
|
||||
input_ids_batched = tokenizer(articles, padding=True, return_tensors=return_tensors).input_ids
|
||||
if is_pt:
|
||||
model = model.to(torch_device)
|
||||
input_ids = input_ids.to(torch_device)
|
||||
input_ids_batched = input_ids_batched.to(torch_device)
|
||||
|
||||
output_sequences_batched = model.generate(
|
||||
input_ids=input_ids_batched, return_dict_in_generate=True, output_scores=True
|
||||
)
|
||||
output_sequences = model.generate(input_ids=input_ids, return_dict_in_generate=True, output_scores=True)
|
||||
|
||||
batched_out = output_sequences_batched.sequences_scores
|
||||
out = output_sequences.sequences_scores
|
||||
if is_pt:
|
||||
batched_out = batched_out.cpu().numpy()
|
||||
out = out.cpu().numpy()
|
||||
|
||||
diff = np.abs(np.sum(batched_out[:5]) - np.sum(out))
|
||||
self.assertTrue(diff < 1e-4)
|
||||
|
||||
def test_generate_input_ids_as_kwarg(self):
|
||||
model_cls = self.framework_dependent_parameters["AutoModelForCausalLM"]
|
||||
return_tensors = self.framework_dependent_parameters["return_tensors"]
|
||||
is_pt = not model_cls.__name__.startswith("TF")
|
||||
|
||||
article = """I need input_ids to generate"""
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model = model_cls.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=15)
|
||||
input_ids = tokenizer(article, return_tensors=return_tensors).input_ids
|
||||
if is_pt:
|
||||
model = model.to(torch_device)
|
||||
input_ids = input_ids.to(torch_device)
|
||||
|
||||
output_sequences_kwargs = model.generate(input_ids=input_ids)
|
||||
output_sequences = model.generate(input_ids)
|
||||
if is_pt:
|
||||
output_sequences_kwargs = output_sequences_kwargs.cpu().numpy()
|
||||
output_sequences = output_sequences.cpu().numpy()
|
||||
|
||||
self.assertTrue(np.array_equal(output_sequences, output_sequences_kwargs))
|
||||
self.assertEqual(output_sequences.shape, (1, 15))
|
||||
|
||||
def test_generate_input_ids_as_encoder_kwarg(self):
|
||||
model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"]
|
||||
return_tensors = self.framework_dependent_parameters["return_tensors"]
|
||||
is_pt = not model_cls.__name__.startswith("TF")
|
||||
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
model = model_cls.from_pretrained("hf-internal-testing/tiny-random-bart", max_length=5)
|
||||
model.config.eos_token_id = None
|
||||
input_ids = tokenizer(article, return_tensors=return_tensors).input_ids
|
||||
if is_pt:
|
||||
model = model.to(torch_device)
|
||||
input_ids = input_ids.to(torch_device)
|
||||
|
||||
output_sequences_kwargs = model.generate(input_ids=input_ids)
|
||||
output_sequences = model.generate(input_ids)
|
||||
if is_pt:
|
||||
output_sequences_kwargs = output_sequences_kwargs.cpu().numpy()
|
||||
output_sequences = output_sequences.cpu().numpy()
|
||||
|
||||
self.assertTrue(np.array_equal(output_sequences, output_sequences_kwargs))
|
||||
self.assertEqual(output_sequences.shape, (1, 5))
|
||||
|
||||
def test_generate_inputs_and_encoder_kwargs(self):
|
||||
model_cls = self.framework_dependent_parameters["AutoModelForCausalLM"]
|
||||
return_tensors = self.framework_dependent_parameters["return_tensors"]
|
||||
|
||||
article = """I need input_ids to generate"""
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model = model_cls.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=10)
|
||||
input_ids = tokenizer(article, return_tensors=return_tensors).input_ids
|
||||
with self.assertRaises(ValueError):
|
||||
model.generate(input_ids, input_ids=input_ids)
|
||||
|
||||
def test_generate_too_many_encoder_kwargs(self):
|
||||
model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"]
|
||||
return_tensors = self.framework_dependent_parameters["return_tensors"]
|
||||
|
||||
article = """I need input_ids to generate"""
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
model = model_cls.from_pretrained("hf-internal-testing/tiny-random-bart", max_length=10)
|
||||
input_ids = tokenizer(article, return_tensors=return_tensors).input_ids
|
||||
with self.assertRaises(ValueError):
|
||||
model.generate(input_ids=input_ids, inputs_embeds=input_ids)
|
||||
|
||||
def test_generate_input_features_as_encoder_kwarg(self):
|
||||
model_cls = self.framework_dependent_parameters["AutoModelForSpeechSeq2Seq"]
|
||||
floats_tensor = self.framework_dependent_parameters["floats_tensor"]
|
||||
is_pt = not model_cls.__name__.startswith("TF")
|
||||
|
||||
input_features = floats_tensor((3, 80, 60))
|
||||
model = model_cls.from_pretrained("hf-internal-testing/tiny-random-WhisperForConditionalGeneration")
|
||||
if is_pt:
|
||||
input_features.to(torch_device)
|
||||
model = model.to(torch_device)
|
||||
|
||||
output_sequences_kwargs = model.generate(input_features=input_features, max_length=5)
|
||||
output_sequences = model.generate(input_features, max_length=5)
|
||||
if is_pt:
|
||||
output_sequences_kwargs = output_sequences_kwargs.cpu().numpy()
|
||||
output_sequences = output_sequences.cpu().numpy()
|
||||
|
||||
self.assertTrue(np.array_equal(output_sequences, output_sequences_kwargs))
|
||||
self.assertEqual(output_sequences.shape, (3, 5))
|
||||
|
||||
def test_generate_pixel_values_as_encoder_kwarg(self):
|
||||
model_cls = self.framework_dependent_parameters["AutoModelForVision2Seq"]
|
||||
floats_tensor = self.framework_dependent_parameters["floats_tensor"]
|
||||
is_pt = not model_cls.__name__.startswith("TF")
|
||||
|
||||
pixel_values = floats_tensor((2, 3, 30, 30))
|
||||
model = model_cls.from_pretrained("hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2")
|
||||
model.generation_config.eos_token_id = None
|
||||
if is_pt:
|
||||
pixel_values = pixel_values.to(torch_device)
|
||||
model = model.to(torch_device)
|
||||
|
||||
output_sequences_kwargs = model.generate(pixel_values=pixel_values, max_length=5)
|
||||
output_sequences = model.generate(pixel_values, max_length=5)
|
||||
if is_pt:
|
||||
output_sequences_kwargs = output_sequences_kwargs.cpu().numpy()
|
||||
output_sequences = output_sequences.cpu().numpy()
|
||||
|
||||
self.assertTrue(np.array_equal(output_sequences, output_sequences_kwargs))
|
||||
self.assertEqual(output_sequences.shape, (2, 5))
|
||||
|
||||
def test_generate_encoder_outputs_attention_mask(self):
|
||||
model_cls = self.framework_dependent_parameters["AutoModelForSpeechSeq2Seq"]
|
||||
floats_tensor = self.framework_dependent_parameters["floats_tensor"]
|
||||
create_tensor_fn = self.framework_dependent_parameters["create_tensor_fn"]
|
||||
is_pt = not model_cls.__name__.startswith("TF")
|
||||
|
||||
input_features = floats_tensor((3, 80, 60))
|
||||
attention_mask = create_tensor_fn(np.ones(input_features.shape))
|
||||
model = model_cls.from_pretrained("hf-internal-testing/tiny-random-WhisperForConditionalGeneration")
|
||||
if is_pt:
|
||||
input_features = input_features.to(torch_device)
|
||||
attention_mask = attention_mask.to(torch_device)
|
||||
model = model.to(torch_device)
|
||||
|
||||
encoder = model.get_encoder()
|
||||
encoder_outputs = encoder(input_features)
|
||||
|
||||
output_sequences_no_mask = model.generate(encoder_outputs=encoder_outputs)
|
||||
output_sequences_with_mask = model.generate(encoder_outputs=encoder_outputs, attention_mask=attention_mask)
|
||||
if is_pt:
|
||||
output_sequences_no_mask = output_sequences_no_mask.cpu().numpy()
|
||||
output_sequences_with_mask = output_sequences_with_mask.cpu().numpy()
|
||||
|
||||
self.assertTrue(np.array_equal(output_sequences_no_mask, output_sequences_with_mask))
|
||||
|
||||
def test_eos_token_id_int_and_list_greedy_search(self):
|
||||
model_cls = self.framework_dependent_parameters["AutoModelForCausalLM"]
|
||||
return_tensors = self.framework_dependent_parameters["return_tensors"]
|
||||
is_pt = not model_cls.__name__.startswith("TF")
|
||||
|
||||
generation_kwargs = {
|
||||
"do_sample": False,
|
||||
"num_beams": 1,
|
||||
}
|
||||
expectation = 13
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
text = """Hello, my dog is cute and"""
|
||||
tokens = tokenizer(text, return_tensors=return_tensors)
|
||||
model = model_cls.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
if is_pt:
|
||||
model = model.to(torch_device)
|
||||
tokens = tokens.to(torch_device)
|
||||
|
||||
eos_token_id = 873
|
||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||
|
||||
eos_token_id = [873, 198]
|
||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||
|
||||
def test_eos_token_id_int_and_list_contrastive_search(self):
|
||||
model_cls = self.framework_dependent_parameters["AutoModelForCausalLM"]
|
||||
return_tensors = self.framework_dependent_parameters["return_tensors"]
|
||||
is_pt = not model_cls.__name__.startswith("TF")
|
||||
|
||||
generation_kwargs = {
|
||||
"do_sample": False,
|
||||
"num_beams": 1,
|
||||
"penalty_alpha": 0.6,
|
||||
"top_k": 4,
|
||||
}
|
||||
expectation = 17
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
text = """Hello, my dog is cute and"""
|
||||
tokens = tokenizer(text, return_tensors=return_tensors)
|
||||
model = model_cls.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
if is_pt:
|
||||
model = model.to(torch_device)
|
||||
tokens = tokens.to(torch_device)
|
||||
|
||||
eos_token_id = 225
|
||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||
|
||||
eos_token_id = [225, 198]
|
||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||
|
||||
def test_eos_token_id_int_and_list_beam_search(self):
|
||||
model_cls = self.framework_dependent_parameters["AutoModelForCausalLM"]
|
||||
return_tensors = self.framework_dependent_parameters["return_tensors"]
|
||||
is_pt = not model_cls.__name__.startswith("TF")
|
||||
|
||||
generation_kwargs = {
|
||||
"do_sample": False,
|
||||
"num_beams": 3,
|
||||
}
|
||||
expectation = 13
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
text = """Hello, my dog is cute and"""
|
||||
tokens = tokenizer(text, return_tensors=return_tensors)
|
||||
model = model_cls.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
if is_pt:
|
||||
model = model.to(torch_device)
|
||||
tokens = tokens.to(torch_device)
|
||||
|
||||
eos_token_id = 873
|
||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||
unpadded_correct_condition = expectation == len(generated_tokens[0])
|
||||
padded_correct_condition = expectation < len(generated_tokens[0]) and all(
|
||||
token == model.config.pad_token_id for token in generated_tokens[0][expectation:]
|
||||
)
|
||||
self.assertTrue(unpadded_correct_condition or padded_correct_condition)
|
||||
|
||||
eos_token_id = [873, 198]
|
||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||
unpadded_correct_condition = expectation == len(generated_tokens[0])
|
||||
padded_correct_condition = expectation < len(generated_tokens[0]) and all(
|
||||
token == model.config.pad_token_id for token in generated_tokens[0][expectation:]
|
||||
)
|
||||
self.assertTrue(unpadded_correct_condition or padded_correct_condition)
|
||||
|
||||
def test_generate_vision2text_conditioning(self):
|
||||
model_cls = self.framework_dependent_parameters["AutoModelForVision2Seq"]
|
||||
floats_tensor = self.framework_dependent_parameters["floats_tensor"]
|
||||
create_tensor_fn = self.framework_dependent_parameters["create_tensor_fn"]
|
||||
is_pt = not model_cls.__name__.startswith("TF")
|
||||
|
||||
pixel_values = floats_tensor((2, 3, 30, 30))
|
||||
conditioning_input = create_tensor_fn([[10], [10]]) # this should be the 2nd output token, after the BOS token
|
||||
model = model_cls.from_pretrained("hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2")
|
||||
if is_pt:
|
||||
pixel_values = pixel_values.to(torch_device)
|
||||
model = model.to(torch_device)
|
||||
conditioning_input = conditioning_input.to(torch_device)
|
||||
|
||||
# we can condition on decoder_input_ids (expected decoder input) and input_ids (which we pipe internally as
|
||||
# decoder_input_ids, if the encoder is not a model with text input)
|
||||
output_sequences_decoder_input_ids = model.generate(
|
||||
pixel_values, max_length=5, decoder_input_ids=conditioning_input
|
||||
)
|
||||
output_sequences_input_ids = model.generate(pixel_values, max_length=5, input_ids=conditioning_input)
|
||||
if is_pt:
|
||||
output_sequences_decoder_input_ids = output_sequences_decoder_input_ids.cpu().numpy()
|
||||
output_sequences_input_ids = output_sequences_input_ids.cpu().numpy()
|
||||
conditioning_input = conditioning_input.cpu().numpy()
|
||||
|
||||
self.assertTrue(np.array_equal(output_sequences_decoder_input_ids, output_sequences_input_ids))
|
||||
self.assertTrue(np.array_equal(output_sequences_decoder_input_ids[:, 1:2], conditioning_input))
|
||||
@@ -1,487 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Team Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a clone of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import is_tf_available
|
||||
from transformers.testing_utils import require_tf
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers.generation import (
|
||||
TFForcedBOSTokenLogitsProcessor,
|
||||
TFForcedEOSTokenLogitsProcessor,
|
||||
TFForceTokensLogitsProcessor,
|
||||
TFLogitsProcessorList,
|
||||
TFMinLengthLogitsProcessor,
|
||||
TFNoBadWordsLogitsProcessor,
|
||||
TFNoRepeatNGramLogitsProcessor,
|
||||
TFRepetitionPenaltyLogitsProcessor,
|
||||
TFSuppressTokensAtBeginLogitsProcessor,
|
||||
TFSuppressTokensLogitsProcessor,
|
||||
TFTemperatureLogitsWarper,
|
||||
TFTopKLogitsWarper,
|
||||
TFTopPLogitsWarper,
|
||||
)
|
||||
|
||||
from ..test_modeling_tf_common import ids_tensor
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFLogitsProcessorTest(unittest.TestCase):
|
||||
def _get_uniform_logits(self, batch_size: int, length: int):
|
||||
scores = tf.ones((batch_size, length), dtype=tf.float32) / length
|
||||
return scores
|
||||
|
||||
@parameterized.expand([(False,), (True,)])
|
||||
def test_min_length_dist_processor(self, use_xla):
|
||||
vocab_size = 20
|
||||
batch_size = 4
|
||||
eos_token_id = 0
|
||||
|
||||
min_dist_processor = TFMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
|
||||
if use_xla:
|
||||
min_dist_processor = tf.function(min_dist_processor, jit_compile=True)
|
||||
|
||||
# check that min length is applied at length 5
|
||||
cur_len = 5
|
||||
input_ids = ids_tensor((batch_size, cur_len), vocab_size=20)
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores_before_min_length = min_dist_processor(input_ids, scores, cur_len)
|
||||
self.assertListEqual(scores_before_min_length[:, eos_token_id].numpy().tolist(), 4 * [-float("inf")])
|
||||
|
||||
# check that min length is not applied anymore at length 15
|
||||
cur_len = 15
|
||||
input_ids = ids_tensor((batch_size, cur_len), vocab_size=20)
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores_before_min_length = min_dist_processor(input_ids, scores, cur_len)
|
||||
self.assertFalse(tf.math.reduce_any(tf.math.is_inf(scores_before_min_length)).numpy())
|
||||
|
||||
@parameterized.expand([(False,), (True,)])
|
||||
def test_temperature_dist_warper(self, use_xla):
|
||||
input_ids = None
|
||||
cur_len = None
|
||||
length = 20
|
||||
|
||||
scores = self._get_uniform_logits(batch_size=2, length=length)
|
||||
|
||||
# tweak scores to not be uniform anymore
|
||||
scores = scores.numpy()
|
||||
scores[1, 5] = (1 / length) + 0.1 # peak, 1st batch
|
||||
scores[1, 10] = (1 / length) - 0.4 # valley, 1st batch
|
||||
scores = tf.convert_to_tensor(scores)
|
||||
|
||||
# compute softmax
|
||||
probs = tf.nn.softmax(scores, axis=-1)
|
||||
|
||||
temp_dist_warper_sharper = TFTemperatureLogitsWarper(temperature=0.5)
|
||||
temp_dist_warper_smoother = TFTemperatureLogitsWarper(temperature=1.3)
|
||||
if use_xla:
|
||||
temp_dist_warper_sharper = tf.function(temp_dist_warper_sharper, jit_compile=True)
|
||||
temp_dist_warper_smoother = tf.function(temp_dist_warper_smoother, jit_compile=True)
|
||||
|
||||
warped_prob_sharp = tf.nn.softmax(temp_dist_warper_sharper(input_ids, tf.identity(scores), cur_len), axis=-1)
|
||||
warped_prob_smooth = tf.nn.softmax(temp_dist_warper_smoother(input_ids, tf.identity(scores), cur_len), axis=-1)
|
||||
|
||||
# uniform distribution stays uniform
|
||||
tf.debugging.assert_near(probs[0, :], warped_prob_sharp[0, :], atol=1e-3)
|
||||
tf.debugging.assert_near(probs[0, :], warped_prob_smooth[0, :], atol=1e-3)
|
||||
|
||||
# sharp peaks get higher, valleys get lower
|
||||
self.assertLess(tf.math.reduce_max(probs[1, :]), tf.math.reduce_max(warped_prob_sharp[1, :]))
|
||||
self.assertGreater(tf.math.reduce_min(probs[1, :]), tf.math.reduce_min(warped_prob_sharp[1, :]))
|
||||
|
||||
# smooth peaks get lower, valleys get higher
|
||||
self.assertGreater(tf.math.reduce_max(probs[1, :]), tf.math.reduce_max(warped_prob_smooth[1, :]))
|
||||
self.assertLess(tf.math.reduce_min(probs[1, :]), tf.math.reduce_min(warped_prob_smooth[1, :]))
|
||||
|
||||
@parameterized.expand([(False,), (True,)])
|
||||
def test_repetition_penalty_dist_process(self, use_xla):
|
||||
vocab_size = 10
|
||||
cur_len = 2
|
||||
|
||||
input_ids = tf.constant([[0, 1], [5, 0]], dtype=tf.int32)
|
||||
self.assertEqual(cur_len, input_ids.shape[1])
|
||||
|
||||
scores = self._get_uniform_logits(batch_size=2, length=vocab_size)
|
||||
|
||||
mask = tf.cast(tf.constant([[1] + 9 * [0], 10 * [0]]), tf.bool)
|
||||
scores = tf.where(mask, -1 / vocab_size, scores)
|
||||
mask = tf.cast(tf.constant([10 * [0], 5 * [0] + [1] + 4 * [0]]), tf.bool)
|
||||
scores = tf.where(mask, 4 / vocab_size, scores)
|
||||
rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0)
|
||||
if use_xla:
|
||||
rep_penalty_proc = tf.function(rep_penalty_proc, jit_compile=True)
|
||||
|
||||
scores = rep_penalty_proc(input_ids, tf.identity(scores), cur_len)
|
||||
|
||||
# check that values were correctly changed (negative scores for used tokens should increase, others
|
||||
# should decrease)
|
||||
self.assertAlmostEqual(scores[0, 0].numpy(), -(1 / vocab_size) * 2)
|
||||
self.assertAlmostEqual(scores[0, 1].numpy(), (1 / vocab_size) / 2)
|
||||
self.assertAlmostEqual(scores[0, 2].numpy(), (1 / vocab_size)) # unused tokens should see no change
|
||||
|
||||
self.assertAlmostEqual(scores[1, 0].numpy(), (1 / vocab_size) / 2)
|
||||
self.assertAlmostEqual(scores[1, 5].numpy(), (4 / vocab_size) / 2)
|
||||
self.assertAlmostEqual(scores[0, 2].numpy(), (1 / vocab_size)) # unused tokens should see no change
|
||||
|
||||
@parameterized.expand([(False,), (True,)])
|
||||
def test_top_k_dist_warper(self, use_xla):
|
||||
input_ids = None
|
||||
cur_len = None
|
||||
vocab_size = 10
|
||||
batch_size = 2
|
||||
|
||||
# create ramp distribution
|
||||
ramp_logits = np.broadcast_to(np.arange(vocab_size, dtype=np.float32), (batch_size, vocab_size)).copy()
|
||||
ramp_logits[1:, : vocab_size // 2] = ramp_logits[1:, : vocab_size // 2] + vocab_size
|
||||
|
||||
top_k_warp = TFTopKLogitsWarper(3)
|
||||
if use_xla:
|
||||
top_k_warp = tf.function(top_k_warp, jit_compile=True)
|
||||
|
||||
scores = top_k_warp(input_ids, ramp_logits, cur_len)
|
||||
|
||||
# check that correct tokens are filtered
|
||||
self.assertListEqual(tf.math.is_inf(scores[0]).numpy().tolist(), 7 * [True] + 3 * [False])
|
||||
self.assertListEqual(tf.math.is_inf(scores[1]).numpy().tolist(), 2 * [True] + 3 * [False] + 5 * [True])
|
||||
|
||||
# check special cases
|
||||
length = 5
|
||||
|
||||
logits = self._get_uniform_logits(batch_size=batch_size, length=length)
|
||||
top_k_warp_safety_check = TFTopKLogitsWarper(top_k=1, filter_value=0.0, min_tokens_to_keep=3)
|
||||
if use_xla:
|
||||
top_k_warp_safety_check = tf.function(top_k_warp_safety_check, jit_compile=True)
|
||||
|
||||
scores = top_k_warp_safety_check(input_ids, logits, cur_len)
|
||||
# uniform dist is not changed
|
||||
self.assertListEqual(tf.math.reduce_sum(tf.where(scores == 0.0, 1, 0), axis=-1).numpy().tolist(), [0, 0])
|
||||
|
||||
ramp_logits = np.broadcast_to(np.arange(length, dtype=np.float32), (batch_size, length)).copy()
|
||||
scores = top_k_warp_safety_check(input_ids, ramp_logits, cur_len)
|
||||
|
||||
# min_tokens overwrites k: 3 tokens are kept => 2 tokens are nullified
|
||||
self.assertListEqual(tf.math.reduce_sum(tf.where(scores == 0.0, 1, 0), axis=-1).numpy().tolist(), [2, 2])
|
||||
|
||||
@parameterized.expand([(False,), (True,)])
|
||||
def test_top_p_dist_warper(self, use_xla):
|
||||
input_ids = None
|
||||
cur_len = None
|
||||
vocab_size = 10
|
||||
batch_size = 2
|
||||
|
||||
# create distribution and take log (inverse to Softmax as taken in TFTopPLogitsWarper)
|
||||
dist = np.log(np.array([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]], dtype=np.float32))
|
||||
|
||||
# top_p should have been 0.8 to test the edge case of top_p being exactly equal to sum of some token prob
|
||||
# However, due to the numerical instability of softmax in TF we choose this as the edge case
|
||||
# top_p as 0.8 passes when use_xla is True and fails when False. Refer PR #18984.
|
||||
top_p_warp = TFTopPLogitsWarper(0.79999995)
|
||||
if use_xla:
|
||||
top_p_warp = tf.function(top_p_warp, jit_compile=True)
|
||||
filtered_dist = tf.exp(top_p_warp(input_ids, dist, cur_len))
|
||||
|
||||
# dist should be filtered to keep min num values so that sum is >= top_p
|
||||
# exp (-inf) => 0
|
||||
EXPECTED_FILTERED_DIST = tf.constant([[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]], dtype=tf.float32)
|
||||
tf.debugging.assert_near(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3)
|
||||
|
||||
# check edge cases with negative and extreme logits
|
||||
ramp_logits = np.broadcast_to(
|
||||
np.arange(vocab_size, dtype=np.float32)[None, :], (batch_size, vocab_size)
|
||||
).copy() - (vocab_size // 2)
|
||||
|
||||
# make ramp_logits more extreme
|
||||
ramp_logits[1] = ramp_logits[1] * 100.0
|
||||
|
||||
# make sure at least 2 tokens are kept
|
||||
top_p_warp = TFTopPLogitsWarper(0.9, min_tokens_to_keep=2, filter_value=0.0)
|
||||
if use_xla:
|
||||
top_p_warp = tf.function(top_p_warp, jit_compile=True)
|
||||
filtered_dist = top_p_warp(input_ids, ramp_logits, cur_len)
|
||||
|
||||
# first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps
|
||||
# 2.
|
||||
self.assertListEqual(
|
||||
tf.math.reduce_sum(tf.where(filtered_dist != 0.0, 1, 0), axis=-1).numpy().tolist(), [3, 2]
|
||||
)
|
||||
|
||||
def test_no_repeat_ngram_dist_processor(self):
|
||||
vocab_size = 3
|
||||
batch_size = 2
|
||||
cur_len = 4
|
||||
|
||||
input_ids = tf.constant([[1, 1, 2, 1], [0, 1, 0, 1]], dtype=tf.int32)
|
||||
self.assertEqual(cur_len, input_ids.shape[1])
|
||||
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
|
||||
no_repeat_proc_2_gram = TFNoRepeatNGramLogitsProcessor(2)
|
||||
no_repeat_proc_3_gram = TFNoRepeatNGramLogitsProcessor(3)
|
||||
|
||||
filtered_scores_2_gram = no_repeat_proc_2_gram(input_ids, tf.identity(scores), cur_len)
|
||||
filtered_scores_3_gram = no_repeat_proc_3_gram(input_ids, tf.identity(scores), cur_len)
|
||||
|
||||
# 2-gram would forbid 2nd and 3rd token (1,2) at 1st batch and 1st token (0) at 2nd batch
|
||||
self.assertListEqual(
|
||||
tf.math.is_inf(filtered_scores_2_gram).numpy().tolist(), [[False, True, True], [True, False, False]]
|
||||
)
|
||||
|
||||
# 3-gram would forbid no token at 1st batch and 1st token (0) at 2nd batch
|
||||
self.assertListEqual(
|
||||
tf.math.is_inf(filtered_scores_3_gram).numpy().tolist(), [[False, False, False], [True, False, False]]
|
||||
)
|
||||
|
||||
@parameterized.expand([(False,), (True,)])
|
||||
def test_no_bad_words_dist_processor(self, use_xla):
|
||||
vocab_size = 5
|
||||
batch_size = 2
|
||||
eos_token_id = 4
|
||||
cur_len = 4
|
||||
|
||||
input_ids = tf.constant([[0, 1, 3, 1], [0, 1, 0, 1]], dtype=tf.int32)
|
||||
self.assertEqual(cur_len, input_ids.shape[1])
|
||||
|
||||
bad_word_tokens = [[1], [4], [1, 0], [0, 1, 2], [1, 3, 1, 3]]
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
|
||||
no_bad_words_dist_proc = TFNoBadWordsLogitsProcessor(bad_words_ids=bad_word_tokens, eos_token_id=eos_token_id)
|
||||
if use_xla:
|
||||
no_bad_words_dist_proc = tf.function(no_bad_words_dist_proc, jit_compile=True)
|
||||
|
||||
filtered_scores = no_bad_words_dist_proc(input_ids, tf.identity(scores), cur_len)
|
||||
|
||||
# batch 1: 1st, 2nd, and 4th (0, 1, 3) token are forbidden
|
||||
# batch 2: 1st, 2nd, and 3rd (0, 1, 2) token are forbidden
|
||||
self.assertListEqual(
|
||||
tf.math.is_inf(filtered_scores).numpy().tolist(),
|
||||
[[True, True, False, True, True], [True, True, True, False, True]],
|
||||
)
|
||||
|
||||
@parameterized.expand([(False,), (True,)])
|
||||
def test_forced_bos_token_logits_processor(self, use_xla):
|
||||
vocab_size = 20
|
||||
batch_size = 4
|
||||
bos_token_id = 0
|
||||
|
||||
logits_processor = TFForcedBOSTokenLogitsProcessor(bos_token_id=bos_token_id)
|
||||
if use_xla:
|
||||
logits_processor = tf.function(logits_processor, jit_compile=True)
|
||||
|
||||
# check that all scores are -inf except the bos_token_id score
|
||||
cur_len = 1
|
||||
input_ids = ids_tensor((batch_size, cur_len), vocab_size=20)
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores = logits_processor(input_ids, scores, cur_len)
|
||||
self.assertTrue(
|
||||
tf.math.reduce_all(tf.math.is_inf(scores[:, bos_token_id + 1 :]) & (scores[:, bos_token_id + 1 :] < 0))
|
||||
)
|
||||
self.assertListEqual(scores[:, bos_token_id].numpy().tolist(), 4 * [0]) # score for bos_token_id shold be zero
|
||||
|
||||
# check that bos_token_id is not forced if current length is greater than 1
|
||||
cur_len = 4
|
||||
input_ids = ids_tensor((batch_size, cur_len), vocab_size=20)
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores = logits_processor(input_ids, scores, cur_len)
|
||||
self.assertFalse(tf.math.reduce_any(tf.math.is_inf((scores))))
|
||||
|
||||
@parameterized.expand([(False,), (True,)])
|
||||
def test_forced_eos_token_logits_processor(self, use_xla):
|
||||
vocab_size = 20
|
||||
batch_size = 4
|
||||
eos_token_id = 0
|
||||
max_length = 5
|
||||
|
||||
logits_processor = TFForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id)
|
||||
if use_xla:
|
||||
logits_processor = tf.function(logits_processor, jit_compile=True)
|
||||
|
||||
# check that all scores are -inf except the eos_token_id when max_length-1 is reached
|
||||
cur_len = 4
|
||||
input_ids = ids_tensor((batch_size, cur_len), vocab_size=20)
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores = logits_processor(input_ids, scores, cur_len)
|
||||
self.assertTrue(
|
||||
tf.math.reduce_all(tf.math.is_inf(scores[:, eos_token_id + 1 :]) & (scores[:, eos_token_id + 1 :] < 0))
|
||||
)
|
||||
self.assertListEqual(
|
||||
scores[:, eos_token_id].numpy().tolist(), 4 * [0]
|
||||
) # score for eos_token_id should be zero
|
||||
|
||||
# check that eos_token_id is not forced if max_length-1 is not reached
|
||||
cur_len = 3
|
||||
input_ids = ids_tensor((batch_size, cur_len), vocab_size=20)
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores = logits_processor(input_ids, scores, cur_len)
|
||||
self.assertFalse(tf.math.reduce_any(tf.math.is_inf((scores))))
|
||||
|
||||
@parameterized.expand([(False,), (True,)])
|
||||
def test_suppress_tokens_at_begin_logits_processor(self, use_xla):
|
||||
vocab_size = 20
|
||||
batch_size = 4
|
||||
|
||||
begin_suppress_tokens = [1, 2, 3]
|
||||
begin_index = 5
|
||||
|
||||
logits_processor = TFSuppressTokensAtBeginLogitsProcessor(
|
||||
begin_suppress_tokens=begin_suppress_tokens, begin_index=begin_index
|
||||
)
|
||||
if use_xla:
|
||||
logits_processor = tf.function(logits_processor, jit_compile=True)
|
||||
|
||||
# Check that no scores are suppressed if begin_index is not reached
|
||||
cur_len = 4
|
||||
input_ids = tf.convert_to_tensor([[11, 17, 15, 8], [14, 0, 19, 5], [13, 11, 18, 19], [11, 12, 16, 15]])
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores = logits_processor(input_ids, scores, cur_len)
|
||||
self.assertFalse(tf.math.reduce_any(tf.math.is_inf((scores))))
|
||||
|
||||
# Check that scores are suppressed if begin_index is reached
|
||||
cur_len = 5
|
||||
input_ids = tf.convert_to_tensor([[5, 5, 5, 0, 17], [18, 1, 9, 14, 17], [18, 6, 8, 15, 19], [8, 12, 17, 1, 2]])
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores = logits_processor(input_ids, scores, cur_len)
|
||||
self.assertTrue(tf.math.reduce_all(tf.math.is_inf(tf.gather(scores, begin_suppress_tokens, axis=1))))
|
||||
|
||||
@parameterized.expand([(False,), (True,)])
|
||||
def test_suppress_tokens_logits_processor(self, use_xla):
|
||||
vocab_size = 20
|
||||
batch_size = 4
|
||||
|
||||
suppress_tokens = [1, 3, 5]
|
||||
keep_tokens = [i for i in range(vocab_size) if i not in suppress_tokens]
|
||||
|
||||
logits_processor = TFSuppressTokensLogitsProcessor(suppress_tokens=suppress_tokens)
|
||||
if use_xla:
|
||||
logits_processor = tf.function(logits_processor, jit_compile=True)
|
||||
|
||||
# Check that suppress_tokens are suppressed and others are not
|
||||
cur_len = 5
|
||||
input_ids = tf.convert_to_tensor([[0, 10, 19, 6, 3], [17, 4, 8, 17, 2], [7, 1, 11, 6, 15], [5, 8, 13, 16, 0]])
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores = logits_processor(input_ids, scores, cur_len)
|
||||
self.assertTrue(tf.math.reduce_all(tf.math.is_inf(tf.gather(scores, suppress_tokens, axis=1))))
|
||||
self.assertFalse(tf.math.reduce_any(tf.math.is_inf(tf.gather(scores, keep_tokens, axis=1))))
|
||||
|
||||
@parameterized.expand([(False,), (True,)])
|
||||
def test_force_tokens_logits_processor(self, use_xla):
|
||||
vocab_size = 20
|
||||
batch_size = 4
|
||||
|
||||
force_token_map = {1: 2, 3: 2}
|
||||
|
||||
logits_processor = TFForceTokensLogitsProcessor(force_token_map=force_token_map)
|
||||
if use_xla:
|
||||
logits_processor = tf.function(logits_processor, jit_compile=True)
|
||||
|
||||
# check that if the cur_len is contained in the force_token_map, the logits are the same
|
||||
# for all tokens except the one the force_token_map points to
|
||||
cur_len = 1
|
||||
input_ids = tf.convert_to_tensor([[11], [7], [5], [15]])
|
||||
ids_tensor((batch_size, cur_len), vocab_size=20)
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores = logits_processor(input_ids, scores, cur_len)
|
||||
tf.debugging.assert_near(tf.gather(scores, [force_token_map[cur_len]], axis=1), 0.0)
|
||||
|
||||
non_forced_inds = [i for i in range(vocab_size) if i != force_token_map[cur_len]]
|
||||
self.assertTrue(
|
||||
tf.math.reduce_all(
|
||||
tf.experimental.numpy.isclose(
|
||||
tf.gather(scores, [non_forced_inds], axis=1),
|
||||
tf.constant(scores.dtype.min),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# check that if the cur_len is not contained in the force_token_map, the logits are not modified
|
||||
cur_len = 2
|
||||
input_ids = tf.convert_to_tensor([[2, 19], [19, 15], [4, 9], [7, 6]])
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores = logits_processor(input_ids, scores, cur_len)
|
||||
self.assertFalse(tf.math.reduce_any(tf.math.is_inf((scores))))
|
||||
|
||||
@parameterized.expand([(False,), (True,)])
|
||||
def test_processor_list(self, use_xla):
|
||||
# TODO (Joao): reintroduce TFNoRepeatNGramLogitsProcessor when it gets compatible with XLA
|
||||
batch_size = 4
|
||||
cur_len = 10
|
||||
vocab_size = 15
|
||||
eos_token_id = 0
|
||||
|
||||
# dummy input_ids and scores
|
||||
input_ids = ids_tensor((batch_size, cur_len), vocab_size)
|
||||
input_ids_comp = tf.identity(input_ids)
|
||||
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores_comp = tf.identity(scores)
|
||||
|
||||
# instantiate all dist processors
|
||||
min_dist_proc = TFMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
|
||||
temp_dist_warp = TFTemperatureLogitsWarper(temperature=0.5)
|
||||
rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0)
|
||||
top_k_warp = TFTopKLogitsWarper(3)
|
||||
top_p_warp = TFTopPLogitsWarper(0.8)
|
||||
# no_repeat_proc = TFNoRepeatNGramLogitsProcessor(2)
|
||||
no_bad_words_dist_proc = TFNoBadWordsLogitsProcessor(bad_words_ids=[[1]], eos_token_id=eos_token_id)
|
||||
if use_xla:
|
||||
min_dist_proc = tf.function(min_dist_proc, jit_compile=True)
|
||||
temp_dist_warp = tf.function(temp_dist_warp, jit_compile=True)
|
||||
rep_penalty_proc = tf.function(rep_penalty_proc, jit_compile=True)
|
||||
top_k_warp = tf.function(top_k_warp, jit_compile=True)
|
||||
top_p_warp = tf.function(top_p_warp, jit_compile=True)
|
||||
# no_repeat_proc = tf.function(no_repeat_proc, jit_compile=True)
|
||||
no_bad_words_dist_proc = tf.function(no_bad_words_dist_proc, jit_compile=True)
|
||||
|
||||
# no processor list
|
||||
scores = min_dist_proc(input_ids, scores, cur_len)
|
||||
scores = temp_dist_warp(input_ids, scores, cur_len)
|
||||
scores = rep_penalty_proc(input_ids, scores, cur_len)
|
||||
scores = top_k_warp(input_ids, scores, cur_len)
|
||||
scores = top_p_warp(input_ids, scores, cur_len)
|
||||
# scores = no_repeat_proc(input_ids, scores, cur_len)
|
||||
scores = no_bad_words_dist_proc(input_ids, scores, cur_len)
|
||||
|
||||
# with processor list
|
||||
processor = TFLogitsProcessorList(
|
||||
[
|
||||
min_dist_proc,
|
||||
temp_dist_warp,
|
||||
rep_penalty_proc,
|
||||
top_k_warp,
|
||||
top_p_warp,
|
||||
# no_repeat_proc,
|
||||
no_bad_words_dist_proc,
|
||||
]
|
||||
)
|
||||
scores_comp = processor(input_ids, scores_comp, cur_len)
|
||||
|
||||
# remove inf
|
||||
scores = tf.where(tf.math.is_inf(scores), -1e9, scores)
|
||||
scores_comp = tf.where(tf.math.is_inf(scores_comp), -1e9, scores_comp)
|
||||
|
||||
# scores should be equal
|
||||
tf.debugging.assert_near(scores, scores_comp, atol=1e-3)
|
||||
|
||||
# input_ids should never be changed
|
||||
self.assertListEqual(input_ids.numpy().tolist(), input_ids_comp.numpy().tolist())
|
||||
@@ -1,245 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Team Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a clone of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from transformers import is_tensorflow_text_available, is_tf_available
|
||||
from transformers.testing_utils import require_tensorflow_text, require_tf, slow
|
||||
|
||||
from ..test_modeling_tf_common import floats_tensor
|
||||
from .test_framework_agnostic import GenerationIntegrationTestsMixin
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
TFAutoModelForCausalLM,
|
||||
TFAutoModelForSeq2SeqLM,
|
||||
TFAutoModelForSpeechSeq2Seq,
|
||||
TFAutoModelForVision2Seq,
|
||||
TFBartForConditionalGeneration,
|
||||
TFLogitsProcessorList,
|
||||
TFMinLengthLogitsProcessor,
|
||||
)
|
||||
from transformers.modeling_tf_utils import keras
|
||||
|
||||
if is_tensorflow_text_available():
|
||||
import tensorflow_text as text
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMixin):
|
||||
# setting framework_dependent_parameters needs to be gated, just like its contents' imports
|
||||
if is_tf_available():
|
||||
framework_dependent_parameters = {
|
||||
"AutoModelForCausalLM": TFAutoModelForCausalLM,
|
||||
"AutoModelForSpeechSeq2Seq": TFAutoModelForSpeechSeq2Seq,
|
||||
"AutoModelForSeq2SeqLM": TFAutoModelForSeq2SeqLM,
|
||||
"AutoModelForVision2Seq": TFAutoModelForVision2Seq,
|
||||
"LogitsProcessorList": TFLogitsProcessorList,
|
||||
"MinLengthLogitsProcessor": TFMinLengthLogitsProcessor,
|
||||
"create_tensor_fn": tf.convert_to_tensor,
|
||||
"floats_tensor": floats_tensor,
|
||||
"return_tensors": "tf",
|
||||
}
|
||||
|
||||
@slow
|
||||
def test_generate_tf_function_export_fixed_input_length(self):
|
||||
# TF-only test: tf.saved_model export
|
||||
test_model = TFAutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
input_length = 2
|
||||
max_new_tokens = 2
|
||||
|
||||
class DummyModel(tf.Module):
|
||||
def __init__(self, model):
|
||||
super(DummyModel, self).__init__()
|
||||
self.model = model
|
||||
|
||||
@tf.function(
|
||||
input_signature=(
|
||||
tf.TensorSpec((None, input_length), tf.int32, name="input_ids"),
|
||||
tf.TensorSpec((None, input_length), tf.int32, name="attention_mask"),
|
||||
),
|
||||
jit_compile=True,
|
||||
)
|
||||
def serving(self, input_ids, attention_mask):
|
||||
outputs = self.model.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_new_tokens=max_new_tokens,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
return {"sequences": outputs["sequences"]}
|
||||
|
||||
dummy_input_ids = [[2, 0], [102, 103]]
|
||||
dummy_attention_masks = [[1, 0], [1, 1]]
|
||||
dummy_model = DummyModel(model=test_model)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tf.saved_model.save(dummy_model, tmp_dir, signatures={"serving_default": dummy_model.serving})
|
||||
serving_func = tf.saved_model.load(tmp_dir).signatures["serving_default"]
|
||||
for batch_size in range(1, len(dummy_input_ids) + 1):
|
||||
inputs = {
|
||||
"input_ids": tf.constant(dummy_input_ids[:batch_size]),
|
||||
"attention_mask": tf.constant(dummy_attention_masks[:batch_size]),
|
||||
}
|
||||
tf_func_outputs = serving_func(**inputs)["sequences"]
|
||||
tf_model_outputs = test_model.generate(**inputs, max_new_tokens=max_new_tokens)
|
||||
tf.debugging.assert_equal(tf_func_outputs, tf_model_outputs)
|
||||
|
||||
@slow
|
||||
def test_generate_tf_function_export_fixed_batch_size(self):
|
||||
# TF-only test: tf.saved_model export
|
||||
test_model = TFAutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
batch_size = 1
|
||||
max_new_tokens = 2
|
||||
|
||||
class DummyModel(tf.Module):
|
||||
def __init__(self, model):
|
||||
super(DummyModel, self).__init__()
|
||||
self.model = model
|
||||
|
||||
@tf.function(
|
||||
input_signature=(
|
||||
tf.TensorSpec((batch_size, None), tf.int32, name="input_ids"),
|
||||
tf.TensorSpec((batch_size, None), tf.int32, name="attention_mask"),
|
||||
),
|
||||
jit_compile=True,
|
||||
)
|
||||
def serving(self, input_ids, attention_mask):
|
||||
outputs = self.model.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_new_tokens=max_new_tokens,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
return {"sequences": outputs["sequences"]}
|
||||
|
||||
dummy_input_ids = [[2], [102, 103]]
|
||||
dummy_attention_masks = [[1], [1, 1]]
|
||||
dummy_model = DummyModel(model=test_model)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tf.saved_model.save(dummy_model, tmp_dir, signatures={"serving_default": dummy_model.serving})
|
||||
serving_func = tf.saved_model.load(tmp_dir).signatures["serving_default"]
|
||||
for input_row in range(len(dummy_input_ids)):
|
||||
inputs = {
|
||||
"input_ids": tf.constant([dummy_input_ids[input_row]]),
|
||||
"attention_mask": tf.constant([dummy_attention_masks[input_row]]),
|
||||
}
|
||||
tf_func_outputs = serving_func(**inputs)["sequences"]
|
||||
tf_model_outputs = test_model.generate(**inputs, max_new_tokens=max_new_tokens)
|
||||
tf.debugging.assert_equal(tf_func_outputs, tf_model_outputs)
|
||||
|
||||
@slow
|
||||
@require_tensorflow_text
|
||||
def test_generate_tf_function_export_with_tf_tokenizer(self):
|
||||
# TF-only test: tf.saved_model export
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# file needed to load the TF tokenizer
|
||||
hf_hub_download(repo_id="google/flan-t5-small", filename="spiece.model", local_dir=tmp_dir)
|
||||
|
||||
class CompleteSentenceTransformer(keras.layers.Layer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.tokenizer = text.SentencepieceTokenizer(
|
||||
model=tf.io.gfile.GFile(os.path.join(tmp_dir, "spiece.model"), "rb").read()
|
||||
)
|
||||
self.model = TFAutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
def call(self, inputs, *args, **kwargs):
|
||||
tokens = self.tokenizer.tokenize(inputs)
|
||||
input_ids, attention_mask = text.pad_model_inputs(
|
||||
tokens, max_seq_length=64, pad_value=self.model.config.pad_token_id
|
||||
)
|
||||
outputs = self.model.generate(input_ids=input_ids, attention_mask=attention_mask)
|
||||
return self.tokenizer.detokenize(outputs)
|
||||
|
||||
complete_model = CompleteSentenceTransformer()
|
||||
inputs = keras.layers.Input(shape=(1,), dtype=tf.string, name="inputs")
|
||||
outputs = complete_model(inputs)
|
||||
keras_model = keras.Model(inputs, outputs)
|
||||
keras_model.save(tmp_dir)
|
||||
|
||||
def test_eos_token_id_int_and_list_top_k_top_sampling(self):
|
||||
# Has PT equivalent: this test relies on random sampling
|
||||
generation_kwargs = {
|
||||
"do_sample": True,
|
||||
"num_beams": 1,
|
||||
"top_p": 0.7,
|
||||
"top_k": 10,
|
||||
"temperature": 0.7,
|
||||
}
|
||||
expectation = 14
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
text = """Hello, my dog is cute and"""
|
||||
tokens = tokenizer(text, return_tensors="tf")
|
||||
model = TFAutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
|
||||
eos_token_id = 638
|
||||
# forces the generation to happen on CPU, to avoid GPU-related quirks
|
||||
with tf.device(":/CPU:0"):
|
||||
tf.random.set_seed(0)
|
||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||
|
||||
eos_token_id = [638, 198]
|
||||
with tf.device(":/CPU:0"):
|
||||
tf.random.set_seed(0)
|
||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||
|
||||
def test_model_kwarg_encoder_signature_filtering(self):
|
||||
# Has PT equivalent: ample use of framework-specific code
|
||||
bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
article = """Hugging Face is a technology company based in New York and Paris."""
|
||||
input_ids = bart_tokenizer(article, return_tensors="tf").input_ids
|
||||
bart_model = TFBartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
output = bart_model.generate(input_ids).numpy()
|
||||
|
||||
# Let's create a fake model that has a different signature. In particular, this fake model accepts "foo" as an
|
||||
# argument. Because "foo" is not in the encoder signature and doesn't start with "decoder_", it will be part of
|
||||
# the encoder kwargs prior to signature filtering, which would lead to an exception. But filtering kicks in and
|
||||
# saves the day.
|
||||
class FakeBart(TFBartForConditionalGeneration):
|
||||
def call(self, input_ids, foo=None, **kwargs):
|
||||
return super().call(input_ids, **kwargs)
|
||||
|
||||
bart_model = FakeBart.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
fake_output = bart_model.generate(input_ids, foo="bar").numpy()
|
||||
self.assertTrue(np.array_equal(output, fake_output))
|
||||
|
||||
# Encoder signature filtering only kicks in if it doesn't accept wildcard kwargs. The following test will fail
|
||||
# because it doesn't do signature filtering.
|
||||
class FakeEncoder(bart_model.model.encoder.__class__):
|
||||
def call(self, input_ids, **kwargs):
|
||||
return super().call(input_ids, **kwargs)
|
||||
|
||||
fake_encoder = FakeEncoder(bart_model.config, bart_model.model.shared)
|
||||
bart_model.model.encoder = fake_encoder
|
||||
|
||||
# Normal generation still works (the output will be different because the encoder weights are different)
|
||||
fake_output = bart_model.generate(input_ids).numpy()
|
||||
with self.assertRaises(ValueError):
|
||||
# FakeEncoder.call() accepts **kwargs -> no filtering -> value error due to unexpected input "foo"
|
||||
bart_model.generate(input_ids, foo="bar")
|
||||
@@ -49,7 +49,6 @@ from transformers.testing_utils import (
|
||||
from transformers.utils import is_ipex_available
|
||||
|
||||
from ..test_modeling_common import floats_tensor, ids_tensor
|
||||
from .test_framework_agnostic import GenerationIntegrationTestsMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -2783,24 +2782,9 @@ class UtilsFunctionsTest(unittest.TestCase):
|
||||
|
||||
@pytest.mark.generate
|
||||
@require_torch
|
||||
class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMixin):
|
||||
# setting framework_dependent_parameters needs to be gated, just like its contents' imports
|
||||
if is_torch_available():
|
||||
framework_dependent_parameters = {
|
||||
"AutoModelForCausalLM": AutoModelForCausalLM,
|
||||
"AutoModelForSpeechSeq2Seq": AutoModelForSpeechSeq2Seq,
|
||||
"AutoModelForSeq2SeqLM": AutoModelForSeq2SeqLM,
|
||||
"AutoModelForVision2Seq": AutoModelForVision2Seq,
|
||||
"LogitsProcessorList": LogitsProcessorList,
|
||||
"MinLengthLogitsProcessor": MinLengthLogitsProcessor,
|
||||
"create_tensor_fn": torch.tensor,
|
||||
"floats_tensor": floats_tensor,
|
||||
"return_tensors": "pt",
|
||||
}
|
||||
|
||||
class GenerationIntegrationTests(unittest.TestCase):
|
||||
@slow
|
||||
def test_diverse_beam_search(self):
|
||||
# PT-only test: TF doesn't have a diverse beam search implementation
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood.
|
||||
The celebrity couple announced the arrival of their son, Silas Randall Timberlake, in statements to People.
|
||||
"Silas was the middle name of Timberlake's maternal grandfather Bill Bomar, who died in 2012, while Randall is the musician's own middle name, as well as his father's first," People reports.
|
||||
@@ -2834,7 +2818,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
)
|
||||
|
||||
def test_max_length_if_input_embeds(self):
|
||||
# PT-only test: TF doesn't have StoppingCriteria
|
||||
article = "Today a dragon flew over Paris."
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
@@ -2848,7 +2831,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
self.assertEqual(out_gen.shape[-1], input_len + out_gen_embeds.shape[-1])
|
||||
|
||||
def test_min_length_if_input_embeds(self):
|
||||
# PT-only test: TF doesn't have StoppingCriteria
|
||||
article = "Today a dragon flew over Paris."
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
@@ -2862,7 +2844,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
self.assertEqual(out_gen.shape[-1], input_len + out_gen_embeds.shape[-1])
|
||||
|
||||
def test_custom_stopping_criteria_overload_error(self):
|
||||
# PT-only test: TF doesn't have StoppingCriteria
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
|
||||
@@ -2876,7 +2857,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
bart_model.generate(input_ids, stopping_criteria=stopping_criteria, max_length=32)
|
||||
|
||||
def test_custom_stopping_criteria(self):
|
||||
# PT-only test: TF doesn't have StoppingCriteria
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
|
||||
@@ -2900,7 +2880,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
|
||||
# TODO (joao): replace `stop_sequence` in the pipeline by the more recent `generate` functionality
|
||||
def test_stop_sequence_stopping_criteria(self):
|
||||
# PT-only test: TF doesn't have StoppingCriteria
|
||||
prompt = """Hello I believe in"""
|
||||
generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-bart")
|
||||
output = generator(prompt)
|
||||
@@ -2913,7 +2892,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
self.assertEqual(output, [{"generated_text": "Hello I believe in we"}])
|
||||
|
||||
def test_generate_non_nlp_input_ids_as_kwarg(self):
|
||||
# PT-only test: AFAIK there's no non-NLP model architecture in TF that supports `input_ids` as its only input
|
||||
model = ImageGPTForCausalImageModeling.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-imagegpt", max_length=10
|
||||
).to(torch_device)
|
||||
@@ -2926,7 +2904,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
self.assertEqual(output_sequences.shape, (3, 10))
|
||||
|
||||
def test_generate_input_values_as_encoder_kwarg(self):
|
||||
# PT-only test: AFAIK there's no generate-capable architecture in TF that supports `input_values` as its input
|
||||
input_values = floats_tensor((2, 250))
|
||||
model = SpeechEncoderDecoderModel.from_pretrained("hf-internal-testing/tiny-random-speech-encoder-decoder")
|
||||
model = model.to(torch_device)
|
||||
@@ -2937,7 +2914,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
self.assertEqual(output_sequences.shape, (2, 5))
|
||||
|
||||
def test_transition_scores_group_beam_search_encoder_decoder(self):
|
||||
# PT-only test: TF doesn't have group beam search
|
||||
articles = [
|
||||
"Justin Timberlake and Jessica Biel, welcome to parenthood.",
|
||||
"Michael Phelps is arguably the most decorated Olympian of all time.",
|
||||
@@ -3067,7 +3043,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
|
||||
@slow
|
||||
def test_beam_search_example_integration(self):
|
||||
# PT-only test: TF doesn't have a BeamSearchScorer
|
||||
# exactly the example provided in the docstrings of beam search, which previously
|
||||
# failed after directly copying from it. Refer to PR #15555
|
||||
tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base")
|
||||
@@ -3094,7 +3069,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
|
||||
@slow
|
||||
def test_constrained_beam_search(self):
|
||||
# PT-only test: TF doesn't have constrained beam search
|
||||
model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2").to(torch_device)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
|
||||
|
||||
@@ -3132,7 +3106,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
|
||||
@slow
|
||||
def test_constrained_beam_search_mixed(self):
|
||||
# PT-only test: TF doesn't have constrained beam search
|
||||
model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2").to(torch_device)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
|
||||
|
||||
@@ -3173,7 +3146,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
|
||||
@slow
|
||||
def test_constrained_beam_search_mixed_mixin(self):
|
||||
# PT-only test: TF doesn't have constrained beam search
|
||||
model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2").to(torch_device)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
|
||||
|
||||
@@ -3251,7 +3223,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
|
||||
@slow
|
||||
def test_constrained_beam_search_example_translation_mixin(self):
|
||||
# PT-only test: TF doesn't have constrained beam search
|
||||
tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base")
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base")
|
||||
|
||||
@@ -3276,7 +3247,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
|
||||
@slow
|
||||
def test_constrained_beam_search_example_integration(self):
|
||||
# PT-only test: TF doesn't have constrained beam search
|
||||
tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base")
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base")
|
||||
|
||||
@@ -3345,7 +3315,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
self.assertListEqual(out_text, expected_out)
|
||||
|
||||
def test_constrained_beam_search_mixin_type_checks(self):
|
||||
# PT-only test: TF doesn't have constrained beam search
|
||||
tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/t5-tiny-random")
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("patrickvonplaten/t5-tiny-random")
|
||||
|
||||
@@ -3386,7 +3355,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
model.generate(input_ids, force_words_ids=[[[-1]]])
|
||||
|
||||
def test_batched_decoder_start_id(self):
|
||||
# PT-only test: TF doesn't support batched_decoder_start_id
|
||||
articles = [
|
||||
"Justin Timberlake and Jessica Biel, welcome to parenthood.",
|
||||
"Michael Phelps is arguably the most decorated Olympian of all time.",
|
||||
@@ -3435,7 +3403,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
outputs = bart_model.generate(input_ids, generation_config=GenerationConfig(do_sample=False))
|
||||
|
||||
def test_contrastive_search_batched(self):
|
||||
# PT-only test: TF doesn't have constrained beam search
|
||||
# Tests that contrastive search works with batched inputs (i.e. has the same output as for non-batched inputs)
|
||||
articles = ["Foo", "Bar Baz"]
|
||||
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
@@ -3461,7 +3428,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
self.assertTrue(max_score_diff < 1e-5)
|
||||
|
||||
def test_logits_processor_not_inplace(self):
|
||||
# PT-only test: TF fixes were not made
|
||||
article = "Today a dragon flew over Paris."
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
@@ -3572,7 +3538,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
self.assertEqual(len(warning_list), 0)
|
||||
|
||||
def test_length_warning_assisted_generation(self):
|
||||
# PT-only test: TF doesn't support assisted decoding yet.
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
@@ -3604,7 +3569,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
self.assertEqual(config.is_assistant, False)
|
||||
|
||||
def test_generated_length_assisted_generation(self):
|
||||
# PT-only test: TF doesn't support assisted decoding yet.
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
@@ -3639,7 +3603,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
self.assertTrue(out.shape[-1] <= (input_length + 7))
|
||||
|
||||
def test_model_kwarg_assisted_decoding_decoder_only(self):
|
||||
# PT-only test: TF doesn't support assisted decoding yet.
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model.generation_config.pad_token_id = tokenizer.eos_token_id
|
||||
@@ -3839,7 +3802,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
@slow
|
||||
@require_torch_multi_gpu
|
||||
def test_assisted_decoding_in_different_gpu(self):
|
||||
# PT-only test: TF doesn't support assisted decoding yet.
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda:0")
|
||||
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(
|
||||
"cuda:1"
|
||||
@@ -3863,7 +3825,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
@slow
|
||||
@require_torch_accelerator
|
||||
def test_assisted_decoding_model_in_gpu_assistant_in_cpu(self):
|
||||
# PT-only test: TF doesn't support assisted decoding yet.
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(
|
||||
torch_device
|
||||
)
|
||||
@@ -3887,7 +3848,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
self.assertTrue(input_length <= out.shape[-1] <= input_length + 20)
|
||||
|
||||
def test_special_tokens_fall_back_to_model_default(self):
|
||||
# PT-only test: TF doesn't support assisted decoding yet.
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(
|
||||
torch_device
|
||||
)
|
||||
@@ -4367,6 +4327,416 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
duration = datetime.datetime.now() - start
|
||||
self.assertGreater(duration, datetime.timedelta(seconds=1.5 * MAX_TIME))
|
||||
|
||||
def test_validate_generation_inputs(self):
|
||||
"""Tests validation of inputs to `generate`"""
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
encoder_input_str = "Hello world"
|
||||
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
|
||||
|
||||
# typos are quickly detected (the correct argument is `do_sample`)
|
||||
with self.assertRaisesRegex(ValueError, "do_samples"):
|
||||
model.generate(input_ids, do_samples=True)
|
||||
|
||||
# arbitrary arguments that will not be used anywhere are also not accepted
|
||||
with self.assertRaisesRegex(ValueError, "foo"):
|
||||
fake_model_kwargs = {"foo": "bar"}
|
||||
model.generate(input_ids, **fake_model_kwargs)
|
||||
|
||||
# however, valid model_kwargs are accepted
|
||||
valid_model_kwargs = {"attention_mask": torch.tensor(np.zeros_like(input_ids))}
|
||||
model.generate(input_ids, **valid_model_kwargs)
|
||||
|
||||
def test_custom_logits_processor(self):
|
||||
"""Tests that custom logits processors can be used in `generate`, and that redundant arguments are caught."""
|
||||
bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_model = AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-bart", min_length=1)
|
||||
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids
|
||||
|
||||
logits_processor = LogitsProcessorList()
|
||||
logits_processor.append(MinLengthLogitsProcessor(min_length=10, eos_token_id=0))
|
||||
|
||||
# it should not be allowed to both define `min_length` via config and `logits_processor` list
|
||||
with self.assertRaises(ValueError):
|
||||
bart_model.generate(input_ids, logits_processor=logits_processor, min_length=10)
|
||||
bart_model.generate(input_ids, logits_processor=logits_processor)
|
||||
|
||||
def test_transition_scores_greedy_search(self):
|
||||
"""Test that `compute_transition_scores` is working as expected with gready search"""
|
||||
articles = ["Justin Timberlake", "Michael Phelps"]
|
||||
tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2", padding_side="left")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
|
||||
model.generation_config.eos_token_id = None
|
||||
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids
|
||||
model = model.to(torch_device)
|
||||
input_ids = input_ids.to(torch_device)
|
||||
|
||||
outputs = model.generate(
|
||||
input_ids=input_ids,
|
||||
max_new_tokens=5,
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
)
|
||||
|
||||
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores)
|
||||
transition_scores = transition_scores.cpu().numpy()
|
||||
|
||||
expected_scores = np.array(
|
||||
[
|
||||
[-57.8844, -60.45698, -70.16364, -65.50791, -66.35648],
|
||||
[-54.417572, -60.216614, -62.661243, -58.621933, -58.298683],
|
||||
]
|
||||
)
|
||||
self.assertTrue(np.allclose(transition_scores, expected_scores, atol=1e-3))
|
||||
|
||||
def test_transition_scores_greedy_search_normalized(self):
|
||||
"""
|
||||
Test that `compute_transition_scores` is working as expected with gready search, with `normalize_logits=True`
|
||||
"""
|
||||
articles = ["Justin Timberlake", "Michael Phelps"]
|
||||
tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2", padding_side="left")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
|
||||
model.generation_config.eos_token_id = None
|
||||
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids
|
||||
model = model.to(torch_device)
|
||||
input_ids = input_ids.to(torch_device)
|
||||
|
||||
outputs = model.generate(
|
||||
input_ids=input_ids,
|
||||
max_new_tokens=5,
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
)
|
||||
|
||||
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, normalize_logits=True)
|
||||
transition_scores = transition_scores.cpu().numpy()
|
||||
|
||||
expected_scores = np.array(
|
||||
[
|
||||
[-2.538938, -2.2694316, -2.1580915, -1.572299, -2.6719835],
|
||||
[-1.8826028, -2.2461371, -1.7556462, -2.9644494, -1.7996008],
|
||||
]
|
||||
)
|
||||
self.assertTrue(np.allclose(transition_scores, expected_scores, atol=1e-3))
|
||||
|
||||
def test_transition_scores_beam_search_encoder_decoder(self):
|
||||
"""
|
||||
Test that `compute_transition_scores` is working as expected with beam search and encoder-decoder models
|
||||
"""
|
||||
articles = [
|
||||
"Justin Timberlake and Jessica Biel, welcome to parenthood.",
|
||||
"Michael Phelps is arguably the most decorated Olympian of all time.",
|
||||
]
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids
|
||||
model = model.to(torch_device)
|
||||
input_ids = input_ids.to(torch_device)
|
||||
|
||||
outputs = model.generate(
|
||||
input_ids=input_ids,
|
||||
max_length=10,
|
||||
num_beams=4,
|
||||
num_return_sequences=2,
|
||||
eos_token_id=None,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
length_penalty=0.0,
|
||||
)
|
||||
|
||||
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, outputs.beam_indices)
|
||||
transition_scores = transition_scores.cpu().numpy()
|
||||
outputs.sequences_scores = outputs.sequences_scores.cpu().numpy()
|
||||
|
||||
self.assertTrue(np.allclose(np.sum(transition_scores, axis=-1), outputs.sequences_scores, atol=1e-3))
|
||||
|
||||
def test_transition_scores_beam_search_encoder_decoder_with_eos(self):
|
||||
"""
|
||||
Test that `compute_transition_scores` is working as expected with beam search and encoder-decoder models, when
|
||||
an EOS token is defined
|
||||
"""
|
||||
articles = [
|
||||
"Justin Timberlake and Jessica Biel, welcome to parenthood.",
|
||||
"Michael Phelps is arguably the most decorated Olympian of all time.",
|
||||
]
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids
|
||||
model = model.to(torch_device)
|
||||
input_ids = input_ids.to(torch_device)
|
||||
|
||||
outputs = model.generate(
|
||||
input_ids=input_ids,
|
||||
max_length=10,
|
||||
num_beams=4,
|
||||
num_return_sequences=2,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
length_penalty=0.0,
|
||||
)
|
||||
|
||||
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, outputs.beam_indices)
|
||||
transition_scores = transition_scores.cpu().numpy()
|
||||
outputs.sequences_scores = outputs.sequences_scores.cpu().numpy()
|
||||
|
||||
self.assertTrue(np.allclose(np.sum(transition_scores, axis=-1), outputs.sequences_scores, atol=1e-3))
|
||||
|
||||
def test_transition_scores_beam_search_decoder_only(self):
|
||||
"""
|
||||
Test that `compute_transition_scores` is working as expected with beam search and decoder-only models
|
||||
"""
|
||||
articles = [
|
||||
"Justin Timberlake",
|
||||
"Michael Phelps",
|
||||
]
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids
|
||||
model = model.to(torch_device)
|
||||
input_ids = input_ids.to(torch_device)
|
||||
|
||||
outputs = model.generate(
|
||||
input_ids=input_ids,
|
||||
max_length=10,
|
||||
num_beams=4,
|
||||
num_return_sequences=2,
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
eos_token_id=None,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
length_penalty=0.0,
|
||||
)
|
||||
|
||||
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, outputs.beam_indices)
|
||||
transition_scores = transition_scores.cpu().numpy()
|
||||
outputs.sequences_scores = outputs.sequences_scores.cpu().numpy()
|
||||
|
||||
self.assertTrue(np.allclose(np.sum(transition_scores, axis=-1), outputs.sequences_scores, atol=1e-3))
|
||||
|
||||
@slow
|
||||
def test_transition_scores_early_stopping(self):
|
||||
"""
|
||||
Test that `compute_transition_scores` is working as expected with beam search and early stopping
|
||||
|
||||
This is an aggressive test that makes sure that `beam_search's`
|
||||
transition scores are computed correctly for varying `num_return_sequences`, `num_beams` and `batch_size > 1`
|
||||
2 x input_ids for "question: How are you? \n context: I had a long day, "
|
||||
"""
|
||||
input_ids = torch.tensor(2 * [[822, 10, 571, 33, 25, 58, 2625, 10, 27, 141, 3, 9, 307, 239, 6, 1]])
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-small")
|
||||
model = model.to(torch_device)
|
||||
input_ids = input_ids.to(torch_device)
|
||||
|
||||
outputs = model.generate(
|
||||
input_ids,
|
||||
max_length=10,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
forced_eos_token_id=model.config.eos_token_id,
|
||||
num_beams=4,
|
||||
do_sample=False,
|
||||
num_return_sequences=3,
|
||||
length_penalty=0.0,
|
||||
)
|
||||
|
||||
transition_scores = model.compute_transition_scores(
|
||||
sequences=outputs.sequences, scores=outputs.scores, beam_indices=outputs.beam_indices
|
||||
)
|
||||
transition_scores = transition_scores.cpu().numpy()
|
||||
outputs.sequences_scores = outputs.sequences_scores.cpu().numpy()
|
||||
|
||||
self.assertTrue(np.allclose(np.sum(transition_scores, axis=-1), outputs.sequences_scores))
|
||||
|
||||
def test_encoder_decoder_generate_attention_mask(self):
|
||||
"""
|
||||
Test that `generate` automagically creates the correct `attention_mask` for encoder-decoder models (which
|
||||
has a different keyword)
|
||||
"""
|
||||
articles = ["Timberlake", "Jessica Biel, welcome to parenthood among other things"]
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
# need extreme generation values here to force this test
|
||||
# to fail when `attention_mask` is not correctly treated in generate
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-bart",
|
||||
)
|
||||
model.config.eos_token_id = None
|
||||
input_ids = tokenizer(articles[0], return_tensors="pt").input_ids
|
||||
input_ids_batched = tokenizer(articles, padding=True, return_tensors="pt").input_ids
|
||||
model = model.to(torch_device)
|
||||
input_ids = input_ids.to(torch_device)
|
||||
input_ids_batched = input_ids_batched.to(torch_device)
|
||||
|
||||
generate_kwargs = {
|
||||
"return_dict_in_generate": True,
|
||||
"output_scores": True,
|
||||
"max_length": 50,
|
||||
"num_beams": 5,
|
||||
"num_return_sequences": 5,
|
||||
}
|
||||
|
||||
output_sequences_batched = model.generate(input_ids=input_ids_batched, **generate_kwargs)
|
||||
output_sequences = model.generate(input_ids=input_ids, **generate_kwargs)
|
||||
|
||||
batched_out = output_sequences_batched.sequences_scores
|
||||
out = output_sequences.sequences_scores
|
||||
batched_out = batched_out.cpu().numpy()
|
||||
out = out.cpu().numpy()
|
||||
|
||||
diff = np.abs(np.sum(batched_out[:5]) - np.sum(out))
|
||||
self.assertTrue(diff < 1e-4)
|
||||
|
||||
def test_generate_input_ids_as_kwarg(self):
|
||||
"""Test that `input_ids` work equaly as a positional and keyword argument in decoder-only models"""
|
||||
article = "I need input_ids to generate"
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=15)
|
||||
input_ids = tokenizer(article, return_tensors="pt").input_ids
|
||||
model = model.to(torch_device)
|
||||
input_ids = input_ids.to(torch_device)
|
||||
|
||||
output_sequences_kwargs = model.generate(input_ids=input_ids)
|
||||
output_sequences = model.generate(input_ids)
|
||||
output_sequences_kwargs = output_sequences_kwargs.cpu().numpy()
|
||||
output_sequences = output_sequences.cpu().numpy()
|
||||
|
||||
self.assertTrue(np.array_equal(output_sequences, output_sequences_kwargs))
|
||||
self.assertEqual(output_sequences.shape, (1, 15))
|
||||
|
||||
def test_generate_input_ids_as_encoder_kwarg(self):
|
||||
"""Test that `input_ids` work equaly as a positional and keyword argument in encoder-decoder models"""
|
||||
article = "Justin Timberlake and Jessica Biel, welcome to parenthood."
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
model.config.eos_token_id = None
|
||||
input_ids = tokenizer(article, return_tensors="pt").input_ids
|
||||
model = model.to(torch_device)
|
||||
input_ids = input_ids.to(torch_device)
|
||||
|
||||
output_sequences_kwargs = model.generate(input_ids=input_ids, max_length=5)
|
||||
output_sequences = model.generate(input_ids, max_length=5)
|
||||
output_sequences_kwargs = output_sequences_kwargs.cpu().numpy()
|
||||
output_sequences = output_sequences.cpu().numpy()
|
||||
|
||||
self.assertTrue(np.array_equal(output_sequences, output_sequences_kwargs))
|
||||
self.assertEqual(output_sequences.shape, (1, 5))
|
||||
|
||||
def test_generate_inputs_and_encoder_kwargs(self):
|
||||
"""
|
||||
Test that an exception is thrown if the main tensor (`input_ids` in LLMs) is passed as both a positional and
|
||||
keyword argument
|
||||
"""
|
||||
article = "I need input_ids to generate"
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=10)
|
||||
input_ids = tokenizer(article, return_tensors="pt").input_ids
|
||||
with self.assertRaises(ValueError):
|
||||
model.generate(input_ids, input_ids=input_ids)
|
||||
|
||||
def test_generate_too_many_encoder_kwargs(self):
|
||||
"""Test that passing redundant inputs results in an exception (`input_ids` and `inputs_embeds` in LLMs)"""
|
||||
article = "I need input_ids to generate"
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-bart", max_length=10)
|
||||
input_ids = tokenizer(article, return_tensors="pt").input_ids
|
||||
with self.assertRaises(ValueError):
|
||||
model.generate(input_ids=input_ids, inputs_embeds=input_ids)
|
||||
|
||||
def test_generate_input_features_as_encoder_kwarg(self):
|
||||
"""Test that non-`input_ids` main model inputs are correctly handled as positional arguments"""
|
||||
input_features = floats_tensor((3, 80, 60))
|
||||
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-WhisperForConditionalGeneration"
|
||||
)
|
||||
input_features.to(torch_device)
|
||||
model = model.to(torch_device)
|
||||
|
||||
output_sequences_kwargs = model.generate(input_features=input_features, max_length=5)
|
||||
output_sequences = model.generate(input_features, max_length=5)
|
||||
output_sequences_kwargs = output_sequences_kwargs.cpu().numpy()
|
||||
output_sequences = output_sequences.cpu().numpy()
|
||||
|
||||
self.assertTrue(np.array_equal(output_sequences, output_sequences_kwargs))
|
||||
self.assertEqual(output_sequences.shape, (3, 5))
|
||||
|
||||
def test_generate_encoder_outputs_attention_mask(self):
|
||||
"""Test that `generate` can handle attention masks when the encoder outputs are passed"""
|
||||
input_features = floats_tensor((3, 80, 60))
|
||||
attention_mask = torch.randint(0, 2, input_features.shape).to(torch_device)
|
||||
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-WhisperForConditionalGeneration"
|
||||
)
|
||||
input_features = input_features.to(torch_device)
|
||||
attention_mask = attention_mask.to(torch_device)
|
||||
model = model.to(torch_device)
|
||||
|
||||
encoder = model.get_encoder()
|
||||
encoder_outputs = encoder(input_features)
|
||||
|
||||
output_sequences_no_mask = model.generate(encoder_outputs=encoder_outputs)
|
||||
output_sequences_with_mask = model.generate(encoder_outputs=encoder_outputs, attention_mask=attention_mask)
|
||||
output_sequences_no_mask = output_sequences_no_mask.cpu().numpy()
|
||||
output_sequences_with_mask = output_sequences_with_mask.cpu().numpy()
|
||||
|
||||
self.assertFalse(np.array_equal(output_sequences_no_mask, output_sequences_with_mask))
|
||||
|
||||
def test_eos_token_id_int_and_list_greedy_search(self):
|
||||
"""Test that `generate` can handle multiple EOS tokens"""
|
||||
generation_kwargs = {
|
||||
"do_sample": False,
|
||||
"num_beams": 1,
|
||||
}
|
||||
expectation = 13
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
text = """Hello, my dog is cute and"""
|
||||
tokens = tokenizer(text, return_tensors="pt")
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model = model.to(torch_device)
|
||||
tokens = tokens.to(torch_device)
|
||||
|
||||
eos_token_id = 873
|
||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||
|
||||
eos_token_id = [873, 198]
|
||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||
|
||||
def test_generate_vision2text_conditioning(self):
|
||||
"""Test that `decoder_input_ids` can be used to condition the generation in vision-to-text models"""
|
||||
pixel_values = floats_tensor((2, 3, 30, 30))
|
||||
conditioning_input = torch.tensor([[10], [10]]) # this should be the 2nd output token, after the BOS token
|
||||
model = AutoModelForVision2Seq.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2"
|
||||
)
|
||||
pixel_values = pixel_values.to(torch_device)
|
||||
model = model.to(torch_device)
|
||||
conditioning_input = conditioning_input.to(torch_device)
|
||||
|
||||
# we can condition on decoder_input_ids (expected decoder input) and input_ids (which we pipe internally as
|
||||
# decoder_input_ids, if the encoder is not a model with text input)
|
||||
output_sequences_decoder_input_ids = model.generate(
|
||||
pixel_values, max_length=5, decoder_input_ids=conditioning_input
|
||||
)
|
||||
output_sequences_input_ids = model.generate(pixel_values, max_length=5, input_ids=conditioning_input)
|
||||
output_sequences_decoder_input_ids = output_sequences_decoder_input_ids.cpu().numpy()
|
||||
output_sequences_input_ids = output_sequences_input_ids.cpu().numpy()
|
||||
conditioning_input = conditioning_input.cpu().numpy()
|
||||
|
||||
self.assertTrue(np.array_equal(output_sequences_decoder_input_ids, output_sequences_input_ids))
|
||||
self.assertTrue(np.array_equal(output_sequences_decoder_input_ids[:, 1:2], conditioning_input))
|
||||
|
||||
|
||||
@require_torch
|
||||
class TokenHealingTestCase(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user