Generate: Add new decoding strategy "DoLa" in .generate() (#29619)

Co-authored-by: Joao Gante <joao@huggingface.co>
This commit is contained in:
Yung-Sung Chuang
2024-07-09 09:37:38 -07:00
committed by GitHub
parent 99c0e55335
commit d094d8d9ec
7 changed files with 530 additions and 5 deletions

View File

@@ -1264,6 +1264,55 @@ class GenerationTesterMixin:
for output in (output_greedy, output_prompt_lookup):
self._check_outputs(output, input_ids, model.config, use_cache=True)
def test_dola_decoding_sample(self):
# TODO (joao): investigate skips, try to reduce incompatibilities
for model_class in self.all_generative_model_classes:
if model_class._is_stateful:
self.skipTest(reason="Stateful models don't support DoLa decoding")
if any(model_name in model_class.__name__.lower() for model_name in ["reformer"]):
self.skipTest("Skip Reformer as the lm_head input size is 2 * hidden size, adopted from Rev Nets.")
if any(model_name in model_class.__name__.lower() for model_name in ["marian", "mbart", "pegasus"]):
self.skipTest("DoLa is not supported for models that don't return layerwise hidden states")
# enable cache if the model is not openai-gpt, xlnet, cpm, or xlm
config, input_ids, attention_mask = self._get_input_ids_and_config()
# Some models don't support the cache and returning past_key_values
if not hasattr(config, "use_cache"):
config.use_cache = False
else:
config.use_cache = True
# Encoder-decoder models are not supported
if config.is_encoder_decoder:
self.skipTest("DoLa is not supported for encoder-decoder models")
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
if model.get_output_embeddings() is None:
self.skipTest("DoLa is not supported for models that don't have output embeddings")
# Sets dola generation arguments such that:
# a) no EOS is generated, to ensure generation doesn't break early
# b) there are at least two forward passes in the main model, to ensure the input preparation of
# the main model is correct
generation_kwargs = {
"eos_token_id": -1, # see a)
"max_new_tokens": 4, # see b)
"num_beams": 1,
"do_sample": True,
"output_scores": True,
"output_logits": True,
"output_hidden_states": True,
"output_attentions": self.has_attentions,
"return_dict_in_generate": True,
}
generation_kwargs.update({"dola_layers": "low"})
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_dola = model.generate(input_ids, **model_kwargs, **generation_kwargs)
self._check_outputs(output_dola, input_ids, model.config, use_cache=config.use_cache)
def test_assisted_decoding_sample(self):
# In this test we don't check assisted vs non-assisted output -- seeded assisted decoding with sample will not
# match sample for the same seed, as the forward pass does not return the exact same logits (due to matmul with