Falcon: batched generation (#26137)
This commit is contained in:
@@ -19,8 +19,16 @@ import unittest
|
||||
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import AutoConfig, AutoModel, AutoTokenizer, FalconConfig, is_torch_available, set_seed
|
||||
from transformers.testing_utils import CaptureLogger, require_torch, slow, tooslow, torch_device
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
FalconConfig,
|
||||
is_torch_available,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.testing_utils import CaptureLogger, require_bitsandbytes, require_torch, slow, tooslow, torch_device
|
||||
from transformers.utils import logging as transformers_logging
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
@@ -502,6 +510,35 @@ class FalconLanguageGenerationTest(unittest.TestCase):
|
||||
outputs_cache = model.generate(**inputs, do_sample=False, max_new_tokens=20, use_cache=True)
|
||||
self.assertTrue((outputs_cache - outputs_no_cache).sum().item() == 0)
|
||||
|
||||
@require_bitsandbytes
|
||||
@slow
|
||||
def test_batched_generation(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-7b", padding_side="left")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"tiiuae/falcon-7b",
|
||||
device_map="auto",
|
||||
load_in_4bit=True,
|
||||
)
|
||||
|
||||
test_text = "A sequence: 1, 2" # should generate the rest of the sequence
|
||||
|
||||
unpadded_inputs = tokenizer([test_text], return_tensors="pt").to("cuda:0")
|
||||
unpadded_inputs.pop("token_type_ids")
|
||||
unpadded_gen_out = model.generate(**unpadded_inputs, max_new_tokens=20)
|
||||
unpadded_gen_text = tokenizer.batch_decode(unpadded_gen_out, skip_special_tokens=True)
|
||||
|
||||
dummy_text = "This is a longer text " * 2 # forces left-padding on `test_text`
|
||||
padded_inputs = tokenizer([test_text, dummy_text], return_tensors="pt", padding=True).to("cuda:0")
|
||||
padded_inputs.pop("token_type_ids")
|
||||
padded_gen_out = model.generate(**padded_inputs, max_new_tokens=20)
|
||||
padded_gen_text = tokenizer.batch_decode(padded_gen_out, skip_special_tokens=True)
|
||||
|
||||
expected_output = "A sequence: 1, 2, 3, 4, 5, 6, 7, 8, "
|
||||
self.assertLess(unpadded_inputs.input_ids.shape[-1], padded_inputs.input_ids.shape[-1]) # left-padding exists
|
||||
self.assertEqual(unpadded_gen_text[0], expected_output)
|
||||
self.assertEqual(padded_gen_text[0], expected_output)
|
||||
|
||||
|
||||
# TODO Lysandre: Remove this in version v4.34
|
||||
class FalconOverrideTest(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user