Refactoring the generate() function (#6949)
* first draft * show design proposition for new generate method * up * make better readable * make first version * gpt2 tests pass * make beam search for gpt2 work * add first encoder-decoder code * delete typo * make t5 work * save indermediate * make bart work with beam search * finish beam search bart / t5 * add default kwargs * make more tests pass * fix no bad words sampler * some fixes and tests for all distribution processors * fix test * fix rag slow tests * merge to master * add nograd to generate * make all slow tests pass * speed up generate * fix edge case bug * small fix * correct typo * add type hints and docstrings * fix typos in tests * add beam search tests * add tests for beam scorer * fix test rag * finish beam search tests * move generation tests in seperate file * fix generation tests * more tests * add aggressive generation tests * fix tests * add gpt2 sample test * add more docstring * add more docs * finish doc strings * apply some more of sylvains and sams comments * fix some typos * make fix copies * apply lysandres and sylvains comments * final corrections on examples * small fix for reformer
This commit is contained in:
committed by
GitHub
parent
b63beb743c
commit
a1bbcf3f6c
@@ -20,6 +20,7 @@ from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_generation_utils import GenerationTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
@@ -377,7 +378,7 @@ class GPT2ModelTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (
|
||||
(GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel, GPT2ForSequenceClassification)
|
||||
@@ -510,32 +511,17 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||
|
||||
@slow
|
||||
def test_lm_generate_distilgpt2(self):
|
||||
model = GPT2LMHeadModel.from_pretrained("distilgpt2")
|
||||
def test_gpt2_sample(self):
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
model = GPT2LMHeadModel.from_pretrained("gpt2")
|
||||
model.to(torch_device)
|
||||
input_ids = torch.tensor([[464, 1893]], dtype=torch.long, device=torch_device) # The president
|
||||
expected_output_ids = [
|
||||
464,
|
||||
1893,
|
||||
286,
|
||||
262,
|
||||
1578,
|
||||
1829,
|
||||
11,
|
||||
290,
|
||||
262,
|
||||
1893,
|
||||
286,
|
||||
262,
|
||||
1578,
|
||||
7526,
|
||||
11,
|
||||
423,
|
||||
587,
|
||||
287,
|
||||
262,
|
||||
2635,
|
||||
] # The president of the United States, and the president of the United Kingdom, have been in the White
|
||||
|
||||
output_ids = model.generate(input_ids, do_sample=False)
|
||||
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||
torch.manual_seed(0)
|
||||
input_ids = tokenizer("Today is a nice day and", return_tensors="pt").input_ids.to(torch_device)
|
||||
output_ids = model.generate(input_ids, do_sample=True)
|
||||
output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
||||
|
||||
EXPECTED_OUTPUT_STR = (
|
||||
"Today is a nice day and if you don't know anything about the state of play during your holiday"
|
||||
)
|
||||
self.assertEqual(output_str, EXPECTED_OUTPUT_STR)
|
||||
|
||||
Reference in New Issue
Block a user