Refactoring the generate() function (#6949)
* first draft * show design proposition for new generate method * up * make better readable * make first version * gpt2 tests pass * make beam search for gpt2 work * add first encoder-decoder code * delete typo * make t5 work * save indermediate * make bart work with beam search * finish beam search bart / t5 * add default kwargs * make more tests pass * fix no bad words sampler * some fixes and tests for all distribution processors * fix test * fix rag slow tests * merge to master * add nograd to generate * make all slow tests pass * speed up generate * fix edge case bug * small fix * correct typo * add type hints and docstrings * fix typos in tests * add beam search tests * add tests for beam scorer * fix test rag * finish beam search tests * move generation tests in seperate file * fix generation tests * more tests * add aggressive generation tests * fix tests * add gpt2 sample test * add more docstring * add more docs * finish doc strings * apply some more of sylvains and sams comments * fix some typos * make fix copies * apply lysandres and sylvains comments * final corrections on examples * small fix for reformer
This commit is contained in:
committed by
GitHub
parent
b63beb743c
commit
a1bbcf3f6c
239
tests/test_generation_beam_search.py
Normal file
239
tests/test_generation_beam_search.py
Normal file
@@ -0,0 +1,239 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Team Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a clone of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, torch_device
|
||||
|
||||
from .test_modeling_common import floats_tensor, ids_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.generation_beam_search import BeamHypotheses, BeamSearchScorer
|
||||
|
||||
|
||||
class BeamSearchTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=3,
|
||||
sequence_length=10,
|
||||
vocab_size=99,
|
||||
pad_token_id=0,
|
||||
max_length=20,
|
||||
num_beams=4,
|
||||
length_penalty=2.0,
|
||||
do_early_stopping=True,
|
||||
num_beam_hyps_to_keep=2,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.sequence_length = sequence_length
|
||||
self.vocab_size = vocab_size
|
||||
self.pad_token_id = pad_token_id
|
||||
self.max_length = max_length
|
||||
self.num_beams = num_beams
|
||||
self.length_penalty = length_penalty
|
||||
self.do_early_stopping = do_early_stopping
|
||||
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
|
||||
|
||||
# cannot be randomely generated
|
||||
self.eos_token_id = vocab_size + 1
|
||||
|
||||
def prepare_beam_scorer(self, **kwargs):
|
||||
return BeamSearchScorer(
|
||||
batch_size=kwargs.get("batch_size", self.batch_size),
|
||||
max_length=kwargs.get("max_length", self.max_length),
|
||||
num_beams=kwargs.get("num_beams", self.num_beams),
|
||||
device=torch_device,
|
||||
length_penalty=kwargs.get("length_penalty", self.length_penalty),
|
||||
do_early_stopping=kwargs.get("do_early_stopping", self.do_early_stopping),
|
||||
num_beam_hyps_to_keep=kwargs.get("num_beam_hyps_to_keep", self.num_beam_hyps_to_keep),
|
||||
)
|
||||
|
||||
def prepare_inputs(self):
|
||||
input_ids = ids_tensor((self.batch_size * self.num_beams, self.sequence_length), self.vocab_size)
|
||||
next_tokens = ids_tensor((self.batch_size, 2 * self.num_beams), self.vocab_size).to(torch_device)
|
||||
next_indices = ids_tensor((self.batch_size, 2 * self.num_beams), self.num_beams).to(torch_device)
|
||||
next_scores, _ = (-floats_tensor((self.batch_size, 2 * self.num_beams)).to(torch_device)).sort(descending=True)
|
||||
return (input_ids, next_tokens, next_indices, next_scores)
|
||||
|
||||
def check_beam_hypotheses(self, input_ids, *args):
|
||||
# check that correct number of beam hypotheses is set in beam scorer
|
||||
beam_scorer = self.prepare_beam_scorer(do_early_stopping=True)
|
||||
beam_hyp = beam_scorer._beam_hyps[0]
|
||||
|
||||
self.parent.assertEqual(len(beam_scorer._beam_hyps), self.batch_size)
|
||||
|
||||
# check correct type
|
||||
self.parent.assertTrue(isinstance(beam_hyp, BeamHypotheses))
|
||||
|
||||
# check that num_beams is correctly set
|
||||
self.parent.assertEqual(beam_hyp.num_beams, self.num_beams)
|
||||
|
||||
# check for early stopping deactivated
|
||||
for beam_idx in range(self.num_beams):
|
||||
beam_hyp.add(input_ids[beam_idx], -10.0)
|
||||
|
||||
# if early stopping True -> score does not matter
|
||||
self.parent.assertTrue(beam_hyp.is_done(-10.0, 5))
|
||||
|
||||
# re-init
|
||||
beam_scorer = self.prepare_beam_scorer(do_early_stopping=False)
|
||||
beam_hyp = beam_scorer._beam_hyps[0]
|
||||
|
||||
# add `num_beams + 1` beams to change `worst_score`
|
||||
for beam_idx in range(self.num_beams + 1):
|
||||
beam_hyp.add(input_ids[beam_idx], -10.0 + float(beam_idx))
|
||||
|
||||
# -10.0 is removed => -9.0 is worst score
|
||||
self.parent.assertAlmostEqual(beam_hyp.worst_score, -9.0 / (self.sequence_length ** beam_hyp.length_penalty))
|
||||
|
||||
# -5.0 is better than worst score => should not be finished
|
||||
self.parent.assertFalse(beam_hyp.is_done(-5.0, self.sequence_length))
|
||||
|
||||
# -20.0 is worse than worst score => should be finished
|
||||
self.parent.assertTrue(beam_hyp.is_done(-20.0, self.sequence_length))
|
||||
|
||||
def check_beam_scorer_update(self, input_ids, next_tokens, next_indices, next_scores):
|
||||
# check too many eos tokens
|
||||
beam_scorer = self.prepare_beam_scorer()
|
||||
|
||||
tokens = next_tokens.clone()
|
||||
tokens[0, :] = self.eos_token_id
|
||||
|
||||
with self.parent.assertRaises(ValueError):
|
||||
beam_scorer.process(input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id)
|
||||
|
||||
# check all batches are done
|
||||
beam_scorer = self.prepare_beam_scorer()
|
||||
|
||||
tokens = next_tokens.clone()
|
||||
tokens[:, : self.num_beams] = self.eos_token_id
|
||||
beam_scorer.process(input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id)
|
||||
# beam scorer should be done
|
||||
self.parent.assertTrue(beam_scorer.is_done)
|
||||
|
||||
# check
|
||||
beam_scorer = self.prepare_beam_scorer()
|
||||
|
||||
tokens = next_tokens.clone()
|
||||
tokens[:, 1] = self.eos_token_id
|
||||
beam_outputs = beam_scorer.process(
|
||||
input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id
|
||||
)
|
||||
output_scores = beam_outputs["next_beam_scores"]
|
||||
output_tokens = beam_outputs["next_beam_tokens"]
|
||||
output_indices = beam_outputs["next_beam_indices"]
|
||||
|
||||
def cut_expected_tensor(tensor):
|
||||
return torch.cat([tensor[:, :1], tensor[:, 2 : self.num_beams + 1]], dim=1).flatten()
|
||||
|
||||
# check all outptus
|
||||
# cut out id of eos token and take best `num_beams` outputs
|
||||
expected_output_tokens = cut_expected_tensor(tokens)
|
||||
expected_output_scores = cut_expected_tensor(next_scores)
|
||||
|
||||
# add num_beams * batch_idx
|
||||
expected_output_indices = (
|
||||
cut_expected_tensor(next_indices)
|
||||
+ (torch.arange(self.num_beams * self.batch_size, device=torch_device) // self.num_beams) * self.num_beams
|
||||
)
|
||||
|
||||
self.parent.assertListEqual(expected_output_tokens.tolist(), output_tokens.tolist())
|
||||
self.parent.assertListEqual(expected_output_indices.tolist(), output_indices.tolist())
|
||||
self.parent.assertTrue(torch.allclose(expected_output_scores, output_scores, atol=1e-3))
|
||||
|
||||
# make sure ids of eos token are correctly saved in beam_hyps of beam scorer
|
||||
for batch_idx in range(self.batch_size):
|
||||
correct_idx = batch_idx * self.num_beams + next_indices[batch_idx, 1]
|
||||
self.parent.assertListEqual(
|
||||
input_ids[correct_idx].tolist(), beam_scorer._beam_hyps[batch_idx].beams[0][-1].tolist()
|
||||
)
|
||||
|
||||
def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_scores):
|
||||
# max_length should be only one more than current input_ids to check that eos is correctly appended
|
||||
max_length = self.sequence_length + 1
|
||||
beam_scorer = self.prepare_beam_scorer(
|
||||
num_beam_hyps_to_keep=1, max_length=max_length, length_penalty=1.0, do_early_stopping=False
|
||||
)
|
||||
|
||||
# update beams and append to input_ids
|
||||
tokens = next_tokens.clone()
|
||||
# first batch, first output has to finish with eos token id since scores are correctly sorted
|
||||
tokens[0, 0] = self.eos_token_id
|
||||
# make sure corresponding score is as good as possible to surely be picked first
|
||||
next_scores[0, 0] = 0.0
|
||||
beam_outputs = beam_scorer.process(
|
||||
input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id
|
||||
)
|
||||
output_scores = beam_outputs["next_beam_scores"]
|
||||
output_tokens = beam_outputs["next_beam_tokens"]
|
||||
output_indices = beam_outputs["next_beam_indices"]
|
||||
|
||||
input_ids = torch.cat([input_ids[output_indices, :], output_tokens.unsqueeze(-1)], dim=-1)
|
||||
|
||||
# finalize
|
||||
decoded = beam_scorer.finalize(
|
||||
input_ids,
|
||||
output_scores,
|
||||
output_tokens,
|
||||
output_indices,
|
||||
pad_token_id=self.pad_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
)
|
||||
# since `num_beam_hyps_to_keep` = 1 => only return `batch_size` x `max_length`
|
||||
self.parent.assertListEqual(list(decoded.shape), [self.batch_size, max_length])
|
||||
|
||||
# first batch has to finish with eos_token
|
||||
self.parent.assertEqual(decoded[0, -1].item(), self.eos_token_id)
|
||||
|
||||
# other batches cannot finish with eos token
|
||||
self.parent.assertNotEqual(decoded[1, -1].item(), self.eos_token_id)
|
||||
self.parent.assertNotEqual(decoded[2, -1].item(), self.eos_token_id)
|
||||
|
||||
# now test that if `num_beam_hyps_to_keep` is 3 => all beams are returned
|
||||
beam_scorer.num_beam_hyps_to_keep = self.num_beams
|
||||
decoded = beam_scorer.finalize(
|
||||
input_ids,
|
||||
output_scores,
|
||||
output_tokens,
|
||||
output_indices,
|
||||
pad_token_id=self.pad_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
)
|
||||
self.parent.assertListEqual(list(decoded.shape), [self.num_beams * self.batch_size, max_length])
|
||||
|
||||
|
||||
@require_torch
|
||||
class BeamSearchTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.beam_search_tester = BeamSearchTester(self)
|
||||
|
||||
def test_beam_hypotheses(self):
|
||||
inputs = self.beam_search_tester.prepare_inputs()
|
||||
self.beam_search_tester.check_beam_hypotheses(*inputs)
|
||||
|
||||
def test_beam_scorer_update(self):
|
||||
inputs = self.beam_search_tester.prepare_inputs()
|
||||
self.beam_search_tester.check_beam_scorer_update(*inputs)
|
||||
|
||||
def test_beam_scorer_finalize(self):
|
||||
inputs = self.beam_search_tester.prepare_inputs()
|
||||
self.beam_search_tester.check_beam_scores_finalize(*inputs)
|
||||
283
tests/test_generation_logits_process.py
Normal file
283
tests/test_generation_logits_process.py
Normal file
@@ -0,0 +1,283 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Team Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a clone of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, torch_device
|
||||
|
||||
from .test_modeling_common import ids_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers.generation_logits_process import (
|
||||
LogitsProcessorList,
|
||||
MinLengthLogitsProcessor,
|
||||
NoBadWordsLogitsProcessor,
|
||||
NoRepeatNGramLogitsProcessor,
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
class LogitsProcessorTest(unittest.TestCase):
|
||||
def _get_uniform_logits(self, batch_size: int, length: int):
|
||||
scores = torch.ones((batch_size, length), device=torch_device, dtype=torch.float) / length
|
||||
return scores
|
||||
|
||||
def test_min_lenght_dist_processor(self):
|
||||
vocab_size = 20
|
||||
batch_size = 4
|
||||
eos_token_id = 0
|
||||
|
||||
min_dist_processor = MinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
|
||||
|
||||
# check that min length is applied at length 5
|
||||
input_ids = ids_tensor((batch_size, 5), vocab_size=20)
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores_before_min_length = min_dist_processor(input_ids, scores)
|
||||
self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), 4 * [-float("inf")])
|
||||
|
||||
# check that min length is not applied anymore at length 15
|
||||
input_ids = ids_tensor((batch_size, 15), vocab_size=20)
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores_before_min_length = min_dist_processor(input_ids, scores)
|
||||
self.assertFalse(torch.isinf(scores_before_min_length).any())
|
||||
|
||||
def test_temperature_dist_warper(self):
|
||||
input_ids = None
|
||||
length = 20
|
||||
|
||||
scores = self._get_uniform_logits(batch_size=2, length=length)
|
||||
|
||||
# tweak scores to not be uniform anymore
|
||||
scores[1, 5] = (1 / length) + 0.1 # peak, 1st batch
|
||||
scores[1, 10] = (1 / length) - 0.4 # valley, 1st batch
|
||||
|
||||
# compute softmax
|
||||
probs = F.softmax(scores, dim=-1)
|
||||
|
||||
temp_dist_warper_sharper = TemperatureLogitsWarper(temperature=0.5)
|
||||
temp_dist_warper_smoother = TemperatureLogitsWarper(temperature=1.3)
|
||||
|
||||
warped_prob_sharp = F.softmax(temp_dist_warper_sharper(input_ids, scores.clone()), dim=-1)
|
||||
warped_prob_smooth = F.softmax(temp_dist_warper_smoother(input_ids, scores.clone()), dim=-1)
|
||||
|
||||
# uniform distribution stays uniform
|
||||
self.assertTrue(torch.allclose(probs[0, :], warped_prob_sharp[0, :], atol=1e-3))
|
||||
self.assertTrue(torch.allclose(probs[0, :], warped_prob_smooth[0, :], atol=1e-3))
|
||||
|
||||
# sharp peaks get higher, valleys get lower
|
||||
self.assertLess(probs[1, :].max(), warped_prob_sharp[1, :].max())
|
||||
self.assertGreater(probs[1, :].min(), warped_prob_sharp[1, :].min())
|
||||
|
||||
# smooth peaks get lower, valleys get higher
|
||||
self.assertGreater(probs[1, :].max(), warped_prob_smooth[1, :].max())
|
||||
self.assertLess(probs[1, :].min(), warped_prob_smooth[1, :].min())
|
||||
|
||||
def test_repetition_penalty_dist_process(self):
|
||||
input_ids = torch.tensor([[0, 1], [5, 0]], device=torch_device, dtype=torch.long)
|
||||
vocab_size = 10
|
||||
|
||||
scores = self._get_uniform_logits(batch_size=2, length=vocab_size)
|
||||
|
||||
# give values special values
|
||||
scores[0, 0] = -(1 / vocab_size)
|
||||
scores[1, 5] = 4 / vocab_size
|
||||
|
||||
rep_penalty_proc = RepetitionPenaltyLogitsProcessor(penalty=2.0)
|
||||
|
||||
scores = rep_penalty_proc(input_ids, scores.clone())
|
||||
|
||||
# check that values were correctly changed
|
||||
self.assertAlmostEqual(scores[0, 0].item(), -(1 / vocab_size) * 2)
|
||||
self.assertAlmostEqual(scores[0, 1].item(), (1 / vocab_size) / 2)
|
||||
|
||||
self.assertAlmostEqual(scores[1, 0].item(), (1 / vocab_size) / 2)
|
||||
self.assertAlmostEqual(scores[1, 5].item(), (4 / vocab_size) / 2)
|
||||
|
||||
def test_top_k_dist_warper(self):
|
||||
input_ids = None
|
||||
vocab_size = 10
|
||||
batch_size = 2
|
||||
|
||||
# create ramp distribution
|
||||
ramp_logits = (
|
||||
torch.arange(vocab_size, device=torch_device, dtype=torch.float).unsqueeze(0).repeat(batch_size, 1)
|
||||
)
|
||||
ramp_logits[1:, : vocab_size // 2] = ramp_logits[1:, : vocab_size // 2] + vocab_size
|
||||
|
||||
top_k_warp = TopKLogitsWarper(3)
|
||||
|
||||
scores = top_k_warp(input_ids, ramp_logits)
|
||||
|
||||
# check that correct tokens are filtered
|
||||
self.assertListEqual(torch.isinf(scores[0]).tolist(), 7 * [True] + 3 * [False])
|
||||
self.assertListEqual(torch.isinf(scores[1]).tolist(), 2 * [True] + 3 * [False] + 5 * [True])
|
||||
|
||||
# check special cases
|
||||
length = 5
|
||||
|
||||
logits = self._get_uniform_logits(batch_size=batch_size, length=length)
|
||||
top_k_warp_safety_check = TopKLogitsWarper(top_k=1, filter_value=0.0, min_tokens_to_keep=3)
|
||||
|
||||
scores = top_k_warp_safety_check(input_ids, logits)
|
||||
# uniform dist is not changed
|
||||
self.assertListEqual((scores == 0.0).to(torch.long).sum(dim=-1).tolist(), [0, 0])
|
||||
|
||||
ramp_logits = torch.arange(length, device=torch_device, dtype=torch.float).unsqueeze(0).repeat(batch_size, 1)
|
||||
scores = top_k_warp_safety_check(input_ids, ramp_logits)
|
||||
|
||||
# min_tokens overwrites k: 3 tokens are kept => 2 tokens are nullified
|
||||
self.assertListEqual((scores == 0.0).to(torch.long).sum(dim=-1).tolist(), [2, 2])
|
||||
|
||||
def test_top_p_dist_warper(self):
|
||||
input_ids = None
|
||||
vocab_size = 10
|
||||
batch_size = 2
|
||||
|
||||
# create distribution and take log (inverse to Softmax as taken in TopPLogitsWarper)
|
||||
dist = torch.log(
|
||||
torch.tensor([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]], device=torch_device, dtype=torch.float)
|
||||
)
|
||||
|
||||
top_p_warp = TopPLogitsWarper(0.7)
|
||||
filtered_dist = torch.exp(top_p_warp(input_ids, dist))
|
||||
|
||||
# dist should be filtered to keep min num values so that sum is >= 0.7
|
||||
# exp (-inf) => 0
|
||||
EXPECTED_FILTERED_DIST = torch.tensor(
|
||||
[[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]], device=torch_device, dtype=torch.float
|
||||
)
|
||||
self.assertTrue(torch.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3))
|
||||
|
||||
# check edge cases with negative and extreme logits
|
||||
ramp_logits = torch.arange(vocab_size, device=torch_device, dtype=torch.float).unsqueeze(0).repeat(
|
||||
batch_size, 1
|
||||
) - (vocab_size // 2)
|
||||
|
||||
# make ramp_logits more extreme
|
||||
ramp_logits[1] = ramp_logits[1] * 100.0
|
||||
|
||||
# make sure at least 2 tokens are kept
|
||||
top_p_warp = TopPLogitsWarper(0.9, min_tokens_to_keep=2, filter_value=0.0)
|
||||
filtered_dist = top_p_warp(input_ids, ramp_logits)
|
||||
|
||||
# first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
|
||||
self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [3, 2])
|
||||
|
||||
def test_no_repeat_ngram_dist_processor(self):
|
||||
vocab_size = 3
|
||||
batch_size = 2
|
||||
|
||||
input_ids = torch.tensor([[1, 1, 2, 1], [0, 1, 0, 1]], device=torch_device, dtype=torch.long)
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
|
||||
no_repeat_proc_2_gram = NoRepeatNGramLogitsProcessor(2)
|
||||
no_repeat_proc_3_gram = NoRepeatNGramLogitsProcessor(3)
|
||||
|
||||
filtered_scores_2_gram = no_repeat_proc_2_gram(input_ids, scores.clone())
|
||||
filtered_scores_3_gram = no_repeat_proc_3_gram(input_ids, scores.clone())
|
||||
|
||||
# 2-gram would forbid 2nd and 3rd token (1,2) at 1st batch and 1st token (0) at 2nd batch
|
||||
self.assertListEqual(torch.isinf(filtered_scores_2_gram).tolist(), [[False, True, True], [True, False, False]])
|
||||
|
||||
# 3-gram would forbid no token at 1st batch and 1st token (0) at 2nd batch
|
||||
self.assertListEqual(
|
||||
torch.isinf(filtered_scores_3_gram).tolist(), [[False, False, False], [True, False, False]]
|
||||
)
|
||||
|
||||
def test_no_bad_words_dist_processor(self):
|
||||
vocab_size = 5
|
||||
batch_size = 2
|
||||
eos_token_id = 4
|
||||
|
||||
input_ids = torch.tensor([[0, 1, 3, 1], [0, 1, 0, 1]], device=torch_device, dtype=torch.long)
|
||||
bad_word_tokens = [[1], [4], [1, 0], [0, 1, 2], [1, 3, 1, 3]]
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
|
||||
no_bad_words_dist_proc = NoBadWordsLogitsProcessor(bad_words_ids=bad_word_tokens, eos_token_id=eos_token_id)
|
||||
|
||||
filtered_scores = no_bad_words_dist_proc(input_ids, scores.clone())
|
||||
|
||||
# batch 1: 1st, 2nd, and 4th (0, 1, 3) token are forbidden
|
||||
# batch 2: 1st, 2nd, and 3rd (0, 1, 2) token are forbidden
|
||||
# Note that 5th element cannot be forbidden as it is EOS token
|
||||
self.assertListEqual(
|
||||
torch.isinf(filtered_scores).tolist(), [[True, True, False, True, False], [True, True, True, False, False]]
|
||||
)
|
||||
|
||||
# check edge case
|
||||
no_bad_words_dist_proc = NoBadWordsLogitsProcessor(bad_words_ids=[[4]], eos_token_id=eos_token_id)
|
||||
filtered_scores = no_bad_words_dist_proc(input_ids, scores.clone())
|
||||
self.assertTrue(torch.allclose(scores, filtered_scores, atol=1e-3))
|
||||
|
||||
def test_processor_list(self):
|
||||
batch_size = 4
|
||||
sequence_length = 10
|
||||
vocab_size = 15
|
||||
eos_token_id = 0
|
||||
|
||||
# dummy input_ids and scores
|
||||
input_ids = ids_tensor((batch_size, sequence_length), vocab_size)
|
||||
input_ids_comp = input_ids.clone()
|
||||
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores_comp = scores.clone()
|
||||
|
||||
# instantiate all dist processors
|
||||
min_dist_proc = MinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
|
||||
temp_dist_warp = TemperatureLogitsWarper(temperature=0.5)
|
||||
rep_penalty_proc = RepetitionPenaltyLogitsProcessor(penalty=2.0)
|
||||
top_k_warp = TopKLogitsWarper(3)
|
||||
top_p_warp = TopPLogitsWarper(0.8)
|
||||
no_repeat_proc = NoRepeatNGramLogitsProcessor(2)
|
||||
no_bad_words_dist_proc = NoBadWordsLogitsProcessor(bad_words_ids=[[1]], eos_token_id=eos_token_id)
|
||||
|
||||
# no processor list
|
||||
scores = min_dist_proc(input_ids, scores)
|
||||
scores = temp_dist_warp(input_ids, scores)
|
||||
scores = rep_penalty_proc(input_ids, scores)
|
||||
scores = top_k_warp(input_ids, scores)
|
||||
scores = top_p_warp(input_ids, scores)
|
||||
scores = no_repeat_proc(input_ids, scores)
|
||||
scores = no_bad_words_dist_proc(input_ids, scores)
|
||||
|
||||
# with processor list
|
||||
processor = LogitsProcessorList(
|
||||
[
|
||||
min_dist_proc,
|
||||
temp_dist_warp,
|
||||
rep_penalty_proc,
|
||||
top_k_warp,
|
||||
top_p_warp,
|
||||
no_repeat_proc,
|
||||
no_bad_words_dist_proc,
|
||||
]
|
||||
)
|
||||
scores_comp = processor(input_ids, scores_comp)
|
||||
|
||||
# scores should be equal
|
||||
self.assertTrue(torch.allclose(scores, scores_comp, atol=1e-3))
|
||||
|
||||
# input_ids should never be changed
|
||||
self.assertListEqual(input_ids.tolist(), input_ids_comp.tolist())
|
||||
510
tests/test_generation_utils.py
Normal file
510
tests/test_generation_utils.py
Normal file
@@ -0,0 +1,510 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Team Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a clone of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, torch_device
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import top_k_top_p_filtering
|
||||
from transformers.generation_beam_search import BeamSearchScorer
|
||||
from transformers.generation_logits_process import (
|
||||
LogitsProcessorList,
|
||||
MinLengthLogitsProcessor,
|
||||
NoBadWordsLogitsProcessor,
|
||||
NoRepeatNGramLogitsProcessor,
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
)
|
||||
|
||||
|
||||
class GenerationTesterMixin:
|
||||
model_tester = None
|
||||
all_generative_model_classes = ()
|
||||
|
||||
def _get_input_ids_and_config(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
# cut to half length & take max batch_size 3
|
||||
max_batch_size = 2
|
||||
sequence_length = input_ids.shape[-1] // 2
|
||||
input_ids = input_ids[:max_batch_size, :sequence_length]
|
||||
attention_mask = attention_mask[:max_batch_size, :sequence_length]
|
||||
|
||||
# generate max 3 tokens
|
||||
max_length = input_ids.shape[-1] + 3
|
||||
if config.eos_token_id is not None and config.pad_token_id is None:
|
||||
# hack to allow generate for models such as GPT2 as is done in `generate()`
|
||||
config.pad_token_id = config.eos_token_id
|
||||
return config, input_ids, attention_mask, max_length
|
||||
|
||||
@staticmethod
|
||||
def _get_logits_processor_and_kwargs(input_length, eos_token_id):
|
||||
process_kwargs = {
|
||||
"min_length": input_length + 1,
|
||||
"bad_words_ids": [[1, 0]],
|
||||
"no_repeat_ngram_size": 2,
|
||||
"repetition_penalty": 1.2,
|
||||
}
|
||||
logits_processor = LogitsProcessorList(
|
||||
(
|
||||
[
|
||||
MinLengthLogitsProcessor(process_kwargs["min_length"], eos_token_id),
|
||||
]
|
||||
if eos_token_id is not None
|
||||
else []
|
||||
)
|
||||
+ [
|
||||
NoBadWordsLogitsProcessor(process_kwargs["bad_words_ids"], eos_token_id),
|
||||
NoRepeatNGramLogitsProcessor(process_kwargs["no_repeat_ngram_size"]),
|
||||
RepetitionPenaltyLogitsProcessor(process_kwargs["repetition_penalty"]),
|
||||
]
|
||||
)
|
||||
return process_kwargs, logits_processor
|
||||
|
||||
@staticmethod
|
||||
def _get_warper_and_kwargs(num_beams):
|
||||
warp_kwargs = {"top_k": 10, "top_p": 0.7, "temperature": 0.7}
|
||||
logits_warper = LogitsProcessorList(
|
||||
[
|
||||
TopKLogitsWarper(top_k=warp_kwargs["top_k"], min_tokens_to_keep=(2 if num_beams > 1 else 1)),
|
||||
TopPLogitsWarper(top_p=warp_kwargs["top_p"], min_tokens_to_keep=(2 if num_beams > 1 else 1)),
|
||||
TemperatureLogitsWarper(warp_kwargs["temperature"]),
|
||||
]
|
||||
)
|
||||
return warp_kwargs, logits_warper
|
||||
|
||||
@staticmethod
|
||||
def _get_beam_scorer_and_kwargs(batch_size, max_length, num_return_sequences=1):
|
||||
beam_kwargs = {
|
||||
"early_stopping": False,
|
||||
"length_penalty": 2.0,
|
||||
"num_beams": 2,
|
||||
"num_return_sequences": num_return_sequences,
|
||||
}
|
||||
beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
num_beams=beam_kwargs["num_beams"],
|
||||
device=torch_device,
|
||||
length_penalty=beam_kwargs["length_penalty"],
|
||||
do_early_stopping=beam_kwargs["early_stopping"],
|
||||
num_beam_hyps_to_keep=num_return_sequences,
|
||||
)
|
||||
return beam_kwargs, beam_scorer
|
||||
|
||||
@staticmethod
|
||||
def _get_encoder_outputs(model, input_ids, attention_mask, num_interleave=1):
|
||||
encoder = model.get_encoder()
|
||||
encoder_outputs = encoder(input_ids, attention_mask=attention_mask, return_dict=True)
|
||||
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave(
|
||||
num_interleave, dim=0
|
||||
)
|
||||
input_ids = torch.zeros_like(input_ids[:, :1]) + model._get_decoder_start_token_id()
|
||||
attention_mask = None
|
||||
return encoder_outputs, input_ids, attention_mask
|
||||
|
||||
def test_greedy_generate(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
|
||||
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
||||
input_ids.shape[-1], config.eos_token_id
|
||||
)
|
||||
|
||||
model = model_class(config).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# check `generate()` and `greedy_search()` are equal
|
||||
kwargs = {}
|
||||
if model.config.is_encoder_decoder:
|
||||
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
|
||||
model, input_ids, attention_mask
|
||||
)
|
||||
kwargs["encoder_outputs"] = encoder_outputs
|
||||
max_length = 4
|
||||
|
||||
output_ids_generate = model.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
do_sample=False,
|
||||
num_beams=1,
|
||||
max_length=max_length,
|
||||
**logits_process_kwargs,
|
||||
)
|
||||
with torch.no_grad():
|
||||
output_ids_greedy = model.greedy_search(
|
||||
input_ids,
|
||||
max_length=max_length,
|
||||
attention_mask=attention_mask,
|
||||
logits_processor=logits_processor,
|
||||
**kwargs,
|
||||
)
|
||||
self.assertListEqual(output_ids_generate.tolist(), output_ids_greedy.tolist())
|
||||
|
||||
def test_sample_generate(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
||||
input_ids.shape[-1], config.eos_token_id
|
||||
)
|
||||
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
|
||||
|
||||
model = model_class(config).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# check `generate()` and `sample()` are equal
|
||||
if model.config.is_encoder_decoder:
|
||||
max_length = 4
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_ids_generate = model.generate(
|
||||
input_ids,
|
||||
do_sample=True,
|
||||
num_beams=1,
|
||||
max_length=max_length,
|
||||
attention_mask=attention_mask,
|
||||
**logits_warper_kwargs,
|
||||
**process_kwargs,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
kwargs = {}
|
||||
if model.config.is_encoder_decoder:
|
||||
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs(
|
||||
model, input_ids, attention_mask
|
||||
)
|
||||
kwargs["encoder_outputs"] = encoder_outputs
|
||||
else:
|
||||
attention_mask_clone = attention_mask
|
||||
input_ids_clone = input_ids
|
||||
|
||||
with torch.no_grad():
|
||||
output_ids_sample = model.sample(
|
||||
input_ids_clone,
|
||||
attention_mask=attention_mask_clone,
|
||||
max_length=max_length,
|
||||
logits_processor=logits_processor,
|
||||
logits_warper=logits_warper,
|
||||
**kwargs,
|
||||
)
|
||||
self.assertListEqual(output_ids_generate.tolist(), output_ids_sample.tolist())
|
||||
|
||||
# check `generate()` and `sample()` yield equal results for `num_return_sequences`
|
||||
num_return_sequences = 3
|
||||
if model.config.is_encoder_decoder:
|
||||
max_length = 4
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_ids_generate = model.generate(
|
||||
input_ids,
|
||||
do_sample=True,
|
||||
num_beams=1,
|
||||
max_length=max_length,
|
||||
num_return_sequences=num_return_sequences,
|
||||
attention_mask=attention_mask,
|
||||
**logits_warper_kwargs,
|
||||
**process_kwargs,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
kwargs = {}
|
||||
if model.config.is_encoder_decoder:
|
||||
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs(
|
||||
model, input_ids, attention_mask, num_interleave=num_return_sequences
|
||||
)
|
||||
kwargs["encoder_outputs"] = encoder_outputs
|
||||
input_ids_clone = input_ids_clone.repeat_interleave(num_return_sequences, dim=0)
|
||||
else:
|
||||
attention_mask_clone = attention_mask.repeat_interleave(num_return_sequences, dim=0)
|
||||
input_ids_clone = input_ids.repeat_interleave(num_return_sequences, dim=0)
|
||||
|
||||
with torch.no_grad():
|
||||
output_ids_sample = model.sample(
|
||||
input_ids_clone,
|
||||
attention_mask=attention_mask_clone,
|
||||
max_length=max_length,
|
||||
logits_processor=logits_processor,
|
||||
logits_warper=logits_warper,
|
||||
**kwargs,
|
||||
)
|
||||
self.assertListEqual(output_ids_generate.tolist(), output_ids_sample.tolist())
|
||||
|
||||
def test_beam_search_generate(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
|
||||
logits_process_kwargs, logits_processor = self._get_logits_processor_and_kwargs(
|
||||
input_ids.shape[-1], config.eos_token_id
|
||||
)
|
||||
|
||||
model = model_class(config).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# check `generate()` and `beam_search()` are equal
|
||||
if model.config.is_encoder_decoder:
|
||||
max_length = 4
|
||||
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
|
||||
output_ids_generate = model.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
do_sample=False,
|
||||
max_length=max_length,
|
||||
**beam_kwargs,
|
||||
**logits_process_kwargs,
|
||||
)
|
||||
|
||||
# beam_search does not automatically interleave `batch_size` dim for `num_beams`
|
||||
kwargs = {}
|
||||
if model.config.is_encoder_decoder:
|
||||
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs(
|
||||
model, input_ids, attention_mask, num_interleave=beam_scorer.num_beams
|
||||
)
|
||||
kwargs["encoder_outputs"] = encoder_outputs
|
||||
input_ids_clone = input_ids_clone.repeat_interleave(beam_scorer.num_beams, dim=0)
|
||||
else:
|
||||
attention_mask_clone = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0)
|
||||
input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0)
|
||||
|
||||
with torch.no_grad():
|
||||
output_ids_beam_search = model.beam_search(
|
||||
input_ids_clone,
|
||||
beam_scorer,
|
||||
max_length=max_length,
|
||||
attention_mask=attention_mask_clone,
|
||||
logits_processor=logits_processor,
|
||||
**kwargs,
|
||||
)
|
||||
self.assertListEqual(output_ids_generate.tolist(), output_ids_beam_search.tolist())
|
||||
|
||||
# check `generate()` and `beam_search()` are equal for `num_return_sequences`
|
||||
num_return_sequences = 2
|
||||
if model.config.is_encoder_decoder:
|
||||
max_length = 4
|
||||
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(
|
||||
input_ids.shape[0], max_length, num_return_sequences=num_return_sequences
|
||||
)
|
||||
|
||||
output_ids_generate = model.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
do_sample=False,
|
||||
max_length=max_length,
|
||||
**beam_kwargs,
|
||||
**logits_process_kwargs,
|
||||
)
|
||||
# beam_search does not automatically interleave `batch_size` dim for `num_beams`
|
||||
kwargs = {}
|
||||
if model.config.is_encoder_decoder:
|
||||
encoder_outputs, input_ids_clone, attention_mask_clone = self._get_encoder_outputs(
|
||||
model, input_ids, attention_mask, num_interleave=beam_scorer.num_beams
|
||||
)
|
||||
kwargs["encoder_outputs"] = encoder_outputs
|
||||
input_ids_clone = input_ids_clone.repeat_interleave(beam_scorer.num_beams, dim=0)
|
||||
else:
|
||||
attention_mask_clone = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0)
|
||||
input_ids_clone = input_ids.repeat_interleave(beam_scorer.num_beams, dim=0)
|
||||
|
||||
with torch.no_grad():
|
||||
output_ids_beam_search = model.beam_search(
|
||||
input_ids_clone,
|
||||
beam_scorer,
|
||||
max_length=max_length,
|
||||
attention_mask=attention_mask_clone,
|
||||
logits_processor=logits_processor,
|
||||
**kwargs,
|
||||
)
|
||||
self.assertListEqual(output_ids_generate.tolist(), output_ids_beam_search.tolist())
|
||||
|
||||
def test_beam_sample_generate(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
|
||||
|
||||
model = model_class(config).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# check `generate()` and `beam_search()` are equal
|
||||
# change `num_return_sequences = 2` but not for `beam_scorer`
|
||||
num_return_sequences = 2
|
||||
if model.config.is_encoder_decoder:
|
||||
max_length = 4
|
||||
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(
|
||||
input_ids.shape[0] * num_return_sequences, max_length
|
||||
)
|
||||
beam_kwargs["num_return_sequences"] = num_return_sequences
|
||||
torch.manual_seed(0)
|
||||
output_ids_generate = model.generate(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
do_sample=True,
|
||||
max_length=max_length,
|
||||
**beam_kwargs,
|
||||
**logits_warper_kwargs,
|
||||
)
|
||||
# beam_search does not automatically interleave `batch_size` dim for `num_beams * num_return_sequences`
|
||||
kwargs = {}
|
||||
if model.config.is_encoder_decoder:
|
||||
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
|
||||
model, input_ids, attention_mask, num_interleave=beam_scorer.num_beams * num_return_sequences
|
||||
)
|
||||
kwargs["encoder_outputs"] = encoder_outputs
|
||||
else:
|
||||
attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0)
|
||||
|
||||
torch.manual_seed(0)
|
||||
with torch.no_grad():
|
||||
output_ids_beam_sample = model.beam_sample(
|
||||
input_ids.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0),
|
||||
beam_scorer,
|
||||
max_length=max_length,
|
||||
attention_mask=attention_mask,
|
||||
logits_warper=logits_warper,
|
||||
**kwargs,
|
||||
)
|
||||
self.assertListEqual(output_ids_generate.tolist(), output_ids_beam_sample.tolist())
|
||||
|
||||
def test_generate_without_input_ids(self):
|
||||
config, _, _, max_length = self._get_input_ids_and_config()
|
||||
|
||||
# if no bos token id => cannot generate from None
|
||||
if config.bos_token_id is None:
|
||||
return
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
output_ids_generate = model.generate(
|
||||
do_sample=False,
|
||||
max_length=max_length,
|
||||
)
|
||||
|
||||
self.assertIsNotNone(output_ids_generate)
|
||||
|
||||
|
||||
@require_torch
|
||||
class UtilsFunctionsTest(unittest.TestCase):
|
||||
|
||||
# tests whether the top_k_top_p function behaves as expected
|
||||
def test_top_k_top_p_filtering(self):
|
||||
logits = torch.tensor(
|
||||
[
|
||||
[
|
||||
8.2220991, # 3rd highest value; idx. 0
|
||||
-0.5620044,
|
||||
5.23229752,
|
||||
4.0386393,
|
||||
-6.8798378,
|
||||
-0.54785802,
|
||||
-3.2012153,
|
||||
2.92777176,
|
||||
1.88171953,
|
||||
7.35341276,
|
||||
8.43207833, # 2nd highest value; idx. 10
|
||||
-9.85711836,
|
||||
-5.96209236,
|
||||
-1.13039161,
|
||||
-7.1115294,
|
||||
-0.8369633,
|
||||
-5.3186408,
|
||||
7.06427407,
|
||||
0.81369344,
|
||||
-0.82023817,
|
||||
-5.9179796,
|
||||
0.58813443,
|
||||
-6.99778438,
|
||||
4.71551189,
|
||||
-0.18771637,
|
||||
7.44020759, # 4th highest value; idx. 25
|
||||
9.38450987, # 1st highest value; idx. 26
|
||||
2.12662941,
|
||||
-9.32562038,
|
||||
2.35652522,
|
||||
], # cummulative prob of 4 highest values <= 0.6
|
||||
[
|
||||
0.58425518,
|
||||
4.53139238,
|
||||
-5.57510464,
|
||||
-6.28030699,
|
||||
-7.19529503,
|
||||
-4.02122551,
|
||||
1.39337037,
|
||||
-6.06707057,
|
||||
1.59480517,
|
||||
-9.643119,
|
||||
0.03907799,
|
||||
0.67231762,
|
||||
-8.88206726,
|
||||
6.27115922, # 4th highest value; idx. 13
|
||||
2.28520723,
|
||||
4.82767506,
|
||||
4.30421368,
|
||||
8.8275313, # 2nd highest value; idx. 17
|
||||
5.44029958,
|
||||
-4.4735794,
|
||||
7.38579536, # 3rd highest value; idx. 20
|
||||
-2.91051663,
|
||||
2.61946077,
|
||||
-2.5674762,
|
||||
-9.48959302,
|
||||
-4.02922645,
|
||||
-1.35416918,
|
||||
9.67702323, # 1st highest value; idx. 27
|
||||
-5.89478553,
|
||||
1.85370467,
|
||||
], # cummulative prob of 4 highest values <= 0.6
|
||||
],
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
non_inf_expected_idx = torch.tensor(
|
||||
[[0, 0], [0, 10], [0, 25], [0, 26], [1, 13], [1, 17], [1, 20], [1, 27]],
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
) # expected non filtered idx as noted above
|
||||
|
||||
non_inf_expected_output = torch.tensor(
|
||||
[
|
||||
8.2221,
|
||||
8.4321,
|
||||
7.4402,
|
||||
9.3845,
|
||||
6.2712,
|
||||
8.8275,
|
||||
7.3858,
|
||||
9.6770,
|
||||
], # expected non filtered values as noted above
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
output = top_k_top_p_filtering(logits, top_k=10, top_p=0.6, min_tokens_to_keep=4)
|
||||
non_inf_output = output[output != -float("inf")].to(device=torch_device)
|
||||
non_inf_idx = (output != -float("inf")).nonzero().to(device=torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(non_inf_expected_output, non_inf_output, atol=1e-12))
|
||||
self.assertTrue(torch.all(torch.eq(non_inf_expected_idx, non_inf_idx)))
|
||||
@@ -23,6 +23,7 @@ from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_generation_utils import GenerationTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
|
||||
|
||||
@@ -128,7 +129,7 @@ def prepare_bart_inputs_dict(
|
||||
|
||||
|
||||
@require_torch
|
||||
class BARTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
class BARTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(BartModel, BartForConditionalGeneration, BartForSequenceClassification, BartForQuestionAnswering)
|
||||
if is_torch_available()
|
||||
|
||||
@@ -20,6 +20,7 @@ from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_generation_utils import GenerationTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
@@ -357,11 +358,12 @@ class BertModelTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
class BertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (
|
||||
(
|
||||
BertModel,
|
||||
BertLMHeadModel,
|
||||
BertForMaskedLM,
|
||||
BertForMultipleChoice,
|
||||
BertForNextSentencePrediction,
|
||||
@@ -373,6 +375,7 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else ()
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = BertModelTester(self)
|
||||
|
||||
@@ -20,6 +20,7 @@ from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_generation_utils import GenerationTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
@@ -183,9 +184,10 @@ class BertGenerationEncoderTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
class BertGenerationEncoderTest(ModelTesterMixin, unittest.TestCase):
|
||||
class BertGenerationEncoderTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (BertGenerationEncoder, BertGenerationDecoder) if is_torch_available() else ()
|
||||
all_generative_model_classes = (BertGenerationDecoder,) if is_torch_available() else ()
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = BertGenerationEncoderTester(self)
|
||||
|
||||
@@ -147,6 +147,7 @@ class Blenderbot3BIntegrationTests(unittest.TestCase):
|
||||
|
||||
src_text = ["Sam"]
|
||||
model_inputs = self.tokenizer(src_text, return_tensors="pt").to(torch_device)
|
||||
|
||||
generated_utterances = model.generate(**model_inputs, **FASTER_GEN_KWARGS)
|
||||
tgt_text = 'Sam is a great name. It means "sun" in Gaelic.'
|
||||
|
||||
@@ -156,6 +157,7 @@ class Blenderbot3BIntegrationTests(unittest.TestCase):
|
||||
src_text = "Social anxiety\nWow, I am never shy. Do you have anxiety?\nYes. I end up sweating and blushing and feel like i'm going to throw up.\nand why is that?"
|
||||
|
||||
model_inputs = self.tokenizer([src_text], return_tensors="pt").to(torch_device)
|
||||
|
||||
generated_ids = model.generate(**model_inputs, **FASTER_GEN_KWARGS)[0]
|
||||
reply = self.tokenizer.decode(generated_ids, **TOK_DECODE_KW)
|
||||
|
||||
@@ -187,6 +189,9 @@ class Blenderbot90MIntegrationTests(unittest.TestCase):
|
||||
]
|
||||
|
||||
model_inputs = self.tokenizer(src_text, return_tensors="pt").to(torch_device)
|
||||
|
||||
# model does not have "token_type_ids"
|
||||
model_inputs.pop("token_type_ids")
|
||||
assert isinstance(self.tokenizer, BlenderbotSmallTokenizer)
|
||||
generated_ids = self.model.generate(**model_inputs)[0]
|
||||
reply = self.tokenizer.decode(generated_ids, **TOK_DECODE_KW)
|
||||
@@ -198,10 +203,11 @@ class Blenderbot90MIntegrationTests(unittest.TestCase):
|
||||
|
||||
def test_90_generation_from_short_input(self):
|
||||
model_inputs = self.tokenizer(["sam"], return_tensors="pt").to(torch_device)
|
||||
generated_utterances = self.model.generate(**model_inputs)
|
||||
# generated_txt = self.tokenizer.decode(generated_utterances[0])
|
||||
|
||||
# assert generated_txt == "__start__ have you ever heard of sam harris? he's an american singer, songwriter, and actor. __end__"
|
||||
# model does not have "token_type_ids"
|
||||
model_inputs.pop("token_type_ids")
|
||||
generated_utterances = self.model.generate(**model_inputs)
|
||||
|
||||
clean_txt = self.tokenizer.decode(generated_utterances[0], **TOK_DECODE_KW)
|
||||
assert clean_txt in (
|
||||
"have you ever been to a sam club? it's a great club in the south.",
|
||||
|
||||
@@ -44,7 +44,6 @@ if is_torch_available():
|
||||
BertModel,
|
||||
PretrainedConfig,
|
||||
PreTrainedModel,
|
||||
top_k_top_p_filtering,
|
||||
)
|
||||
|
||||
|
||||
@@ -882,126 +881,6 @@ class ModelTesterMixin:
|
||||
with torch.no_grad():
|
||||
model(**inputs)[0]
|
||||
|
||||
def test_lm_head_model_random_no_beam_search_generate(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]
|
||||
|
||||
# make sure that input_ids is at most of size 15
|
||||
input_ids = input_ids[..., :15]
|
||||
|
||||
# iterate over all generative models
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
if config.bos_token_id is None:
|
||||
# if bos token id is not defined, model needs input_ids
|
||||
with self.assertRaises(AssertionError):
|
||||
model.generate(do_sample=True, max_length=5)
|
||||
# num_return_sequences = 1
|
||||
self._check_generated_ids(model.generate(input_ids, do_sample=True))
|
||||
else:
|
||||
# num_return_sequences = 1
|
||||
self._check_generated_ids(model.generate(do_sample=True, max_length=5))
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
# generating multiple sequences when no beam search generation
|
||||
# is not allowed as it would always generate the same sequences
|
||||
model.generate(input_ids, do_sample=False, num_beams=1, num_return_sequences=2)
|
||||
|
||||
# num_return_sequences > 1, sample
|
||||
self._check_generated_ids(model.generate(input_ids, do_sample=True, num_return_sequences=2))
|
||||
|
||||
# check bad words tokens language generation
|
||||
# create list of 1-seq bad token and list of 2-seq of bad tokens
|
||||
bad_words_ids = [
|
||||
self._generate_random_bad_tokens(1, model.config),
|
||||
self._generate_random_bad_tokens(2, model.config),
|
||||
]
|
||||
output_tokens = model.generate(
|
||||
input_ids, do_sample=True, bad_words_ids=bad_words_ids, num_return_sequences=2
|
||||
)
|
||||
# only count generated tokens
|
||||
generated_ids = output_tokens[:, input_ids.shape[-1] :]
|
||||
self.assertFalse(self._check_match_tokens(generated_ids.tolist(), bad_words_ids))
|
||||
|
||||
def test_lm_head_model_random_beam_search_generate(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = (inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]).to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
# make sure that input_ids is at most of size 15
|
||||
input_ids = input_ids[..., :15]
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
if config.bos_token_id is None:
|
||||
# if bos token id is not defined mobel needs input_ids, num_return_sequences = 1
|
||||
self._check_generated_ids(model.generate(input_ids, do_sample=True, num_beams=2))
|
||||
else:
|
||||
# num_return_sequences = 1
|
||||
self._check_generated_ids(model.generate(do_sample=True, max_length=5, num_beams=2))
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
# generating more sequences than having beams leads is not possible
|
||||
model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2)
|
||||
|
||||
# num_return_sequences > 1, sample
|
||||
self._check_generated_ids(
|
||||
model.generate(
|
||||
input_ids,
|
||||
do_sample=True,
|
||||
num_beams=2,
|
||||
num_return_sequences=2,
|
||||
)
|
||||
)
|
||||
# num_return_sequences > 1, greedy
|
||||
self._check_generated_ids(model.generate(input_ids, do_sample=False, num_beams=2, num_return_sequences=2))
|
||||
|
||||
# check bad words tokens language generation
|
||||
# create list of 1-seq bad token and list of 2-seq of bad tokens
|
||||
bad_words_ids = [
|
||||
self._generate_random_bad_tokens(1, model.config),
|
||||
self._generate_random_bad_tokens(2, model.config),
|
||||
]
|
||||
output_tokens = model.generate(
|
||||
input_ids, do_sample=False, bad_words_ids=bad_words_ids, num_beams=2, num_return_sequences=2
|
||||
)
|
||||
# only count generated tokens
|
||||
generated_ids = output_tokens[:, input_ids.shape[-1] :]
|
||||
self.assertFalse(self._check_match_tokens(generated_ids.tolist(), bad_words_ids))
|
||||
|
||||
def _generate_random_bad_tokens(self, num_bad_tokens: int, config) -> List[int]:
|
||||
# special tokens cannot be bad tokens
|
||||
special_tokens = [x for x in [config.bos_token_id, config.eos_token_id, config.pad_token_id] if x is not None]
|
||||
# create random bad tokens that are not special tokens
|
||||
bad_tokens = []
|
||||
while len(bad_tokens) < num_bad_tokens:
|
||||
token = ids_tensor((1, 1), self.model_tester.vocab_size).squeeze(0).cpu().numpy()[0]
|
||||
if token not in special_tokens:
|
||||
bad_tokens.append(token)
|
||||
return bad_tokens
|
||||
|
||||
def _check_generated_ids(self, output_ids):
|
||||
for token_id in output_ids[0].tolist():
|
||||
self.assertGreaterEqual(token_id, 0)
|
||||
self.assertLess(token_id, self.model_tester.vocab_size)
|
||||
|
||||
def _check_match_tokens(self, generated_ids, bad_words_ids):
|
||||
# for all bad word tokens
|
||||
for bad_word_ids in bad_words_ids:
|
||||
# for all slices in batch
|
||||
for generated_ids_slice in generated_ids:
|
||||
# for all word idx
|
||||
for i in range(len(bad_word_ids), len(generated_ids_slice)):
|
||||
# if tokens match
|
||||
if generated_ids_slice[i - len(bad_word_ids) : i] == bad_word_ids:
|
||||
return True
|
||||
return False
|
||||
|
||||
@require_torch_multigpu
|
||||
def test_multigpu_data_parallel_forward(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@@ -1094,110 +973,3 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
model = BertModel.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
|
||||
self.assertEqual(model.config.output_hidden_states, True)
|
||||
self.assertEqual(model.config, config)
|
||||
|
||||
|
||||
@require_torch
|
||||
class UtilsFunctionsTest(unittest.TestCase):
|
||||
|
||||
# tests whether the top_k_top_p function behaves as expected
|
||||
def test_top_k_top_p_filtering(self):
|
||||
logits = torch.tensor(
|
||||
[
|
||||
[
|
||||
8.2220991, # 3rd highest value; idx. 0
|
||||
-0.5620044,
|
||||
5.23229752,
|
||||
4.0386393,
|
||||
-6.8798378,
|
||||
-0.54785802,
|
||||
-3.2012153,
|
||||
2.92777176,
|
||||
1.88171953,
|
||||
7.35341276, # 5th highest value; idx. 9
|
||||
8.43207833, # 2nd highest value; idx. 10
|
||||
-9.85711836,
|
||||
-5.96209236,
|
||||
-1.13039161,
|
||||
-7.1115294,
|
||||
-0.8369633,
|
||||
-5.3186408,
|
||||
7.06427407,
|
||||
0.81369344,
|
||||
-0.82023817,
|
||||
-5.9179796,
|
||||
0.58813443,
|
||||
-6.99778438,
|
||||
4.71551189,
|
||||
-0.18771637,
|
||||
7.44020759, # 4th highest value; idx. 25
|
||||
9.38450987, # 1st highest value; idx. 26
|
||||
2.12662941,
|
||||
-9.32562038,
|
||||
2.35652522,
|
||||
], # cumulative prob of 5 highest values <= 0.6
|
||||
[
|
||||
0.58425518,
|
||||
4.53139238,
|
||||
-5.57510464,
|
||||
-6.28030699,
|
||||
-7.19529503,
|
||||
-4.02122551,
|
||||
1.39337037,
|
||||
-6.06707057,
|
||||
1.59480517,
|
||||
-9.643119,
|
||||
0.03907799,
|
||||
0.67231762,
|
||||
-8.88206726,
|
||||
6.27115922, # 4th highest value; idx. 13
|
||||
2.28520723,
|
||||
4.82767506,
|
||||
4.30421368,
|
||||
8.8275313, # 2nd highest value; idx. 17
|
||||
5.44029958, # 5th highest value; idx. 18
|
||||
-4.4735794,
|
||||
7.38579536, # 3rd highest value; idx. 20
|
||||
-2.91051663,
|
||||
2.61946077,
|
||||
-2.5674762,
|
||||
-9.48959302,
|
||||
-4.02922645,
|
||||
-1.35416918,
|
||||
9.67702323, # 1st highest value; idx. 27
|
||||
-5.89478553,
|
||||
1.85370467,
|
||||
], # cumulative prob of 5 highest values <= 0.6
|
||||
],
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
non_inf_expected_idx = torch.tensor(
|
||||
[[0, 0], [0, 9], [0, 10], [0, 25], [0, 26], [1, 13], [1, 17], [1, 18], [1, 20], [1, 27]],
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
) # expected non filtered idx as noted above
|
||||
|
||||
non_inf_expected_output = torch.tensor(
|
||||
[
|
||||
8.2221,
|
||||
7.3534,
|
||||
8.4321,
|
||||
7.4402,
|
||||
9.3845,
|
||||
6.2712,
|
||||
8.8275,
|
||||
5.4403,
|
||||
7.3858,
|
||||
9.6770,
|
||||
], # expected non filtered values as noted above
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
output = top_k_top_p_filtering(logits, top_k=10, top_p=0.6, min_tokens_to_keep=4)
|
||||
non_inf_output = output[output != -float("inf")].to(device=torch_device)
|
||||
non_inf_idx = (output != -float("inf")).nonzero().to(device=torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(non_inf_expected_output, non_inf_output, atol=1e-12))
|
||||
self.assertTrue(torch.all(torch.eq(non_inf_expected_idx, non_inf_idx)))
|
||||
|
||||
@@ -19,6 +19,7 @@ from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_generation_utils import GenerationTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
@@ -151,7 +152,7 @@ class CTRLModelTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
class CTRLModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
class CTRLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (CTRLModel, CTRLLMHeadModel) if is_torch_available() else ()
|
||||
all_generative_model_classes = (CTRLLMHeadModel,) if is_torch_available() else ()
|
||||
|
||||
@@ -24,6 +24,7 @@ from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_generation_utils import GenerationTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
|
||||
|
||||
@@ -120,7 +121,7 @@ def prepare_fsmt_inputs_dict(
|
||||
|
||||
|
||||
@require_torch
|
||||
class FSMTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (FSMTModel, FSMTForConditionalGeneration) if is_torch_available() else ()
|
||||
all_generative_model_classes = (FSMTForConditionalGeneration,) if is_torch_available() else ()
|
||||
is_encoder_decoder = True
|
||||
|
||||
@@ -20,6 +20,7 @@ from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_generation_utils import GenerationTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
@@ -377,7 +378,7 @@ class GPT2ModelTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (
|
||||
(GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel, GPT2ForSequenceClassification)
|
||||
@@ -510,32 +511,17 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||
|
||||
@slow
|
||||
def test_lm_generate_distilgpt2(self):
|
||||
model = GPT2LMHeadModel.from_pretrained("distilgpt2")
|
||||
def test_gpt2_sample(self):
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
model = GPT2LMHeadModel.from_pretrained("gpt2")
|
||||
model.to(torch_device)
|
||||
input_ids = torch.tensor([[464, 1893]], dtype=torch.long, device=torch_device) # The president
|
||||
expected_output_ids = [
|
||||
464,
|
||||
1893,
|
||||
286,
|
||||
262,
|
||||
1578,
|
||||
1829,
|
||||
11,
|
||||
290,
|
||||
262,
|
||||
1893,
|
||||
286,
|
||||
262,
|
||||
1578,
|
||||
7526,
|
||||
11,
|
||||
423,
|
||||
587,
|
||||
287,
|
||||
262,
|
||||
2635,
|
||||
] # The president of the United States, and the president of the United Kingdom, have been in the White
|
||||
|
||||
output_ids = model.generate(input_ids, do_sample=False)
|
||||
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||
torch.manual_seed(0)
|
||||
input_ids = tokenizer("Today is a nice day and", return_tensors="pt").input_ids.to(torch_device)
|
||||
output_ids = model.generate(input_ids, do_sample=True)
|
||||
output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
||||
|
||||
EXPECTED_OUTPUT_STR = (
|
||||
"Today is a nice day and if you don't know anything about the state of play during your holiday"
|
||||
)
|
||||
self.assertEqual(output_str, EXPECTED_OUTPUT_STR)
|
||||
|
||||
@@ -20,6 +20,7 @@ from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_generation_utils import GenerationTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
|
||||
|
||||
@@ -170,7 +171,7 @@ class OpenAIGPTModelTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
class OpenAIGPTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
class OpenAIGPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (
|
||||
(OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel, OpenAIGPTForSequenceClassification)
|
||||
|
||||
@@ -22,6 +22,7 @@ from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_generation_utils import GenerationTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
|
||||
|
||||
@@ -853,7 +854,7 @@ class ProphetNetStandaloneEncoderModelTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
class ProphetNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (ProphetNetModel, ProphetNetForConditionalGeneration) if is_torch_available() else ()
|
||||
all_generative_model_classes = (ProphetNetForConditionalGeneration,) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
@@ -917,7 +918,7 @@ class ProphetNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
|
||||
@require_torch
|
||||
class ProphetNetStandaloneDecoderModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
class ProphetNetStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (ProphetNetDecoder, ProphetNetForCausalLM) if is_torch_available() else ()
|
||||
all_generative_model_classes = (ProphetNetForCausalLM,) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
|
||||
@@ -26,6 +26,7 @@ from transformers.testing_utils import (
|
||||
)
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_generation_utils import GenerationTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
@@ -196,11 +197,14 @@ class ReformerModelTester:
|
||||
)
|
||||
|
||||
def create_and_check_reformer_model_with_lm_backward(self, config, input_ids, input_mask, choice_labels):
|
||||
if not self.is_training:
|
||||
return
|
||||
|
||||
config.is_decoder = False
|
||||
config.lsh_num_chunks_after = 1
|
||||
model = ReformerForMaskedLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
model.train()
|
||||
loss = model(input_ids, attention_mask=input_mask, labels=input_ids)["loss"]
|
||||
loss.backward()
|
||||
|
||||
@@ -569,7 +573,7 @@ class ReformerTesterMixin:
|
||||
|
||||
|
||||
@require_torch
|
||||
class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.TestCase):
|
||||
class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(ReformerModel, ReformerModelWithLMHead, ReformerForSequenceClassification, ReformerForQuestionAnswering)
|
||||
if is_torch_available()
|
||||
@@ -629,7 +633,7 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest
|
||||
|
||||
|
||||
@require_torch
|
||||
class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.TestCase):
|
||||
class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(ReformerModel, ReformerModelWithLMHead, ReformerForSequenceClassification, ReformerForQuestionAnswering)
|
||||
if is_torch_available()
|
||||
|
||||
@@ -20,6 +20,7 @@ from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_generation_utils import GenerationTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
@@ -267,7 +268,7 @@ class RobertaModelTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
class RobertaModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (
|
||||
(
|
||||
@@ -282,6 +283,7 @@ class RobertaModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (RobertaForCausalLM,) if is_torch_available() else ()
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = RobertaModelTester(self)
|
||||
|
||||
@@ -23,6 +23,7 @@ from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_generation_utils import GenerationTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
|
||||
|
||||
@@ -466,7 +467,7 @@ class T5ModelTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
|
||||
all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else ()
|
||||
@@ -592,6 +593,7 @@ class T5ModelIntegrationTests(unittest.TestCase):
|
||||
do_sample=False,
|
||||
early_stopping=True,
|
||||
)
|
||||
|
||||
decoded = tok.batch_decode(hypotheses_batch, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
self.assertListEqual(
|
||||
expected_summaries,
|
||||
|
||||
@@ -20,6 +20,7 @@ from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, require_torch_multigpu, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_generation_utils import GenerationTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
|
||||
|
||||
@@ -156,7 +157,7 @@ class TransfoXLModelTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (TransfoXLModel, TransfoXLLMHeadModel) if is_torch_available() else ()
|
||||
all_generative_model_classes = (TransfoXLLMHeadModel,) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
|
||||
@@ -20,6 +20,7 @@ from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_generation_utils import GenerationTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
@@ -331,7 +332,7 @@ class XLMModelTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
class XLMModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
class XLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (
|
||||
(
|
||||
|
||||
@@ -21,6 +21,7 @@ from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_generation_utils import GenerationTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
@@ -479,7 +480,7 @@ class XLNetModelTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(
|
||||
XLNetModel,
|
||||
|
||||
Reference in New Issue
Block a user