[Test refactor 1/5] Per-folder tests reorganization (#15725)

* Per-folder tests reorganization

Co-authored-by: sgugger <sylvain.gugger@gmail.com>
Co-authored-by: Stas Bekman <stas@stason.org>
This commit is contained in:
Lysandre Debut
2022-02-23 15:46:28 -05:00
committed by GitHub
parent fecb08c2b8
commit 29c10a41d0
438 changed files with 636 additions and 565 deletions

View File

View File

@@ -0,0 +1,533 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a clone of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, torch_device
from ..test_modeling_common import floats_tensor, ids_tensor
if is_torch_available():
import torch
from transformers.generation_beam_constraints import PhrasalConstraint
from transformers.generation_beam_search import BeamHypotheses, BeamSearchScorer, ConstrainedBeamSearchScorer
class BeamSearchTester:
def __init__(
self,
parent,
batch_size=3,
sequence_length=10,
vocab_size=99,
pad_token_id=0,
max_length=20,
num_beams=4,
length_penalty=2.0,
do_early_stopping=True,
num_beam_hyps_to_keep=2,
):
self.parent = parent
self.batch_size = batch_size
self.sequence_length = sequence_length
self.vocab_size = vocab_size
self.pad_token_id = pad_token_id
self.max_length = max_length
self.num_beams = num_beams
self.length_penalty = length_penalty
self.do_early_stopping = do_early_stopping
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
# cannot be randomely generated
self.eos_token_id = vocab_size + 1
def prepare_beam_scorer(self, **kwargs):
return BeamSearchScorer(
batch_size=kwargs.get("batch_size", self.batch_size),
num_beams=kwargs.get("num_beams", self.num_beams),
device=torch_device,
length_penalty=kwargs.get("length_penalty", self.length_penalty),
do_early_stopping=kwargs.get("do_early_stopping", self.do_early_stopping),
num_beam_hyps_to_keep=kwargs.get("num_beam_hyps_to_keep", self.num_beam_hyps_to_keep),
)
def prepare_inputs(self):
input_ids = ids_tensor((self.batch_size * self.num_beams, self.sequence_length), self.vocab_size)
next_tokens = ids_tensor((self.batch_size, 2 * self.num_beams), self.vocab_size).to(torch_device)
next_indices = ids_tensor((self.batch_size, 2 * self.num_beams), self.num_beams).to(torch_device)
next_scores, _ = (-floats_tensor((self.batch_size, 2 * self.num_beams)).to(torch_device)).sort(descending=True)
return (input_ids, next_tokens, next_indices, next_scores)
def check_beam_hypotheses(self, input_ids, *args):
# check that correct number of beam hypotheses is set in beam scorer
beam_scorer = self.prepare_beam_scorer(do_early_stopping=True)
beam_hyp = beam_scorer._beam_hyps[0]
self.parent.assertEqual(len(beam_scorer._beam_hyps), self.batch_size)
# check correct type
self.parent.assertTrue(isinstance(beam_hyp, BeamHypotheses))
# check that num_beams is correctly set
self.parent.assertEqual(beam_hyp.num_beams, self.num_beams)
# check for early stopping deactivated
for beam_idx in range(self.num_beams):
beam_hyp.add(input_ids[beam_idx], -10.0)
# if early stopping True -> score does not matter
self.parent.assertTrue(beam_hyp.is_done(-10.0, 5))
# re-init
beam_scorer = self.prepare_beam_scorer(do_early_stopping=False)
beam_hyp = beam_scorer._beam_hyps[0]
# add `num_beams + 1` beams to change `worst_score`
for beam_idx in range(self.num_beams + 1):
beam_hyp.add(input_ids[beam_idx], -10.0 + float(beam_idx))
# -10.0 is removed => -9.0 is worst score
self.parent.assertAlmostEqual(beam_hyp.worst_score, -9.0 / (self.sequence_length**beam_hyp.length_penalty))
# -5.0 is better than worst score => should not be finished
self.parent.assertFalse(beam_hyp.is_done(-5.0, self.sequence_length))
# -20.0 is worse than worst score => should be finished
self.parent.assertTrue(beam_hyp.is_done(-20.0, self.sequence_length))
def check_beam_scorer_update(self, input_ids, next_tokens, next_indices, next_scores):
# check too many eos tokens
beam_scorer = self.prepare_beam_scorer()
tokens = next_tokens.clone()
tokens[0, :] = self.eos_token_id
with self.parent.assertRaises(ValueError):
beam_scorer.process(input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id)
# check all batches are done
beam_scorer = self.prepare_beam_scorer()
tokens = next_tokens.clone()
tokens[:, : self.num_beams] = self.eos_token_id
beam_scorer.process(input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id)
# beam scorer should be done
self.parent.assertTrue(beam_scorer.is_done)
# check
beam_scorer = self.prepare_beam_scorer()
tokens = next_tokens.clone()
tokens[:, 1] = self.eos_token_id
beam_outputs = beam_scorer.process(
input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id
)
output_scores = beam_outputs["next_beam_scores"]
output_tokens = beam_outputs["next_beam_tokens"]
output_indices = beam_outputs["next_beam_indices"]
def cut_expected_tensor(tensor):
return torch.cat([tensor[:, :1], tensor[:, 2 : self.num_beams + 1]], dim=1).flatten()
# check all outptus
# cut out id of eos token and take best `num_beams` outputs
expected_output_tokens = cut_expected_tensor(tokens)
expected_output_scores = cut_expected_tensor(next_scores)
# add num_beams * batch_idx
expected_output_indices = (
cut_expected_tensor(next_indices)
+ (torch.arange(self.num_beams * self.batch_size, device=torch_device) // self.num_beams) * self.num_beams
)
self.parent.assertListEqual(expected_output_tokens.tolist(), output_tokens.tolist())
self.parent.assertListEqual(expected_output_indices.tolist(), output_indices.tolist())
self.parent.assertTrue(torch.allclose(expected_output_scores, output_scores, atol=1e-3))
# make sure ids of eos token are correctly saved in beam_hyps of beam scorer
for batch_idx in range(self.batch_size):
correct_idx = batch_idx * self.num_beams + next_indices[batch_idx, 1]
self.parent.assertListEqual(
input_ids[correct_idx].tolist(), beam_scorer._beam_hyps[batch_idx].beams[0][-1].tolist()
)
def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_scores):
# max_length should be only one more than current input_ids to check that eos is correctly appended
max_length = self.sequence_length + 1
beam_scorer = self.prepare_beam_scorer(num_beam_hyps_to_keep=1, length_penalty=1.0, do_early_stopping=False)
# update beams and append to input_ids
tokens = next_tokens.clone()
# first batch, first output has to finish with eos token id since scores are correctly sorted
tokens[0, 0] = self.eos_token_id
# make sure corresponding score is as good as possible to surely be picked first
next_scores[0, 0] = 0.0
beam_outputs = beam_scorer.process(
input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id
)
output_scores = beam_outputs["next_beam_scores"]
output_tokens = beam_outputs["next_beam_tokens"]
output_indices = beam_outputs["next_beam_indices"]
input_ids = torch.cat([input_ids[output_indices, :], output_tokens.unsqueeze(-1)], dim=-1)
# finalize
sequence_output = beam_scorer.finalize(
input_ids,
output_scores,
output_tokens,
output_indices,
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id,
max_length=max_length,
)
sequences = sequence_output["sequences"]
sequence_scores = sequence_output["sequence_scores"]
# since `num_beam_hyps_to_keep` = 1 => only return `batch_size` x `max_length`
self.parent.assertListEqual(list(sequences.shape), [self.batch_size, max_length])
self.parent.assertListEqual(list(sequence_scores.shape), [self.batch_size])
# check sequence_scores
self.parent.assertFalse((sequence_scores > 0).any().item())
# first batch has to finish with eos_token
self.parent.assertEqual(sequences[0, -1].item(), self.eos_token_id)
# other batches cannot finish with eos token
self.parent.assertNotEqual(sequences[1, -1].item(), self.eos_token_id)
self.parent.assertNotEqual(sequences[2, -1].item(), self.eos_token_id)
# now test that if `num_beam_hyps_to_keep` is 3 => all beams are returned
beam_scorer.num_beam_hyps_to_keep = self.num_beams
sequence_output = beam_scorer.finalize(
input_ids,
output_scores,
output_tokens,
output_indices,
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id,
max_length=max_length,
)
sequences = sequence_output["sequences"]
sequence_scores = sequence_output["sequence_scores"]
self.parent.assertListEqual(list(sequences.shape), [self.num_beams * self.batch_size, max_length])
self.parent.assertListEqual(list(sequence_scores.shape), [self.num_beams * self.batch_size])
class ConstrainedBeamSearchTester:
def __init__(
self,
parent,
constraints=None,
batch_size=3,
sequence_length=10,
vocab_size=99,
pad_token_id=0,
max_length=20,
num_beams=4,
length_penalty=2.0,
do_early_stopping=True,
num_beam_hyps_to_keep=2,
):
self.parent = parent
self.batch_size = batch_size
self.sequence_length = sequence_length
self.vocab_size = vocab_size
self.pad_token_id = pad_token_id
self.max_length = max_length
self.num_beams = num_beams
self.length_penalty = length_penalty
self.do_early_stopping = do_early_stopping
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
if constraints is None:
force_tokens = torch.randint(10, 50, (1, 2)).type(torch.LongTensor)[0]
constraints = [
PhrasalConstraint(force_tokens),
]
self.constraints = constraints
# cannot be randomely generated
self.eos_token_id = vocab_size + 1
def prepare_constrained_beam_scorer(self, **kwargs):
return ConstrainedBeamSearchScorer(
constraints=kwargs.get("constraints", self.constraints),
batch_size=kwargs.get("batch_size", self.batch_size),
num_beams=kwargs.get("num_beams", self.num_beams),
device=torch_device,
length_penalty=kwargs.get("length_penalty", self.length_penalty),
do_early_stopping=kwargs.get("do_early_stopping", self.do_early_stopping),
num_beam_hyps_to_keep=kwargs.get("num_beam_hyps_to_keep", self.num_beam_hyps_to_keep),
)
def prepare_inputs(self):
input_ids = ids_tensor((self.batch_size * self.num_beams, self.sequence_length), self.vocab_size)
next_tokens = ids_tensor((self.batch_size, 2 * self.num_beams), self.vocab_size).to(torch_device)
next_indices = ids_tensor((self.batch_size, 2 * self.num_beams), self.num_beams).to(torch_device)
next_scores, _ = (-floats_tensor((self.batch_size, 2 * self.num_beams)).to(torch_device)).sort(descending=True)
scores_for_all_vocab, _ = (
-floats_tensor((self.batch_size * self.num_beams, self.vocab_size)).to(torch_device)
).sort(descending=True)
return (input_ids, next_tokens, next_indices, next_scores, scores_for_all_vocab)
def check_beam_hypotheses(self, input_ids, *args):
# check that correct number of beam hypotheses is set in beam scorer
constrained_beam_scorer = self.prepare_constrained_beam_scorer(do_early_stopping=True)
beam_hyp = constrained_beam_scorer._beam_hyps[0]
self.parent.assertEqual(len(constrained_beam_scorer._beam_hyps), self.batch_size)
# check correct type
self.parent.assertTrue(isinstance(beam_hyp, BeamHypotheses))
# check that num_beams is correctly set
self.parent.assertEqual(beam_hyp.num_beams, self.num_beams)
# check for early stopping deactivated
for beam_idx in range(self.num_beams):
beam_hyp.add(input_ids[beam_idx], -10.0)
# if early stopping True -> score does not matter
self.parent.assertTrue(beam_hyp.is_done(-10.0, 5))
# re-init
constrained_beam_scorer = self.prepare_constrained_beam_scorer(do_early_stopping=False)
beam_hyp = constrained_beam_scorer._beam_hyps[0]
# add `num_beams + 1` beams to change `worst_score`
for beam_idx in range(self.num_beams + 1):
beam_hyp.add(input_ids[beam_idx], -10.0 + float(beam_idx))
# -10.0 is removed => -9.0 is worst score
self.parent.assertAlmostEqual(beam_hyp.worst_score, -9.0 / (self.sequence_length**beam_hyp.length_penalty))
# -5.0 is better than worst score => should not be finished
self.parent.assertFalse(beam_hyp.is_done(-5.0, self.sequence_length))
# -20.0 is worse than worst score => should be finished
self.parent.assertTrue(beam_hyp.is_done(-20.0, self.sequence_length))
def check_constrained_beam_scorer_update(
self, input_ids, next_tokens, next_indices, next_scores, scores_for_all_vocab
):
# check too many eos tokens
constrained_beam_scorer = self.prepare_constrained_beam_scorer()
fulfilling_sequence = torch.stack([constraint.token_ids for constraint in self.constraints]).flatten()
fulfill_len = fulfilling_sequence.size(0)
input_ids[:, :fulfill_len] = fulfilling_sequence
tokens = next_tokens.clone()
tokens[0, :] = self.eos_token_id
with self.parent.assertRaises(ValueError):
constrained_beam_scorer.process(
input_ids, next_scores, tokens, next_indices, scores_for_all_vocab, eos_token_id=self.eos_token_id
)
# check all batches are done
constrained_beam_scorer = self.prepare_constrained_beam_scorer()
tokens = next_tokens.clone()
tokens[:, : self.num_beams] = self.eos_token_id
constrained_beam_scorer.process(
input_ids, next_scores, tokens, next_indices, scores_for_all_vocab, eos_token_id=self.eos_token_id
)
# beam scorer should be done
self.parent.assertTrue(constrained_beam_scorer.is_done)
# check
constrained_beam_scorer = self.prepare_constrained_beam_scorer()
tokens = next_tokens.clone()
tokens[:, 1] = self.eos_token_id
beam_outputs = constrained_beam_scorer.process(
input_ids, next_scores, tokens, next_indices, scores_for_all_vocab, eos_token_id=self.eos_token_id
)
output_scores = beam_outputs["next_beam_scores"]
output_tokens = beam_outputs["next_beam_tokens"]
output_indices = beam_outputs["next_beam_indices"]
def cut_expected_tensor(tensor):
return torch.cat([tensor[:, :1], tensor[:, 2 : self.num_beams + 1]], dim=1).flatten()
# check all outptus
# cut out id of eos token and take best `num_beams` outputs
expected_output_tokens = cut_expected_tensor(tokens)
expected_output_scores = cut_expected_tensor(next_scores)
# add num_beams * batch_idx
expected_output_indices = (
cut_expected_tensor(next_indices)
+ (torch.arange(self.num_beams * self.batch_size, device=torch_device) // self.num_beams) * self.num_beams
)
self.parent.assertListEqual(expected_output_tokens.tolist(), output_tokens.tolist())
self.parent.assertListEqual(expected_output_indices.tolist(), output_indices.tolist())
self.parent.assertTrue(torch.allclose(expected_output_scores, output_scores, atol=1e-3))
# make sure ids of eos token are correctly saved in beam_hyps of beam scorer
for batch_idx in range(self.batch_size):
correct_idx = batch_idx * self.num_beams + next_indices[batch_idx, 1]
self.parent.assertListEqual(
input_ids[correct_idx].tolist(), constrained_beam_scorer._beam_hyps[batch_idx].beams[0][-1].tolist()
)
def check_constrained_beam_scorer_finalize(
self, input_ids, next_tokens, next_indices, next_scores, scores_for_all_vocab
):
# max_length should be only one more than current input_ids to check that eos is correctly appended
max_length = self.sequence_length + 1
# for testing finalize, we do want to have fulfilled constraints
fulfilling_sequence = torch.stack([constraint.token_ids for constraint in self.constraints]).flatten()
fulfill_len = fulfilling_sequence.size(0)
input_ids[:, :fulfill_len] = fulfilling_sequence
constrained_beam_scorer = self.prepare_constrained_beam_scorer(
num_beam_hyps_to_keep=1, length_penalty=1.0, do_early_stopping=False
)
constraints = constrained_beam_scorer.constraints
# update beams and append to input_ids
tokens = next_tokens.clone()
# first batch, first output has to finish with eos token id since scores are correctly sorted
tokens[0, 0] = self.eos_token_id
# make sure corresponding score is as good as possible to surely be picked first
next_scores[0, 0] = 0.0
beam_outputs = constrained_beam_scorer.process(
input_ids, next_scores, tokens, next_indices, scores_for_all_vocab, eos_token_id=self.eos_token_id
)
output_scores = beam_outputs["next_beam_scores"]
output_tokens = beam_outputs["next_beam_tokens"]
output_indices = beam_outputs["next_beam_indices"]
input_ids = torch.cat([input_ids[output_indices, :], output_tokens.unsqueeze(-1)], dim=-1)
# finalize
sequence_output = constrained_beam_scorer.finalize(
input_ids,
output_scores,
output_tokens,
output_indices,
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id,
max_length=max_length,
)
sequences = sequence_output["sequences"]
sequence_scores = sequence_output["sequence_scores"]
# since `num_beam_hyps_to_keep` = 1 => only return `batch_size` x `max_length`
self.parent.assertListEqual(list(sequences.shape), [self.batch_size, max_length])
self.parent.assertListEqual(list(sequence_scores.shape), [self.batch_size])
# check sequence_scores
self.parent.assertFalse((sequence_scores > 0).any().item())
# first batch has to finish with eos_token
self.parent.assertEqual(sequences[0, -1].item(), self.eos_token_id)
# other batches cannot finish with eos token
self.parent.assertNotEqual(sequences[1, -1].item(), self.eos_token_id)
self.parent.assertNotEqual(sequences[2, -1].item(), self.eos_token_id)
# test that the constraint is indeed fulfilled
for output in sequences:
for constraint in constraints:
forced_token_ids = constraint.token_ids
self.parent.assertEqual(self._check_sequence_inside_sequence(output, forced_token_ids), True)
# now test that if `num_beam_hyps_to_keep` is 3 => all beams are returned
# constrained_beam_scorer.num_beam_hyps_to_keep = self.num_beams
constrained_beam_scorer = self.prepare_constrained_beam_scorer(
num_beam_hyps_to_keep=self.num_beams, length_penalty=1.0, do_early_stopping=False
)
sequence_output = constrained_beam_scorer.finalize(
input_ids,
output_scores,
output_tokens,
output_indices,
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id,
max_length=max_length,
)
sequences = sequence_output["sequences"]
sequence_scores = sequence_output["sequence_scores"]
self.parent.assertListEqual(list(sequences.shape), [self.num_beams * self.batch_size, max_length])
self.parent.assertListEqual(list(sequence_scores.shape), [self.num_beams * self.batch_size])
def _check_sequence_inside_sequence(self, tensor_1, tensor_2):
# set to same device. we don't care what device.
tensor_1, tensor_2 = tensor_1.cpu(), tensor_2.cpu()
in_order = tensor_1.size(0) <= tensor_2.size(0)
longer = tensor_2 if in_order else tensor_1
shorter = tensor_1 if in_order else tensor_2
flag = False
chunk_size = shorter.size(0)
for chunk_idx in range(longer.size(0) - chunk_size + 1):
subseq = longer[chunk_idx : chunk_idx + chunk_size]
if torch.equal(subseq, shorter):
flag = True
break
return flag
@require_torch
class BeamSearchTest(unittest.TestCase):
def setUp(self):
self.beam_search_tester = BeamSearchTester(self)
def test_beam_hypotheses(self):
inputs = self.beam_search_tester.prepare_inputs()
self.beam_search_tester.check_beam_hypotheses(*inputs)
def test_beam_scorer_update(self):
inputs = self.beam_search_tester.prepare_inputs()
self.beam_search_tester.check_beam_scorer_update(*inputs)
def test_beam_scorer_finalize(self):
inputs = self.beam_search_tester.prepare_inputs()
self.beam_search_tester.check_beam_scores_finalize(*inputs)
@require_torch
class ConstrainedBeamSearchTest(unittest.TestCase):
def setUp(self):
self.constrained_beam_search_tester = ConstrainedBeamSearchTester(self)
def test_constrained_beam_hypotheses(self):
inputs = self.constrained_beam_search_tester.prepare_inputs()
self.constrained_beam_search_tester.check_beam_hypotheses(*inputs)
def test_constrained_beam_scorer_update(self):
inputs = self.constrained_beam_search_tester.prepare_inputs()
self.constrained_beam_search_tester.check_constrained_beam_scorer_update(*inputs)
def test_constrained_beam_scorer_finalize(self):
inputs = self.constrained_beam_search_tester.prepare_inputs()
self.constrained_beam_search_tester.check_constrained_beam_scorer_finalize(*inputs)

View File

@@ -0,0 +1,301 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a clone of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from transformers import is_flax_available
from transformers.testing_utils import require_flax
from ..test_modeling_flax_common import ids_tensor
if is_flax_available():
import jax
import jax.numpy as jnp
from transformers.generation_flax_logits_process import (
FlaxForcedBOSTokenLogitsProcessor,
FlaxForcedEOSTokenLogitsProcessor,
FlaxLogitsProcessorList,
FlaxMinLengthLogitsProcessor,
FlaxTemperatureLogitsWarper,
FlaxTopKLogitsWarper,
FlaxTopPLogitsWarper,
)
@require_flax
class LogitsProcessorTest(unittest.TestCase):
def _get_uniform_logits(self, batch_size: int, length: int):
scores = np.ones((batch_size, length)) / length
return scores
def test_temperature_dist_warper(self):
input_ids = None
length = 20
scores = self._get_uniform_logits(batch_size=2, length=length)
# tweak scores to not be uniform anymore
scores[1, 5] = (1 / length) + 0.1 # peak, 1st batch
scores[1, 10] = (1 / length) - 0.4 # valley, 1st batch
# compute softmax
probs = jax.nn.softmax(scores, axis=-1)
temp_dist_warper_sharper = FlaxTemperatureLogitsWarper(temperature=0.5)
temp_dist_warper_smoother = FlaxTemperatureLogitsWarper(temperature=1.3)
warped_prob_sharp = jax.nn.softmax(temp_dist_warper_sharper(input_ids, scores.copy(), cur_len=None), axis=-1)
warped_prob_smooth = jax.nn.softmax(temp_dist_warper_smoother(input_ids, scores.copy(), cur_len=None), axis=-1)
# uniform distribution stays uniform
self.assertTrue(jnp.allclose(probs[0, :], warped_prob_sharp[0, :], atol=1e-3))
self.assertTrue(jnp.allclose(probs[0, :], warped_prob_smooth[0, :], atol=1e-3))
# sharp peaks get higher, valleys get lower
self.assertLess(probs[1, :].max(), warped_prob_sharp[1, :].max())
self.assertGreater(probs[1, :].min(), warped_prob_sharp[1, :].min())
# smooth peaks get lower, valleys get higher
self.assertGreater(probs[1, :].max(), warped_prob_smooth[1, :].max())
self.assertLess(probs[1, :].min(), warped_prob_smooth[1, :].min())
def test_top_k_dist_warper(self):
input_ids = None
vocab_size = 10
batch_size = 2
# create ramp distribution
ramp_logits = np.broadcast_to(np.arange(vocab_size)[None, :], (batch_size, vocab_size)).copy()
ramp_logits[1:, : vocab_size // 2] = ramp_logits[1:, : vocab_size // 2] + vocab_size
top_k_warp = FlaxTopKLogitsWarper(3)
scores = top_k_warp(input_ids, ramp_logits, cur_len=None)
# check that correct tokens are filtered
self.assertListEqual(jnp.isinf(scores[0]).tolist(), 7 * [True] + 3 * [False])
self.assertListEqual(jnp.isinf(scores[1]).tolist(), 2 * [True] + 3 * [False] + 5 * [True])
# check special case
length = 5
top_k_warp_safety_check = FlaxTopKLogitsWarper(top_k=1, filter_value=0.0, min_tokens_to_keep=3)
ramp_logits = np.broadcast_to(np.arange(length)[None, :], (batch_size, length)).copy()
scores = top_k_warp_safety_check(input_ids, ramp_logits, cur_len=None)
# min_tokens overwrites k: 3 tokens are kept => 2 tokens are nullified
self.assertListEqual((scores == 0.0).sum(axis=-1).tolist(), [2, 2])
def test_top_p_dist_warper(self):
input_ids = None
vocab_size = 10
batch_size = 2
# create distribution and take log (inverse to Softmax as taken in TopPLogitsWarper)
dist = np.log(np.array([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]]))
top_p_warp = FlaxTopPLogitsWarper(0.7)
filtered_dist = np.exp(top_p_warp(input_ids, dist, cur_len=None))
# dist should be filtered to keep min num values so that sum is >= 0.7
# exp (-inf) => 0
EXPECTED_FILTERED_DIST = np.array([[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]])
self.assertTrue(np.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3))
# check edge cases with negative and extreme logits
ramp_logits = np.broadcast_to(np.arange(vocab_size)[None, :], (batch_size, vocab_size)).copy() - (
vocab_size // 2
)
# make ramp_logits more extreme
ramp_logits[1] = ramp_logits[1] * 100.0
# make sure at least 2 tokens are kept
top_p_warp = FlaxTopPLogitsWarper(0.9, min_tokens_to_keep=2, filter_value=0.0)
filtered_dist = top_p_warp(input_ids, ramp_logits, cur_len=None)
# first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
self.assertListEqual((filtered_dist != 0.0).sum(axis=-1).tolist(), [3, 2])
def test_min_length_dist_processor(self):
vocab_size = 20
batch_size = 4
eos_token_id = 0
min_dist_processor = FlaxMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
# check that min length is applied at length 5
input_ids = ids_tensor((batch_size, 20), vocab_size=20)
cur_len = 5
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_before_min_length = min_dist_processor(input_ids, scores, cur_len=cur_len)
self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), 4 * [-float("inf")])
# check that min length is not applied anymore at length 15
scores = self._get_uniform_logits(batch_size, vocab_size)
cur_len = 15
scores_before_min_length = min_dist_processor(input_ids, scores, cur_len=cur_len)
self.assertFalse(jnp.isinf(scores_before_min_length).any())
def test_forced_bos_token_logits_processor(self):
vocab_size = 20
batch_size = 4
bos_token_id = 0
logits_processor = FlaxForcedBOSTokenLogitsProcessor(bos_token_id=bos_token_id)
# check that all scores are -inf except the bos_token_id score
input_ids = ids_tensor((batch_size, 1), vocab_size=20)
cur_len = 1
scores = self._get_uniform_logits(batch_size, vocab_size)
scores = logits_processor(input_ids, scores, cur_len=cur_len)
self.assertTrue(jnp.isneginf(scores[:, bos_token_id + 1 :]).all())
self.assertListEqual(scores[:, bos_token_id].tolist(), 4 * [0]) # score for bos_token_id shold be zero
# check that bos_token_id is not forced if current length is greater than 1
cur_len = 3
scores = self._get_uniform_logits(batch_size, vocab_size)
scores = logits_processor(input_ids, scores, cur_len=cur_len)
self.assertFalse(jnp.isinf(scores).any())
def test_forced_eos_token_logits_processor(self):
vocab_size = 20
batch_size = 4
eos_token_id = 0
max_length = 5
logits_processor = FlaxForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id)
# check that all scores are -inf except the eos_token_id when max_length is reached
input_ids = ids_tensor((batch_size, 4), vocab_size=20)
cur_len = 4
scores = self._get_uniform_logits(batch_size, vocab_size)
scores = logits_processor(input_ids, scores, cur_len=cur_len)
self.assertTrue(jnp.isneginf(scores[:, eos_token_id + 1 :]).all())
self.assertListEqual(scores[:, eos_token_id].tolist(), 4 * [0]) # score for eos_token_id should be zero
# check that eos_token_id is not forced if max_length is not reached
cur_len = 3
scores = self._get_uniform_logits(batch_size, vocab_size)
scores = logits_processor(input_ids, scores, cur_len=cur_len)
self.assertFalse(jnp.isinf(scores).any())
def test_processor_list(self):
batch_size = 4
sequence_length = 10
vocab_size = 15
eos_token_id = 2
bos_token_id = 1
max_length = 15
# dummy input_ids and scores
input_ids = ids_tensor((batch_size, sequence_length), vocab_size)
input_ids_comp = input_ids.copy()
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_comp = scores.copy()
# instantiate all dist processors
temp_dist_warp = FlaxTemperatureLogitsWarper(temperature=0.5)
top_k_warp = FlaxTopKLogitsWarper(3)
top_p_warp = FlaxTopPLogitsWarper(0.8)
# instantiate all logits processors
min_dist_proc = FlaxMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
bos_dist_proc = FlaxForcedBOSTokenLogitsProcessor(bos_token_id=bos_token_id)
eos_dist_proc = FlaxForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id)
cur_len = 10
# no processor list
scores = temp_dist_warp(input_ids, scores, cur_len=cur_len)
scores = top_k_warp(input_ids, scores, cur_len=cur_len)
scores = top_p_warp(input_ids, scores, cur_len=cur_len)
scores = min_dist_proc(input_ids, scores, cur_len=cur_len)
scores = bos_dist_proc(input_ids, scores, cur_len=cur_len)
scores = eos_dist_proc(input_ids, scores, cur_len=cur_len)
# with processor list
processor = FlaxLogitsProcessorList(
[temp_dist_warp, top_k_warp, top_p_warp, min_dist_proc, bos_dist_proc, eos_dist_proc]
)
scores_comp = processor(input_ids, scores_comp, cur_len=cur_len)
# scores should be equal
self.assertTrue(jnp.allclose(scores, scores_comp, atol=1e-3))
# input_ids should never be changed
self.assertListEqual(input_ids.tolist(), input_ids_comp.tolist())
def test_processor_list_jitted(self):
batch_size = 4
sequence_length = 10
vocab_size = 15
eos_token_id = 2
bos_token_id = 1
max_length = 15
# dummy input_ids and scores
input_ids = ids_tensor((batch_size, sequence_length), vocab_size)
input_ids_comp = input_ids.copy()
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_comp = scores.copy()
# instantiate all dist processors
temp_dist_warp = FlaxTemperatureLogitsWarper(temperature=0.5)
top_k_warp = FlaxTopKLogitsWarper(3)
top_p_warp = FlaxTopPLogitsWarper(0.8)
# instantiate all logits processors
min_dist_proc = FlaxMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
bos_dist_proc = FlaxForcedBOSTokenLogitsProcessor(bos_token_id=bos_token_id)
eos_dist_proc = FlaxForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id)
cur_len = 10
# no processor list
def run_no_processor_list(input_ids, scores, cur_len):
scores = temp_dist_warp(input_ids, scores, cur_len=cur_len)
scores = top_k_warp(input_ids, scores, cur_len=cur_len)
scores = top_p_warp(input_ids, scores, cur_len=cur_len)
scores = min_dist_proc(input_ids, scores, cur_len=cur_len)
scores = bos_dist_proc(input_ids, scores, cur_len=cur_len)
scores = eos_dist_proc(input_ids, scores, cur_len=cur_len)
return scores
# with processor list
def run_processor_list(input_ids, scores, cur_len):
processor = FlaxLogitsProcessorList(
[temp_dist_warp, top_k_warp, top_p_warp, min_dist_proc, bos_dist_proc, eos_dist_proc]
)
scores = processor(input_ids, scores, cur_len=cur_len)
return scores
jitted_run_no_processor_list = jax.jit(run_no_processor_list)
jitted_run_processor_list = jax.jit(run_processor_list)
scores = jitted_run_no_processor_list(input_ids, scores, cur_len)
scores_comp = jitted_run_processor_list(input_ids, scores_comp, cur_len)
# scores should be equal
self.assertTrue(jnp.allclose(scores, scores_comp, atol=1e-3))
# input_ids should never be changed
self.assertListEqual(input_ids.tolist(), input_ids_comp.tolist())

