Generate: store special token tensors under a unique variable name (#31980)

* rename stuff

* english; this one shouldn't be changed

* add a _ to the new var names

* musicgen

* derp
This commit is contained in:
Joao Gante
2024-07-22 14:06:49 +01:00
committed by GitHub
parent aa8f86a421
commit c38c55f4fb
4 changed files with 187 additions and 273 deletions

View File

@@ -3196,6 +3196,40 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
)
self.assertTrue(input_length <= out.shape[-1] <= input_length + 20)
def test_special_tokens_fall_back_to_model_default(self):
# PT-only test: TF doesn't support assisted decoding yet.
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(
torch_device
)
test_bos_id = 50
# Sanity-check: the model has a BOS token set, and the first generated token is a BOS token
gen_output = model.generate()
self.assertTrue(model.generation_config.bos_token_id is not None)
self.assertTrue(model.generation_config.bos_token_id == gen_output[0, 0])
# If we pass a generation config **with** a BOS token, `generate` will use it
generation_config = GenerationConfig(bos_token_id=test_bos_id)
gen_output = model.generate(generation_config=generation_config)
self.assertFalse(model.generation_config.bos_token_id == gen_output[0, 0])
self.assertTrue(generation_config.bos_token_id == gen_output[0, 0])
self.assertTrue(test_bos_id == gen_output[0, 0])
# If we pass a generation config **without** a BOS token, `generate` will fetch the BOS token from
# `model.generation_config`
generation_config = GenerationConfig(bos_token_id=None)
gen_output = model.generate(generation_config=generation_config)
self.assertTrue(model.generation_config.bos_token_id == gen_output[0, 0])
self.assertFalse(test_bos_id == gen_output[0, 0])
self.assertTrue(generation_config.bos_token_id is None)
# Changing `model.generation_config` will affect fallback behavior
model.generation_config.bos_token_id = test_bos_id
gen_output = model.generate(generation_config=generation_config)
self.assertTrue(model.generation_config.bos_token_id == gen_output[0, 0])
self.assertTrue(test_bos_id == gen_output[0, 0])
self.assertTrue(generation_config.bos_token_id is None)
@require_torch
class TokenHealingTestCase(unittest.TestCase):