Generate: Add new decoding strategy "DoLa" in .generate() (#29619)
Co-authored-by: Joao Gante <joao@huggingface.co>
This commit is contained in:
@@ -1264,6 +1264,55 @@ class GenerationTesterMixin:
|
||||
for output in (output_greedy, output_prompt_lookup):
|
||||
self._check_outputs(output, input_ids, model.config, use_cache=True)
|
||||
|
||||
def test_dola_decoding_sample(self):
|
||||
# TODO (joao): investigate skips, try to reduce incompatibilities
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if model_class._is_stateful:
|
||||
self.skipTest(reason="Stateful models don't support DoLa decoding")
|
||||
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["reformer"]):
|
||||
self.skipTest("Skip Reformer as the lm_head input size is 2 * hidden size, adopted from Rev Nets.")
|
||||
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["marian", "mbart", "pegasus"]):
|
||||
self.skipTest("DoLa is not supported for models that don't return layerwise hidden states")
|
||||
|
||||
# enable cache if the model is not openai-gpt, xlnet, cpm, or xlm
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
# Some models don't support the cache and returning past_key_values
|
||||
if not hasattr(config, "use_cache"):
|
||||
config.use_cache = False
|
||||
else:
|
||||
config.use_cache = True
|
||||
|
||||
# Encoder-decoder models are not supported
|
||||
if config.is_encoder_decoder:
|
||||
self.skipTest("DoLa is not supported for encoder-decoder models")
|
||||
config.is_decoder = True
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
if model.get_output_embeddings() is None:
|
||||
self.skipTest("DoLa is not supported for models that don't have output embeddings")
|
||||
# Sets dola generation arguments such that:
|
||||
# a) no EOS is generated, to ensure generation doesn't break early
|
||||
# b) there are at least two forward passes in the main model, to ensure the input preparation of
|
||||
# the main model is correct
|
||||
generation_kwargs = {
|
||||
"eos_token_id": -1, # see a)
|
||||
"max_new_tokens": 4, # see b)
|
||||
"num_beams": 1,
|
||||
"do_sample": True,
|
||||
"output_scores": True,
|
||||
"output_logits": True,
|
||||
"output_hidden_states": True,
|
||||
"output_attentions": self.has_attentions,
|
||||
"return_dict_in_generate": True,
|
||||
}
|
||||
generation_kwargs.update({"dola_layers": "low"})
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_dola = model.generate(input_ids, **model_kwargs, **generation_kwargs)
|
||||
self._check_outputs(output_dola, input_ids, model.config, use_cache=config.use_cache)
|
||||
|
||||
def test_assisted_decoding_sample(self):
|
||||
# In this test we don't check assisted vs non-assisted output -- seeded assisted decoding with sample will not
|
||||
# match sample for the same seed, as the forward pass does not return the exact same logits (due to matmul with
|
||||
|
||||
@@ -839,7 +839,6 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS[self.cuda_compute_capability_major_version])
|
||||
|
||||
@slow
|
||||
@@ -898,3 +897,24 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text)
|
||||
|
||||
def test_model_2b_bf16_dola(self):
|
||||
model_id = "google/gemma-2b"
|
||||
# ground truth text generated with dola_layers="low", repetition_penalty=1.2
|
||||
EXPECTED_TEXTS = [
|
||||
"Hello I am doing an experiment and need to get the mass of a block. The problem is, it has no scale",
|
||||
"Hi today we have the review for a <strong>2016/2017</strong> season of",
|
||||
]
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
|
||||
|
||||
output = model.generate(
|
||||
**inputs, max_new_tokens=20, do_sample=False, dola_layers="low", repetition_penalty=1.2
|
||||
)
|
||||
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
@@ -703,6 +703,29 @@ class LlamaIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_model_7b_dola_generation(self):
|
||||
# ground truth text generated with dola_layers="low", repetition_penalty=1.2
|
||||
EXPECTED_TEXT_COMPLETION = (
|
||||
"Simply put, the theory of relativity states that 1) time and space are relative, and 2) the laws of "
|
||||
"physics are the same for all observers in uniform motion relative to one another.\n\nThe theory of "
|
||||
"relativity was developed by Albert Einstein in the early 20th century, and it revolutionized our "
|
||||
"understanding of space and time."
|
||||
)
|
||||
prompt = "Simply put, the theory of relativity states that "
|
||||
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-chat-hf", device_map="sequential", torch_dtype=torch.float16
|
||||
)
|
||||
model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||||
|
||||
# greedy generation outputs
|
||||
generated_ids = model.generate(
|
||||
**model_inputs, max_new_tokens=64, top_p=None, temperature=1, do_sample=False, dola_layers="low"
|
||||
)
|
||||
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_read_token
|
||||
|
||||
@@ -555,6 +555,30 @@ class MistralIntegrationTest(unittest.TestCase):
|
||||
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], text)
|
||||
|
||||
@slow
|
||||
def test_model_7b_dola_generation(self):
|
||||
# ground truth text generated with dola_layers="low", repetition_penalty=1.2
|
||||
EXPECTED_TEXT_COMPLETION = (
|
||||
"""My favourite condiment is 100% ketchup. I love it on everything, and I’m not ash"""
|
||||
)
|
||||
prompt = "My favourite condiment is "
|
||||
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
|
||||
model = MistralForCausalLM.from_pretrained(
|
||||
"mistralai/Mistral-7B-v0.1", device_map="auto", torch_dtype=torch.float16
|
||||
)
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device)
|
||||
|
||||
# greedy generation outputs
|
||||
generated_ids = model.generate(
|
||||
input_ids, max_new_tokens=20, temperature=0, dola_layers="low", repetition_penalty=1.2
|
||||
)
|
||||
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
||||
|
||||
del model
|
||||
backend_empty_cache(torch_device)
|
||||
gc.collect()
|
||||
|
||||
@require_bitsandbytes
|
||||
@slow
|
||||
@require_flash_attn
|
||||
|
||||
Reference in New Issue
Block a user