View File

@@ -0,0 +1,276 @@
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import numpy as np
import transformers
from transformers import is_flax_available, is_torch_available
from transformers.testing_utils import is_pt_flax_cross_test, require_flax
if is_flax_available():
import os
import jax
import jax.numpy as jnp
from jax import jit
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
if is_torch_available():
import torch
def ids_tensor(shape, vocab_size, rng=None):
"""Creates a random int32 tensor of the shape within the vocab size."""
if rng is None:
rng = random.Random()
total_dims = 1
for dim in shape:
total_dims *= dim
values = []
for _ in range(total_dims):
values.append(rng.randint(0, vocab_size - 1))
output = np.array(values, dtype=jnp.int32).reshape(shape)
return output
def random_attention_mask(shape, rng=None):
attn_mask = ids_tensor(shape, vocab_size=2, rng=rng)
# make sure that at least one token is attended to for each batch
attn_mask[:, -1] = 1
return attn_mask
@require_flax
class FlaxGenerationTesterMixin:
model_tester = None
all_generative_model_classes = ()
def _get_input_ids_and_config(self):
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
# cut to half length & take max batch_size 3
max_batch_size = 2
sequence_length = inputs["input_ids"].shape[-1] // 2
input_ids = inputs["input_ids"][:max_batch_size, :sequence_length]
attention_mask = jnp.ones_like(input_ids)
attention_mask = attention_mask[:max_batch_size, :sequence_length]
# generate max 5 tokens
max_length = input_ids.shape[-1] + 5
if config.eos_token_id is not None and config.pad_token_id is None:
# hack to allow generate for models such as GPT2 as is done in `generate()`
config.pad_token_id = config.eos_token_id
return config, input_ids, attention_mask, max_length
@is_pt_flax_cross_test
def test_greedy_generate_pt_fx(self):
config, input_ids, _, max_length = self._get_input_ids_and_config()
config.do_sample = False
config.max_length = max_length
config.decoder_start_token_id = 0
for model_class in self.all_generative_model_classes:
flax_model = model_class(config)
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)
pt_model = pt_model_class(config).eval()
pt_model = load_flax_weights_in_pytorch_model(pt_model, flax_model.params)
flax_generation_outputs = flax_model.generate(input_ids).sequences
pt_generation_outputs = pt_model.generate(torch.tensor(input_ids, dtype=torch.long))
if flax_generation_outputs.shape[-1] > pt_generation_outputs.shape[-1]:
flax_generation_outputs = flax_generation_outputs[:, : pt_generation_outputs.shape[-1]]
self.assertListEqual(pt_generation_outputs.numpy().tolist(), flax_generation_outputs.tolist())
def test_greedy_generate(self):
config, input_ids, _, max_length = self._get_input_ids_and_config()
config.do_sample = False
config.max_length = max_length
for model_class in self.all_generative_model_classes:
model = model_class(config)
generation_outputs = model.generate(input_ids).sequences
self.assertEqual(generation_outputs.shape[-1], max_length)
jit_generate = jit(model.generate)
jit_generation_outputs = jit_generate(input_ids).sequences
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
def test_sample_generate(self):
config, input_ids, _, max_length = self._get_input_ids_and_config()
config.do_sample = True
config.max_length = max_length
for model_class in self.all_generative_model_classes:
model = model_class(config)
generation_outputs = model.generate(input_ids).sequences
self.assertEqual(generation_outputs.shape[-1], max_length)
jit_generate = jit(model.generate)
jit_generation_outputs = jit_generate(input_ids).sequences
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
def test_beam_search_generate(self):
config, input_ids, _, max_length = self._get_input_ids_and_config()
config.do_sample = False
config.max_length = max_length
config.num_beams = 2
for model_class in self.all_generative_model_classes:
model = model_class(config)
generation_outputs = model.generate(input_ids).sequences
self.assertEqual(generation_outputs.shape[-1], max_length)
jit_generate = jit(model.generate)
jit_generation_outputs = jit_generate(input_ids).sequences
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
def test_sample_generate_logits_warper(self):
config, input_ids, _, max_length = self._get_input_ids_and_config()
config.do_sample = True
config.max_length = max_length
config.temperature = 0.8
config.top_k = 10
config.top_p = 0.3
config.min_length = 1
config.forced_bos_token_id = 8
config.forced_eos_token_id = 9
for model_class in self.all_generative_model_classes:
model = model_class(config)
generation_outputs = model.generate(input_ids).sequences
self.assertEqual(generation_outputs.shape[-1], max_length)
jit_generate = jit(model.generate)
jit_generation_outputs = jit_generate(input_ids).sequences
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
def test_greedy_generate_logits_warper(self):
config, input_ids, _, max_length = self._get_input_ids_and_config()
config.max_length = max_length
config.min_length = 1
config.forced_bos_token_id = 8
config.forced_eos_token_id = 9
for model_class in self.all_generative_model_classes:
model = model_class(config)
generation_outputs = model.generate(input_ids).sequences
self.assertEqual(generation_outputs.shape[-1], max_length)
jit_generate = jit(model.generate)
jit_generation_outputs = jit_generate(input_ids).sequences
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
def test_beam_search_generate_logits_warper(self):
config, input_ids, _, max_length = self._get_input_ids_and_config()
config.max_length = max_length
config.num_beams = 2
config.min_length = 1
config.forced_bos_token_id = 8
config.forced_eos_token_id = 9
for model_class in self.all_generative_model_classes:
model = model_class(config)
generation_outputs = model.generate(input_ids).sequences
self.assertEqual(generation_outputs.shape[-1], max_length)
jit_generate = jit(model.generate)
jit_generation_outputs = jit_generate(input_ids).sequences
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
def test_greedy_generate_attn_mask(self):
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
# pad attention mask on the left
attention_mask = jax.ops.index_update(attention_mask, (0, 0), 0)
config.do_sample = False
config.max_length = max_length
for model_class in self.all_generative_model_classes:
model = model_class(config)
generation_outputs = model.generate(input_ids, attention_mask=attention_mask).sequences
self.assertEqual(generation_outputs.shape[-1], max_length)
jit_generate = jit(model.generate)
jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
def test_sample_generate_attn_mask(self):
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
# pad attention mask on the left
attention_mask = jax.ops.index_update(attention_mask, (0, 0), 0)
config.do_sample = True
config.max_length = max_length
for model_class in self.all_generative_model_classes:
model = model_class(config)
generation_outputs = model.generate(input_ids, attention_mask=attention_mask).sequences
self.assertEqual(generation_outputs.shape[-1], max_length)
jit_generate = jit(model.generate)
jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
def test_beam_search_generate_attn_mask(self):
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
# pad attention mask on the left
attention_mask = jax.ops.index_update(attention_mask, (0, 0), 0)
config.num_beams = 2
config.max_length = max_length
for model_class in self.all_generative_model_classes:
model = model_class(config)
generation_outputs = model.generate(input_ids, attention_mask=attention_mask).sequences
self.assertEqual(generation_outputs.shape[-1], max_length)
jit_generate = jit(model.generate)
jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())

