Generate: store special token tensors under a unique variable name (#31980)
* rename stuff * english; this one shouldn't be changed * add a _ to the new var names * musicgen * derp
This commit is contained in:
@@ -3196,6 +3196,40 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
)
|
||||
self.assertTrue(input_length <= out.shape[-1] <= input_length + 20)
|
||||
|
||||
def test_special_tokens_fall_back_to_model_default(self):
|
||||
# PT-only test: TF doesn't support assisted decoding yet.
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(
|
||||
torch_device
|
||||
)
|
||||
test_bos_id = 50
|
||||
|
||||
# Sanity-check: the model has a BOS token set, and the first generated token is a BOS token
|
||||
gen_output = model.generate()
|
||||
self.assertTrue(model.generation_config.bos_token_id is not None)
|
||||
self.assertTrue(model.generation_config.bos_token_id == gen_output[0, 0])
|
||||
|
||||
# If we pass a generation config **with** a BOS token, `generate` will use it
|
||||
generation_config = GenerationConfig(bos_token_id=test_bos_id)
|
||||
gen_output = model.generate(generation_config=generation_config)
|
||||
self.assertFalse(model.generation_config.bos_token_id == gen_output[0, 0])
|
||||
self.assertTrue(generation_config.bos_token_id == gen_output[0, 0])
|
||||
self.assertTrue(test_bos_id == gen_output[0, 0])
|
||||
|
||||
# If we pass a generation config **without** a BOS token, `generate` will fetch the BOS token from
|
||||
# `model.generation_config`
|
||||
generation_config = GenerationConfig(bos_token_id=None)
|
||||
gen_output = model.generate(generation_config=generation_config)
|
||||
self.assertTrue(model.generation_config.bos_token_id == gen_output[0, 0])
|
||||
self.assertFalse(test_bos_id == gen_output[0, 0])
|
||||
self.assertTrue(generation_config.bos_token_id is None)
|
||||
|
||||
# Changing `model.generation_config` will affect fallback behavior
|
||||
model.generation_config.bos_token_id = test_bos_id
|
||||
gen_output = model.generate(generation_config=generation_config)
|
||||
self.assertTrue(model.generation_config.bos_token_id == gen_output[0, 0])
|
||||
self.assertTrue(test_bos_id == gen_output[0, 0])
|
||||
self.assertTrue(generation_config.bos_token_id is None)
|
||||
|
||||
|
||||
@require_torch
|
||||
class TokenHealingTestCase(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user