Refactoring the generate() function (#6949)
* first draft * show design proposition for new generate method * up * make better readable * make first version * gpt2 tests pass * make beam search for gpt2 work * add first encoder-decoder code * delete typo * make t5 work * save indermediate * make bart work with beam search * finish beam search bart / t5 * add default kwargs * make more tests pass * fix no bad words sampler * some fixes and tests for all distribution processors * fix test * fix rag slow tests * merge to master * add nograd to generate * make all slow tests pass * speed up generate * fix edge case bug * small fix * correct typo * add type hints and docstrings * fix typos in tests * add beam search tests * add tests for beam scorer * fix test rag * finish beam search tests * move generation tests in seperate file * fix generation tests * more tests * add aggressive generation tests * fix tests * add gpt2 sample test * add more docstring * add more docs * finish doc strings * apply some more of sylvains and sams comments * fix some typos * make fix copies * apply lysandres and sylvains comments * final corrections on examples * small fix for reformer
This commit is contained in:
committed by
GitHub
parent
b63beb743c
commit
a1bbcf3f6c
@@ -44,7 +44,6 @@ if is_torch_available():
|
||||
BertModel,
|
||||
PretrainedConfig,
|
||||
PreTrainedModel,
|
||||
top_k_top_p_filtering,
|
||||
)
|
||||
|
||||
|
||||
@@ -882,126 +881,6 @@ class ModelTesterMixin:
|
||||
with torch.no_grad():
|
||||
model(**inputs)[0]
|
||||
|
||||
def test_lm_head_model_random_no_beam_search_generate(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]
|
||||
|
||||
# make sure that input_ids is at most of size 15
|
||||
input_ids = input_ids[..., :15]
|
||||
|
||||
# iterate over all generative models
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
if config.bos_token_id is None:
|
||||
# if bos token id is not defined, model needs input_ids
|
||||
with self.assertRaises(AssertionError):
|
||||
model.generate(do_sample=True, max_length=5)
|
||||
# num_return_sequences = 1
|
||||
self._check_generated_ids(model.generate(input_ids, do_sample=True))
|
||||
else:
|
||||
# num_return_sequences = 1
|
||||
self._check_generated_ids(model.generate(do_sample=True, max_length=5))
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
# generating multiple sequences when no beam search generation
|
||||
# is not allowed as it would always generate the same sequences
|
||||
model.generate(input_ids, do_sample=False, num_beams=1, num_return_sequences=2)
|
||||
|
||||
# num_return_sequences > 1, sample
|
||||
self._check_generated_ids(model.generate(input_ids, do_sample=True, num_return_sequences=2))
|
||||
|
||||
# check bad words tokens language generation
|
||||
# create list of 1-seq bad token and list of 2-seq of bad tokens
|
||||
bad_words_ids = [
|
||||
self._generate_random_bad_tokens(1, model.config),
|
||||
self._generate_random_bad_tokens(2, model.config),
|
||||
]
|
||||
output_tokens = model.generate(
|
||||
input_ids, do_sample=True, bad_words_ids=bad_words_ids, num_return_sequences=2
|
||||
)
|
||||
# only count generated tokens
|
||||
generated_ids = output_tokens[:, input_ids.shape[-1] :]
|
||||
self.assertFalse(self._check_match_tokens(generated_ids.tolist(), bad_words_ids))
|
||||
|
||||
def test_lm_head_model_random_beam_search_generate(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = (inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]).to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
# make sure that input_ids is at most of size 15
|
||||
input_ids = input_ids[..., :15]
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
if config.bos_token_id is None:
|
||||
# if bos token id is not defined mobel needs input_ids, num_return_sequences = 1
|
||||
self._check_generated_ids(model.generate(input_ids, do_sample=True, num_beams=2))
|
||||
else:
|
||||
# num_return_sequences = 1
|
||||
self._check_generated_ids(model.generate(do_sample=True, max_length=5, num_beams=2))
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
# generating more sequences than having beams leads is not possible
|
||||
model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2)
|
||||
|
||||
# num_return_sequences > 1, sample
|
||||
self._check_generated_ids(
|
||||
model.generate(
|
||||
input_ids,
|
||||
do_sample=True,
|
||||
num_beams=2,
|
||||
num_return_sequences=2,
|
||||
)
|
||||
)
|
||||
# num_return_sequences > 1, greedy
|
||||
self._check_generated_ids(model.generate(input_ids, do_sample=False, num_beams=2, num_return_sequences=2))
|
||||
|
||||
# check bad words tokens language generation
|
||||
# create list of 1-seq bad token and list of 2-seq of bad tokens
|
||||
bad_words_ids = [
|
||||
self._generate_random_bad_tokens(1, model.config),
|
||||
self._generate_random_bad_tokens(2, model.config),
|
||||
]
|
||||
output_tokens = model.generate(
|
||||
input_ids, do_sample=False, bad_words_ids=bad_words_ids, num_beams=2, num_return_sequences=2
|
||||
)
|
||||
# only count generated tokens
|
||||
generated_ids = output_tokens[:, input_ids.shape[-1] :]
|
||||
self.assertFalse(self._check_match_tokens(generated_ids.tolist(), bad_words_ids))
|
||||
|
||||
def _generate_random_bad_tokens(self, num_bad_tokens: int, config) -> List[int]:
|
||||
# special tokens cannot be bad tokens
|
||||
special_tokens = [x for x in [config.bos_token_id, config.eos_token_id, config.pad_token_id] if x is not None]
|
||||
# create random bad tokens that are not special tokens
|
||||
bad_tokens = []
|
||||
while len(bad_tokens) < num_bad_tokens:
|
||||
token = ids_tensor((1, 1), self.model_tester.vocab_size).squeeze(0).cpu().numpy()[0]
|
||||
if token not in special_tokens:
|
||||
bad_tokens.append(token)
|
||||
return bad_tokens
|
||||
|
||||
def _check_generated_ids(self, output_ids):
|
||||
for token_id in output_ids[0].tolist():
|
||||
self.assertGreaterEqual(token_id, 0)
|
||||
self.assertLess(token_id, self.model_tester.vocab_size)
|
||||
|
||||
def _check_match_tokens(self, generated_ids, bad_words_ids):
|
||||
# for all bad word tokens
|
||||
for bad_word_ids in bad_words_ids:
|
||||
# for all slices in batch
|
||||
for generated_ids_slice in generated_ids:
|
||||
# for all word idx
|
||||
for i in range(len(bad_word_ids), len(generated_ids_slice)):
|
||||
# if tokens match
|
||||
if generated_ids_slice[i - len(bad_word_ids) : i] == bad_word_ids:
|
||||
return True
|
||||
return False
|
||||
|
||||
@require_torch_multigpu
|
||||
def test_multigpu_data_parallel_forward(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@@ -1094,110 +973,3 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
model = BertModel.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
|
||||
self.assertEqual(model.config.output_hidden_states, True)
|
||||
self.assertEqual(model.config, config)
|
||||
|
||||
|
||||
@require_torch
|
||||
class UtilsFunctionsTest(unittest.TestCase):
|
||||
|
||||
# tests whether the top_k_top_p function behaves as expected
|
||||
def test_top_k_top_p_filtering(self):
|
||||
logits = torch.tensor(
|
||||
[
|
||||
[
|
||||
8.2220991, # 3rd highest value; idx. 0
|
||||
-0.5620044,
|
||||
5.23229752,
|
||||
4.0386393,
|
||||
-6.8798378,
|
||||
-0.54785802,
|
||||
-3.2012153,
|
||||
2.92777176,
|
||||
1.88171953,
|
||||
7.35341276, # 5th highest value; idx. 9
|
||||
8.43207833, # 2nd highest value; idx. 10
|
||||
-9.85711836,
|
||||
-5.96209236,
|
||||
-1.13039161,
|
||||
-7.1115294,
|
||||
-0.8369633,
|
||||
-5.3186408,
|
||||
7.06427407,
|
||||
0.81369344,
|
||||
-0.82023817,
|
||||
-5.9179796,
|
||||
0.58813443,
|
||||
-6.99778438,
|
||||
4.71551189,
|
||||
-0.18771637,
|
||||
7.44020759, # 4th highest value; idx. 25
|
||||
9.38450987, # 1st highest value; idx. 26
|
||||
2.12662941,
|
||||
-9.32562038,
|
||||
2.35652522,
|
||||
], # cumulative prob of 5 highest values <= 0.6
|
||||
[
|
||||
0.58425518,
|
||||
4.53139238,
|
||||
-5.57510464,
|
||||
-6.28030699,
|
||||
-7.19529503,
|
||||
-4.02122551,
|
||||
1.39337037,
|
||||
-6.06707057,
|
||||
1.59480517,
|
||||
-9.643119,
|
||||
0.03907799,
|
||||
0.67231762,
|
||||
-8.88206726,
|
||||
6.27115922, # 4th highest value; idx. 13
|
||||
2.28520723,
|
||||
4.82767506,
|
||||
4.30421368,
|
||||
8.8275313, # 2nd highest value; idx. 17
|
||||
5.44029958, # 5th highest value; idx. 18
|
||||
-4.4735794,
|
||||
7.38579536, # 3rd highest value; idx. 20
|
||||
-2.91051663,
|
||||
2.61946077,
|
||||
-2.5674762,
|
||||
-9.48959302,
|
||||
-4.02922645,
|
||||
-1.35416918,
|
||||
9.67702323, # 1st highest value; idx. 27
|
||||
-5.89478553,
|
||||
1.85370467,
|
||||
], # cumulative prob of 5 highest values <= 0.6
|
||||
],
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
non_inf_expected_idx = torch.tensor(
|
||||
[[0, 0], [0, 9], [0, 10], [0, 25], [0, 26], [1, 13], [1, 17], [1, 18], [1, 20], [1, 27]],
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
) # expected non filtered idx as noted above
|
||||
|
||||
non_inf_expected_output = torch.tensor(
|
||||
[
|
||||
8.2221,
|
||||
7.3534,
|
||||
8.4321,
|
||||
7.4402,
|
||||
9.3845,
|
||||
6.2712,
|
||||
8.8275,
|
||||
5.4403,
|
||||
7.3858,
|
||||
9.6770,
|
||||
], # expected non filtered values as noted above
|
||||
dtype=torch.float,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
output = top_k_top_p_filtering(logits, top_k=10, top_p=0.6, min_tokens_to_keep=4)
|
||||
non_inf_output = output[output != -float("inf")].to(device=torch_device)
|
||||
non_inf_idx = (output != -float("inf")).nonzero().to(device=torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(non_inf_expected_output, non_inf_output, atol=1e-12))
|
||||
self.assertTrue(torch.all(torch.eq(non_inf_expected_idx, non_inf_idx)))
|
||||
|
||||
Reference in New Issue
Block a user