View File

@@ -0,0 +1,506 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a clone of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, torch_device
from ..test_modeling_common import ids_tensor
if is_torch_available():
import torch
from torch import nn
from transformers.generation_logits_process import (
EncoderNoRepeatNGramLogitsProcessor,
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
HammingDiversityLogitsProcessor,
InfNanRemoveLogitsProcessor,
LogitsProcessorList,
MinLengthLogitsProcessor,
NoBadWordsLogitsProcessor,
NoRepeatNGramLogitsProcessor,
PrefixConstrainedLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
TypicalLogitsWarper,
)
@require_torch
class LogitsProcessorTest(unittest.TestCase):
def _get_uniform_logits(self, batch_size: int, length: int):
scores = torch.ones((batch_size, length), device=torch_device, dtype=torch.float) / length
return scores
def test_min_lenght_dist_processor(self):
vocab_size = 20
batch_size = 4
eos_token_id = 0
min_dist_processor = MinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
# check that min length is applied at length 5
input_ids = ids_tensor((batch_size, 5), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_before_min_length = min_dist_processor(input_ids, scores)
self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), 4 * [-float("inf")])
# check that min length is not applied anymore at length 15
input_ids = ids_tensor((batch_size, 15), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_before_min_length = min_dist_processor(input_ids, scores)
self.assertFalse(torch.isinf(scores_before_min_length).any())
def test_temperature_dist_warper(self):
input_ids = None
length = 20
scores = self._get_uniform_logits(batch_size=2, length=length)
# tweak scores to not be uniform anymore
scores[1, 5] = (1 / length) + 0.1 # peak, 1st batch
scores[1, 10] = (1 / length) - 0.4 # valley, 1st batch
# compute softmax
probs = nn.functional.softmax(scores, dim=-1)
temp_dist_warper_sharper = TemperatureLogitsWarper(temperature=0.5)
temp_dist_warper_smoother = TemperatureLogitsWarper(temperature=1.3)
warped_prob_sharp = nn.functional.softmax(temp_dist_warper_sharper(input_ids, scores.clone()), dim=-1)
warped_prob_smooth = nn.functional.softmax(temp_dist_warper_smoother(input_ids, scores.clone()), dim=-1)
# uniform distribution stays uniform
self.assertTrue(torch.allclose(probs[0, :], warped_prob_sharp[0, :], atol=1e-3))
self.assertTrue(torch.allclose(probs[0, :], warped_prob_smooth[0, :], atol=1e-3))
# sharp peaks get higher, valleys get lower
self.assertLess(probs[1, :].max(), warped_prob_sharp[1, :].max())
self.assertGreater(probs[1, :].min(), warped_prob_sharp[1, :].min())
# smooth peaks get lower, valleys get higher
self.assertGreater(probs[1, :].max(), warped_prob_smooth[1, :].max())
self.assertLess(probs[1, :].min(), warped_prob_smooth[1, :].min())
def test_repetition_penalty_dist_process(self):
input_ids = torch.tensor([[0, 1], [5, 0]], device=torch_device, dtype=torch.long)
vocab_size = 10
scores = self._get_uniform_logits(batch_size=2, length=vocab_size)
# give values special values
scores[0, 0] = -(1 / vocab_size)
scores[1, 5] = 4 / vocab_size
rep_penalty_proc = RepetitionPenaltyLogitsProcessor(penalty=2.0)
scores = rep_penalty_proc(input_ids, scores.clone())
# check that values were correctly changed
self.assertAlmostEqual(scores[0, 0].item(), -(1 / vocab_size) * 2)
self.assertAlmostEqual(scores[0, 1].item(), (1 / vocab_size) / 2)
self.assertAlmostEqual(scores[1, 0].item(), (1 / vocab_size) / 2)
self.assertAlmostEqual(scores[1, 5].item(), (4 / vocab_size) / 2)
def test_top_k_dist_warper(self):
input_ids = None
vocab_size = 10
batch_size = 2
# create ramp distribution
ramp_logits = (
torch.arange(vocab_size, device=torch_device, dtype=torch.float).unsqueeze(0).repeat(batch_size, 1)
)
ramp_logits[1:, : vocab_size // 2] = ramp_logits[1:, : vocab_size // 2] + vocab_size
top_k_warp = TopKLogitsWarper(3)
scores = top_k_warp(input_ids, ramp_logits)
# check that correct tokens are filtered
self.assertListEqual(torch.isinf(scores[0]).tolist(), 7 * [True] + 3 * [False])
self.assertListEqual(torch.isinf(scores[1]).tolist(), 2 * [True] + 3 * [False] + 5 * [True])
# check special cases
length = 5
logits = self._get_uniform_logits(batch_size=batch_size, length=length)
top_k_warp_safety_check = TopKLogitsWarper(top_k=1, filter_value=0.0, min_tokens_to_keep=3)
scores = top_k_warp_safety_check(input_ids, logits)
# uniform dist is not changed
self.assertListEqual((scores == 0.0).to(torch.long).sum(dim=-1).tolist(), [0, 0])
ramp_logits = torch.arange(length, device=torch_device, dtype=torch.float).unsqueeze(0).repeat(batch_size, 1)
scores = top_k_warp_safety_check(input_ids, ramp_logits)
# min_tokens overwrites k: 3 tokens are kept => 2 tokens are nullified
self.assertListEqual((scores == 0.0).to(torch.long).sum(dim=-1).tolist(), [2, 2])
def test_top_p_dist_warper(self):
input_ids = None
vocab_size = 10
batch_size = 2
# create distribution and take log (inverse to Softmax as taken in TopPLogitsWarper)
dist = torch.log(
torch.tensor([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]], device=torch_device, dtype=torch.float)
)
top_p_warp = TopPLogitsWarper(0.7)
filtered_dist = torch.exp(top_p_warp(input_ids, dist))
# dist should be filtered to keep min num values so that sum is >= 0.7
# exp (-inf) => 0
EXPECTED_FILTERED_DIST = torch.tensor(
[[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]], device=torch_device, dtype=torch.float
)
self.assertTrue(torch.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3))
# check edge cases with negative and extreme logits
ramp_logits = torch.arange(vocab_size, device=torch_device, dtype=torch.float).unsqueeze(0).repeat(
batch_size, 1
) - (vocab_size // 2)
# make ramp_logits more extreme
ramp_logits[1] = ramp_logits[1] * 100.0
# make sure at least 2 tokens are kept
top_p_warp = TopPLogitsWarper(0.9, min_tokens_to_keep=2, filter_value=0.0)
filtered_dist = top_p_warp(input_ids, ramp_logits)
# first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [3, 2])
def test_typical_dist_warper(self):
input_ids = None
vocab_size = 10
batch_size = 2
# create distribution and take log (inverse to Softmax as taken in TopPLogitsWarper)
dist = torch.log(
torch.tensor([[0.97, 0.01, 0.01, 0.01], [0.4, 0.2, 0.2, 0.2]], device=torch_device, dtype=torch.float)
)
typical_warp = TypicalLogitsWarper(0.5)
filtered_dist = torch.exp(typical_warp(input_ids, dist))
# dist should be filtered to keep min num values so that sum is >= 0.7
# exp (-inf) => 0
EXPECTED_FILTERED_DIST = torch.tensor(
[[0.97, 0.0, 0.0, 0.0], [0.0, 0.2, 0.2, 0.2]], device=torch_device, dtype=torch.float
)
self.assertTrue(torch.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3))
# check special cases
length = 5
logits = self._get_uniform_logits(batch_size=batch_size, length=length)
typical_warp_safety_check = TypicalLogitsWarper(mass=0.5, filter_value=0.0, min_tokens_to_keep=3)
scores = typical_warp_safety_check(input_ids, logits)
# uniform dist is not changed
self.assertListEqual((scores == 0.0).to(torch.long).sum(dim=-1).tolist(), [0, 0])
# check edge cases with negative and extreme logits
ramp_logits = torch.arange(vocab_size, device=torch_device, dtype=torch.float).unsqueeze(0).repeat(
batch_size, 1
) - (vocab_size // 2)
# make ramp_logits more extreme
ramp_logits[1] = ramp_logits[1] * 100.0
# make sure at least 2 tokens are kept
typical_warp = TypicalLogitsWarper(0.7, min_tokens_to_keep=2, filter_value=0.0)
filtered_dist = typical_warp(input_ids, ramp_logits)
# first batch should keep two tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [2, 2])
def test_no_repeat_ngram_dist_processor(self):
vocab_size = 3
batch_size = 2
input_ids = torch.tensor([[1, 1, 2, 1], [0, 1, 0, 1]], device=torch_device, dtype=torch.long)
scores = self._get_uniform_logits(batch_size, vocab_size)
no_repeat_proc_2_gram = NoRepeatNGramLogitsProcessor(2)
no_repeat_proc_3_gram = NoRepeatNGramLogitsProcessor(3)
filtered_scores_2_gram = no_repeat_proc_2_gram(input_ids, scores.clone())
filtered_scores_3_gram = no_repeat_proc_3_gram(input_ids, scores.clone())
# 2-gram would forbid 2nd and 3rd token (1,2) at 1st batch and 1st token (0) at 2nd batch
self.assertListEqual(torch.isinf(filtered_scores_2_gram).tolist(), [[False, True, True], [True, False, False]])
# 3-gram would forbid no token at 1st batch and 1st token (0) at 2nd batch
self.assertListEqual(
torch.isinf(filtered_scores_3_gram).tolist(), [[False, False, False], [True, False, False]]
)
def test_encoder_no_repeat_ngram_dist_processor(self):
vocab_size = 3
num_beams = 2
batch_size = 1
encoder_input_ids = torch.tensor([1, 2, 1, 1], device=torch_device, dtype=torch.long)
input_ids = torch.tensor([[1, 2, 1], [8, 0, 2]], device=torch_device, dtype=torch.long)
scores = self._get_uniform_logits(batch_size * num_beams, vocab_size)
no_repeat_proc_2_gram = EncoderNoRepeatNGramLogitsProcessor(2, encoder_input_ids=encoder_input_ids)
no_repeat_proc_3_gram = EncoderNoRepeatNGramLogitsProcessor(3, encoder_input_ids=encoder_input_ids)
filtered_scores_2_gram = no_repeat_proc_2_gram(input_ids, scores.clone())
filtered_scores_3_gram = no_repeat_proc_3_gram(input_ids, scores.clone())
# 2-gram would forbid 1st and 2nd token at 1st beam and 1st token (0) at 2nd beam
self.assertListEqual(torch.isinf(filtered_scores_2_gram).tolist(), [[False, True, True], [False, True, False]])
# 3-gram would forbid 1st token at 1st beam and no token at 2nd beam
self.assertListEqual(
torch.isinf(filtered_scores_3_gram).tolist(), [[False, True, False], [False, False, False]]
)
# Batched input
vocab_size = 3
num_beams = 2
batch_size = 2
encoder_input_ids = torch.tensor([[1, 2, 1, 1], [0, 0, 2, 1]], device=torch_device, dtype=torch.long)
input_ids = torch.tensor([[1, 2, 1], [1, 0, 2], [0, 0, 0], [0, 2, 2]], device=torch_device, dtype=torch.long)
scores = self._get_uniform_logits(batch_size * num_beams, vocab_size)
no_repeat_proc_2_gram = EncoderNoRepeatNGramLogitsProcessor(2, encoder_input_ids=encoder_input_ids)
no_repeat_proc_3_gram = EncoderNoRepeatNGramLogitsProcessor(3, encoder_input_ids=encoder_input_ids)
filtered_scores_2_gram = no_repeat_proc_2_gram(input_ids, scores.clone())
filtered_scores_3_gram = no_repeat_proc_3_gram(input_ids, scores.clone())
# 2gram
# Batch 1
# - Beam 1: tokens (1, 2) forbidden
# - Beam 2: tokens (1) forbidden
# Batch 2
# - Beam 1: tokens (0, 2) forbidden
# - Beam 2: tokens (1) forbidden
self.assertListEqual(
torch.isinf(filtered_scores_2_gram).tolist(),
[[False, True, True], [False, True, False], [True, False, True], [False, True, False]],
)
# Batch 1
# - Beam 1: tokens (1) forbidden
# - Beam 2: tokens () forbidden
# Batch 2
# - Beam 1: tokens (2) forbidden
# - Beam 2: tokens () forbidden
self.assertListEqual(
torch.isinf(filtered_scores_3_gram).tolist(),
[[False, True, False], [False, False, False], [False, False, True], [False, False, False]],
)
def test_no_bad_words_dist_processor(self):
vocab_size = 5
batch_size = 2
eos_token_id = 4
input_ids = torch.tensor([[0, 1, 3, 1], [0, 1, 0, 1]], device=torch_device, dtype=torch.long)
bad_word_tokens = [[1], [4], [1, 0], [0, 1, 2], [1, 3, 1, 3]]
scores = self._get_uniform_logits(batch_size, vocab_size)
no_bad_words_dist_proc = NoBadWordsLogitsProcessor(bad_words_ids=bad_word_tokens, eos_token_id=eos_token_id)
filtered_scores = no_bad_words_dist_proc(input_ids, scores.clone())
# batch 1: 1st, 2nd, and 4th (0, 1, 3) token are forbidden
# batch 2: 1st, 2nd, and 3rd (0, 1, 2) token are forbidden
# Note that 5th element cannot be forbidden as it is EOS token
self.assertListEqual(
torch.isinf(filtered_scores).tolist(), [[True, True, False, True, False], [True, True, True, False, False]]
)
# check edge case
no_bad_words_dist_proc = NoBadWordsLogitsProcessor(bad_words_ids=[[4]], eos_token_id=eos_token_id)
filtered_scores = no_bad_words_dist_proc(input_ids, scores.clone())
self.assertTrue(torch.allclose(scores, filtered_scores, atol=1e-3))
def test_processor_list(self):
batch_size = 4
sequence_length = 10
vocab_size = 15
eos_token_id = 0
# dummy input_ids and scores
input_ids = ids_tensor((batch_size, sequence_length), vocab_size)
input_ids_comp = input_ids.clone()
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_comp = scores.clone()
# instantiate all dist processors
min_dist_proc = MinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
temp_dist_warp = TemperatureLogitsWarper(temperature=0.5)
rep_penalty_proc = RepetitionPenaltyLogitsProcessor(penalty=2.0)
top_k_warp = TopKLogitsWarper(3)
top_p_warp = TopPLogitsWarper(0.8)
no_repeat_proc = NoRepeatNGramLogitsProcessor(2)
no_bad_words_dist_proc = NoBadWordsLogitsProcessor(bad_words_ids=[[1]], eos_token_id=eos_token_id)
# no processor list
scores = min_dist_proc(input_ids, scores)
scores = temp_dist_warp(input_ids, scores)
scores = rep_penalty_proc(input_ids, scores)
scores = top_k_warp(input_ids, scores)
scores = top_p_warp(input_ids, scores)
scores = no_repeat_proc(input_ids, scores)
scores = no_bad_words_dist_proc(input_ids, scores)
# with processor list
processor = LogitsProcessorList(
[
min_dist_proc,
temp_dist_warp,
rep_penalty_proc,
top_k_warp,
top_p_warp,
no_repeat_proc,
no_bad_words_dist_proc,
]
)
scores_comp = processor(input_ids, scores_comp)
# scores should be equal
self.assertTrue(torch.allclose(scores, scores_comp, atol=1e-3))
# input_ids should never be changed
self.assertListEqual(input_ids.tolist(), input_ids_comp.tolist())
def test_prefix_constrained_logits_processor(self):
vocab_size = 5
batch_size = 2
input_ids = torch.tensor([[0, 1, 3, 1], [0, 1, 0, 1]], device=torch_device, dtype=torch.long)
scores = self._get_uniform_logits(batch_size, vocab_size)
def prefix_allowed_tokens_fn(batch_id, inputs_ids):
return [[0, 1], [2, 3]][batch_id]
prefix_constrained_logits_proc = PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, 1)
filtered_scores = prefix_constrained_logits_proc(input_ids, scores.clone())
# batch 1: 1st, 2nd (0, 1) token are allowed
# batch 2: 3rd, 4th (2, 3) token are allowed
self.assertListEqual(
torch.isinf(filtered_scores).tolist(), [[False, False, True, True, True], [True, True, False, False, True]]
)
def test_hamming_diversity(self):
vocab_size = 4
num_beams = 2
num_beam_groups = 2
scores = self._get_uniform_logits(num_beams, vocab_size)
# batch_idx = 0 -> index batch_idx * num_beam_groups -> idx = 0 * 2 = 0 -> penalises tokens 1
# batch_idx = 1 -> index batch_idx * num_beam_groups -> idx = 1 * 2 = 2 -> penalises tokens 1
current_tokens = torch.tensor([0, 3, 1, 2], device=torch_device, dtype=torch.long)
diversity_logits_processor = HammingDiversityLogitsProcessor(
diversity_penalty=1.0, num_beams=num_beams, num_beam_groups=num_beam_groups
)
processed_scores = diversity_logits_processor(None, scores, current_tokens, 1)
self.assertTrue(
torch.allclose(
processed_scores[0], torch.tensor([-0.7500, 0.2500, 0.2500, 0.2500], device=torch_device), atol=1e-3
)
)
self.assertTrue(
torch.allclose(
processed_scores[1], torch.tensor([0.2500, -0.7500, 0.2500, 0.2500], device=torch_device), atol=1e-3
)
)
def test_forced_bos_token_logits_processor(self):
vocab_size = 20
batch_size = 4
bos_token_id = 0
logits_processor = ForcedBOSTokenLogitsProcessor(bos_token_id=bos_token_id)
# check that all scores are -inf except the bos_token_id score
input_ids = ids_tensor((batch_size, 1), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores = logits_processor(input_ids, scores)
self.assertTrue(torch.isneginf(scores[:, bos_token_id + 1 :]).all())
self.assertListEqual(scores[:, bos_token_id].tolist(), 4 * [0]) # score for bos_token_id shold be zero
# check that bos_token_id is not forced if current length is greater than 1
input_ids = ids_tensor((batch_size, 4), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores = logits_processor(input_ids, scores)
self.assertFalse(torch.isinf(scores).any())
def test_forced_eos_token_logits_processor(self):
vocab_size = 20
batch_size = 4
eos_token_id = 0
max_length = 5
logits_processor = ForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id)
# check that all scores are -inf except the eos_token_id when max_length is reached
input_ids = ids_tensor((batch_size, 4), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores = logits_processor(input_ids, scores)
self.assertTrue(torch.isneginf(scores[:, eos_token_id + 1 :]).all())
self.assertListEqual(scores[:, eos_token_id].tolist(), 4 * [0]) # score for eos_token_id should be zero
# check that eos_token_id is not forced if max_length is not reached
input_ids = ids_tensor((batch_size, 3), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores = logits_processor(input_ids, scores)
self.assertFalse(torch.isinf(scores).any())
def test_remove_nan_inf_logits_processor(self):
scores = torch.tensor(
[[0.0, 0.7, 0.8, float("nan")], [0.1, float("inf"), 0.3, float("-inf")]], device=torch_device
)
input_ids = ids_tensor((2, 4), vocab_size=20)
logits_processor = InfNanRemoveLogitsProcessor()
scores = logits_processor(input_ids, scores)
self.assertTrue(
torch.allclose(
scores,
torch.tensor(
[[0.0, 0.7, 0.8, 0.0], [0.1, torch.finfo(scores.dtype).max, 0.3, float("-inf")]],
device=torch_device,
),
atol=1e-6,
)
)

View File

@@ -0,0 +1,109 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a clone of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, torch_device
from ..test_modeling_common import ids_tensor
if is_torch_available():
import torch
from transformers.generation_stopping_criteria import (
MaxLengthCriteria,
MaxNewTokensCriteria,
MaxTimeCriteria,
StoppingCriteriaList,
validate_stopping_criteria,
)
@require_torch
class StoppingCriteriaTestCase(unittest.TestCase):
def _get_tensors(self, length):
batch_size = 3
vocab_size = 250
input_ids = ids_tensor((batch_size, length), vocab_size)
scores = torch.ones((batch_size, length), device=torch_device, dtype=torch.float) / length
return input_ids, scores
def test_list_criteria(self):
input_ids, scores = self._get_tensors(5)
criteria = StoppingCriteriaList(
[
MaxLengthCriteria(max_length=10),
MaxTimeCriteria(max_time=0.1),
]
)
self.assertFalse(criteria(input_ids, scores))
input_ids, scores = self._get_tensors(9)
self.assertFalse(criteria(input_ids, scores))
input_ids, scores = self._get_tensors(10)
self.assertTrue(criteria(input_ids, scores))
def test_max_length_criteria(self):
criteria = MaxLengthCriteria(max_length=10)
input_ids, scores = self._get_tensors(5)
self.assertFalse(criteria(input_ids, scores))
input_ids, scores = self._get_tensors(9)
self.assertFalse(criteria(input_ids, scores))
input_ids, scores = self._get_tensors(10)
self.assertTrue(criteria(input_ids, scores))
def test_max_new_tokens_criteria(self):
criteria = MaxNewTokensCriteria(start_length=5, max_new_tokens=5)
input_ids, scores = self._get_tensors(5)
self.assertFalse(criteria(input_ids, scores))
input_ids, scores = self._get_tensors(9)
self.assertFalse(criteria(input_ids, scores))
input_ids, scores = self._get_tensors(10)
self.assertTrue(criteria(input_ids, scores))
criteria_list = StoppingCriteriaList([criteria])
self.assertEqual(criteria_list.max_length, 10)
def test_max_time_criteria(self):
input_ids, scores = self._get_tensors(5)
criteria = MaxTimeCriteria(max_time=0.1)
self.assertFalse(criteria(input_ids, scores))
criteria = MaxTimeCriteria(max_time=0.1, initial_timestamp=time.time() - 0.2)
self.assertTrue(criteria(input_ids, scores))
def test_validate_stopping_criteria(self):
validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 10)
with self.assertWarns(UserWarning):
validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 11)
stopping_criteria = validate_stopping_criteria(StoppingCriteriaList(), 11)
self.assertEqual(len(stopping_criteria), 1)

View File

@@ -0,0 +1,172 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a clone of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import is_tf_available
from transformers.testing_utils import require_tf
if is_tf_available():
import tensorflow as tf
from transformers.generation_tf_logits_process import (
TFLogitsProcessorList,
TFMinLengthLogitsProcessor,
TFNoBadWordsLogitsProcessor,
TFNoRepeatNGramLogitsProcessor,
TFRepetitionPenaltyLogitsProcessor,
)
from transformers.tf_utils import set_tensor_by_indices_to_value
from ..test_modeling_tf_common import ids_tensor
@require_tf
class TFLogitsProcessorTest(unittest.TestCase):
def _get_uniform_logits(self, batch_size: int, length: int):
scores = tf.ones((batch_size, length), dtype=tf.float32) / length
return scores
def test_min_length_dist_processor(self):
vocab_size = 20
batch_size = 4
eos_token_id = 0
min_dist_processor = TFMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
# check that min length is applied at length 5
input_ids = ids_tensor((batch_size, 5), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_before_min_length = min_dist_processor(input_ids, scores)
self.assertListEqual(scores_before_min_length[:, eos_token_id].numpy().tolist(), 4 * [-float("inf")])
# check that min length is not applied anymore at length 15
input_ids = ids_tensor((batch_size, 15), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_before_min_length = min_dist_processor(input_ids, scores)
self.assertFalse(tf.math.reduce_any(tf.math.is_inf(scores_before_min_length)).numpy())
def test_repetition_penalty_dist_process(self):
input_ids = tf.constant([[0, 1], [5, 0]], dtype=tf.int32)
vocab_size = 10
scores = self._get_uniform_logits(batch_size=2, length=vocab_size)
mask = tf.cast(tf.constant([[1] + 9 * [0], 10 * [0]]), tf.bool)
scores = set_tensor_by_indices_to_value(scores, mask, -1 / vocab_size)
mask = tf.cast(tf.constant([10 * [0], 5 * [0] + [1] + 4 * [0]]), tf.bool)
scores = set_tensor_by_indices_to_value(scores, mask, 4 / vocab_size)
rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0)
scores = rep_penalty_proc(input_ids, tf.identity(scores))
# check that values were correctly changed
self.assertAlmostEqual(scores[0, 0].numpy(), -(1 / vocab_size) * 2)
self.assertAlmostEqual(scores[0, 1].numpy(), (1 / vocab_size) / 2)
self.assertAlmostEqual(scores[1, 0].numpy(), (1 / vocab_size) / 2)
self.assertAlmostEqual(scores[1, 5].numpy(), (4 / vocab_size) / 2)
def test_no_repeat_ngram_dist_processor(self):
vocab_size = 3
batch_size = 2
input_ids = tf.constant([[1, 1, 2, 1], [0, 1, 0, 1]], dtype=tf.int32)
scores = self._get_uniform_logits(batch_size, vocab_size)
no_repeat_proc_2_gram = TFNoRepeatNGramLogitsProcessor(2)
no_repeat_proc_3_gram = TFNoRepeatNGramLogitsProcessor(3)
filtered_scores_2_gram = no_repeat_proc_2_gram(input_ids, tf.identity(scores))
filtered_scores_3_gram = no_repeat_proc_3_gram(input_ids, tf.identity(scores))
# 2-gram would forbid 2nd and 3rd token (1,2) at 1st batch and 1st token (0) at 2nd batch
self.assertListEqual(
tf.math.is_inf(filtered_scores_2_gram).numpy().tolist(), [[False, True, True], [True, False, False]]
)
# 3-gram would forbid no token at 1st batch and 1st token (0) at 2nd batch
self.assertListEqual(
tf.math.is_inf(filtered_scores_3_gram).numpy().tolist(), [[False, False, False], [True, False, False]]
)
def test_no_bad_words_dist_processor(self):
vocab_size = 5
batch_size = 2
eos_token_id = 4
input_ids = tf.constant([[0, 1, 3, 1], [0, 1, 0, 1]], dtype=tf.int32)
bad_word_tokens = [[1], [4], [1, 0], [0, 1, 2], [1, 3, 1, 3]]
scores = self._get_uniform_logits(batch_size, vocab_size)
no_bad_words_dist_proc = TFNoBadWordsLogitsProcessor(bad_words_ids=bad_word_tokens, eos_token_id=eos_token_id)
filtered_scores = no_bad_words_dist_proc(input_ids, tf.identity(scores))
# batch 1: 1st, 2nd, and 4th (0, 1, 3) token are forbidden
# batch 2: 1st, 2nd, and 3rd (0, 1, 2) token are forbidden
self.assertListEqual(
tf.math.is_inf(filtered_scores).numpy().tolist(),
[[True, True, False, True, True], [True, True, True, False, True]],
)
def test_processor_list(self):
batch_size = 4
sequence_length = 10
vocab_size = 15
eos_token_id = 0
# dummy input_ids and scores
input_ids = ids_tensor((batch_size, sequence_length), vocab_size)
input_ids_comp = tf.identity(input_ids)
scores = self._get_uniform_logits(batch_size, vocab_size)
scores_comp = tf.identity(scores)
# instantiate all dist processors
min_dist_proc = TFMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
rep_penalty_proc = TFRepetitionPenaltyLogitsProcessor(penalty=2.0)
no_repeat_proc = TFNoRepeatNGramLogitsProcessor(2)
no_bad_words_dist_proc = TFNoBadWordsLogitsProcessor(bad_words_ids=[[1]], eos_token_id=eos_token_id)
# no processor list
scores = min_dist_proc(input_ids, scores)
scores = rep_penalty_proc(input_ids, scores)
scores = no_repeat_proc(input_ids, scores)
scores = no_bad_words_dist_proc(input_ids, scores)
# with processor list
processor = TFLogitsProcessorList(
[
min_dist_proc,
rep_penalty_proc,
no_repeat_proc,
no_bad_words_dist_proc,
]
)
scores_comp = processor(input_ids, scores_comp)
# remove inf
scores = set_tensor_by_indices_to_value(scores, tf.math.is_inf(scores), -1e9)
scores_comp = set_tensor_by_indices_to_value(scores_comp, tf.math.is_inf(scores_comp), -1e9)
# scores should be equal
tf.debugging.assert_near(scores, scores_comp, atol=1e-3)
# input_ids should never be changed
self.assertListEqual(input_ids.numpy().tolist(), input_ids_comp.numpy().tolist())

File diff suppressed because it is too large Load Diff