Generate: decoder-only models can generate with inputs_embeds (#21405)
This commit is contained in:
@@ -519,47 +519,40 @@ class GenerationMixin:
|
|||||||
inputs_kwarg = model_kwargs.pop(input_name, None)
|
inputs_kwarg = model_kwargs.pop(input_name, None)
|
||||||
if inputs_kwarg is not None and inputs is not None:
|
if inputs_kwarg is not None and inputs is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`inputs`: {inputs}` were passed alongside "
|
f"`inputs`: {inputs}` were passed alongside {input_name} which is not allowed."
|
||||||
f"{input_name} which is not allowed."
|
|
||||||
f"Make sure to either pass {inputs} or {input_name}=..."
|
f"Make sure to either pass {inputs} or {input_name}=..."
|
||||||
)
|
)
|
||||||
elif inputs_kwarg is not None:
|
elif inputs_kwarg is not None:
|
||||||
inputs = inputs_kwarg
|
inputs = inputs_kwarg
|
||||||
|
|
||||||
# 3. models with `input_ids` can also make use of `inputs_embeds`
|
# 3. In the presence of `inputs_embeds` for text models:
|
||||||
if self._can_retrieve_inputs_from_name(inputs, "inputs_embeds", model_kwargs):
|
# - decoder-only models should complain if the user attempts to pass `inputs_embeds`, but the model
|
||||||
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
|
# doesn't have its forwarding implemented. `inputs_embeds` is kept in `model_kwargs` and can coexist with
|
||||||
|
# input_ids (`inputs_embeds` will be used in the 1st generation step, as opposed to `input_ids`)
|
||||||
|
# - encoder-decoder models should complain if the user attempts to pass `inputs_embeds` and `input_ids`, and
|
||||||
|
# pull the former to inputs. It will be used in place of `input_ids` to get the encoder hidden states.
|
||||||
|
if input_name == "input_ids" and "inputs_embeds" in model_kwargs:
|
||||||
|
if not self.config.is_encoder_decoder:
|
||||||
|
has_inputs_embeds_forwarding = "inputs_embeds" in set(
|
||||||
|
inspect.signature(self.prepare_inputs_for_generation).parameters.keys()
|
||||||
|
)
|
||||||
|
if not has_inputs_embeds_forwarding:
|
||||||
|
raise ValueError(
|
||||||
|
f"You passed `inputs_embeds` to `.generate()`, but the model class {self.__class__.__name__} "
|
||||||
|
"doesn't have its forwarding implemented. See the GPT2 implementation for an example "
|
||||||
|
"(https://github.com/huggingface/transformers/pull/21405), and feel free to open a PR with it!"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if inputs is not None:
|
||||||
|
raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.")
|
||||||
|
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
|
||||||
|
|
||||||
# 4. Only encoder-decoder models can have non `input_ids` input format
|
# 4. if `inputs` is still None, try to create `input_ids` from BOS token
|
||||||
if not self.config.is_encoder_decoder and input_name != "input_ids":
|
|
||||||
raise ValueError(
|
|
||||||
f"If {input_name} is passed as model-specific keyword "
|
|
||||||
"input then model has to be an encoder-decoder and not a "
|
|
||||||
f"{self.__class__.__name__}."
|
|
||||||
)
|
|
||||||
|
|
||||||
# 5. if `inputs` is still None, try to create `input_ids` from BOS token
|
|
||||||
if inputs is None:
|
if inputs is None:
|
||||||
inputs = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs"))
|
inputs = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs"))
|
||||||
|
|
||||||
return inputs, input_name, model_kwargs
|
return inputs, input_name, model_kwargs
|
||||||
|
|
||||||
def _can_retrieve_inputs_from_name(
|
|
||||||
self, inputs: Optional[torch.Tensor], name: str, model_kwargs: Dict[str, torch.Tensor]
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
If `inputs` is None and `name` is in both forward function and keyword arguments, then inputs can be retrieved
|
|
||||||
from name
|
|
||||||
"""
|
|
||||||
can_retrieve_inputs = model_kwargs.get(name, None) is not None and name in set(
|
|
||||||
inspect.signature(self.forward).parameters.keys()
|
|
||||||
)
|
|
||||||
|
|
||||||
if can_retrieve_inputs and inputs is not None:
|
|
||||||
raise ValueError(f"Cannot only pass one of {name} and {self.main_input_name}")
|
|
||||||
|
|
||||||
return can_retrieve_inputs
|
|
||||||
|
|
||||||
def adjust_logits_during_generation(self, logits: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
def adjust_logits_during_generation(self, logits: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
||||||
"""
|
"""
|
||||||
Implement in subclasses of [`PreTrainedModel`] for custom behavior to adjust the logits in the generate method.
|
Implement in subclasses of [`PreTrainedModel`] for custom behavior to adjust the logits in the generate method.
|
||||||
|
|||||||
@@ -981,7 +981,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
|||||||
def set_output_embeddings(self, new_embeddings):
|
def set_output_embeddings(self, new_embeddings):
|
||||||
self.lm_head = new_embeddings
|
self.lm_head = new_embeddings
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
|
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
||||||
token_type_ids = kwargs.get("token_type_ids", None)
|
token_type_ids = kwargs.get("token_type_ids", None)
|
||||||
# only last token for inputs_ids if past is defined in kwargs
|
# only last token for inputs_ids if past is defined in kwargs
|
||||||
if past_key_values:
|
if past_key_values:
|
||||||
@@ -1000,14 +1000,23 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
|||||||
position_ids = position_ids[:, -1].unsqueeze(-1)
|
position_ids = position_ids[:, -1].unsqueeze(-1)
|
||||||
else:
|
else:
|
||||||
position_ids = None
|
position_ids = None
|
||||||
return {
|
|
||||||
"input_ids": input_ids,
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||||
"past_key_values": past_key_values,
|
if inputs_embeds is not None and past_key_values is None:
|
||||||
"use_cache": kwargs.get("use_cache"),
|
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||||
"position_ids": position_ids,
|
else:
|
||||||
"attention_mask": attention_mask,
|
model_inputs = {"input_ids": input_ids}
|
||||||
"token_type_ids": token_type_ids,
|
|
||||||
}
|
model_inputs.update(
|
||||||
|
{
|
||||||
|
"past_key_values": past_key_values,
|
||||||
|
"use_cache": kwargs.get("use_cache"),
|
||||||
|
"position_ids": position_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"token_type_ids": token_type_ids,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
|
|||||||
@@ -2359,17 +2359,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
|
|
||||||
self.assertTrue(diff < 1e-4)
|
self.assertTrue(diff < 1e-4)
|
||||||
|
|
||||||
def test_decoder_generate_with_inputs_embeds(self):
|
|
||||||
article = """I need input_ids to generate"""
|
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
|
||||||
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=5).to(torch_device)
|
|
||||||
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
|
||||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
|
||||||
|
|
||||||
# cannot generate from `inputs_embeds` for decoder only
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
model.generate(inputs_embeds=inputs_embeds)
|
|
||||||
|
|
||||||
def test_generate_input_ids_as_kwarg(self):
|
def test_generate_input_ids_as_kwarg(self):
|
||||||
article = """I need input_ids to generate"""
|
article = """I need input_ids to generate"""
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||||
@@ -2417,8 +2406,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
|
|
||||||
def test_generate_too_many_encoder_kwargs(self):
|
def test_generate_too_many_encoder_kwargs(self):
|
||||||
article = """I need input_ids to generate"""
|
article = """I need input_ids to generate"""
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||||
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=10).to(torch_device)
|
model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart", max_length=10).to(
|
||||||
|
torch_device
|
||||||
|
)
|
||||||
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
model.generate(input_ids=input_ids, inputs_embeds=input_ids)
|
model.generate(input_ids=input_ids, inputs_embeds=input_ids)
|
||||||
@@ -3128,3 +3119,26 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
eos_token_id = [873]
|
eos_token_id = [873]
|
||||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||||
|
|
||||||
|
def test_generate_from_input_embeds_decoder_only(self):
|
||||||
|
# Note: the model must support generation from input embeddings
|
||||||
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||||
|
|
||||||
|
text = "Hello world"
|
||||||
|
input_ids = tokenizer.encode(text, return_tensors="pt")
|
||||||
|
|
||||||
|
# Traditional way of generating text
|
||||||
|
outputs_from_ids = model.generate(input_ids)
|
||||||
|
|
||||||
|
# Same thing, but from input embeddings
|
||||||
|
inputs_embeds = model.transformer.wte(input_ids)
|
||||||
|
outputs_from_embeds = model.generate(input_ids, inputs_embeds=inputs_embeds)
|
||||||
|
self.assertListEqual(outputs_from_ids.tolist(), outputs_from_embeds.tolist())
|
||||||
|
|
||||||
|
# But if we pass different inputs_embeds, we should get different outputs
|
||||||
|
torch.manual_seed(0)
|
||||||
|
random_embeds = torch.rand_like(inputs_embeds)
|
||||||
|
outputs_from_rand_embeds = model.generate(input_ids, inputs_embeds=random_embeds)
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
self.assertListEqual(outputs_from_rand_embeds.tolist(), outputs_from_embeds.tolist())
|
||||||
|
|||||||
Reference in New Issue
Block a user