From 044281605f5fc037e1048d4360da971fdfbca370 Mon Sep 17 00:00:00 2001 From: Pablo Montalvo <39954772+molbap@users.noreply.github.com> Date: Thu, 8 Aug 2024 18:44:53 +0200 Subject: [PATCH] Fix generate with `inputs_embeds` as input (#32493) * I think inputs_embeds has ndim == 3 * fix sequence length catch * add generate test * [run-slow]olmo, persimmon, gemma, gemma2, qwen2, llama * skip whisper * fix bart test * more fixes --- .../models/codegen/modeling_codegen.py | 15 +++--- .../models/cohere/modeling_cohere.py | 15 +++--- src/transformers/models/dbrx/modeling_dbrx.py | 15 +++--- .../models/falcon/modeling_falcon.py | 15 +++--- .../models/gemma/modeling_gemma.py | 15 +++--- .../models/gemma2/modeling_gemma2.py | 25 +++------- .../models/gpt_neo/modeling_gpt_neo.py | 15 +++--- .../models/gpt_neox/modeling_gpt_neox.py | 15 +++--- src/transformers/models/gptj/modeling_gptj.py | 15 +++--- .../models/llama/modeling_llama.py | 15 +++--- .../models/nemotron/modeling_nemotron.py | 15 +++--- src/transformers/models/olmo/modeling_olmo.py | 15 +++--- .../models/persimmon/modeling_persimmon.py | 15 +++--- src/transformers/models/phi/modeling_phi.py | 15 +++--- src/transformers/models/phi3/modeling_phi3.py | 15 +++--- .../models/qwen2/modeling_qwen2.py | 15 +++--- .../models/qwen2_moe/modeling_qwen2_moe.py | 15 +++--- .../models/stablelm/modeling_stablelm.py | 15 +++--- .../models/starcoder2/modeling_starcoder2.py | 15 +++--- tests/models/bart/test_modeling_bart.py | 5 ++ tests/models/bert/test_modeling_bert.py | 5 ++ tests/models/whisper/test_modeling_whisper.py | 5 ++ tests/test_modeling_common.py | 47 +++++++++++++++++++ 23 files changed, 213 insertions(+), 144 deletions(-) diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 6452c2afa0..1920f350f5 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -756,17 +756,18 @@ class CodeGenForCausalLM(CodeGenPreTrainedModel): # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds} + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: - model_inputs = {"input_ids": input_ids} + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: - if inputs_embeds is not None: - batch_size, sequence_length = inputs_embeds.shape - device = inputs_embeds.device + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device else: - batch_size, sequence_length = input_ids.shape - device = input_ids.device + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device dtype = self.lm_head.weight.dtype min_dtype = torch.finfo(dtype).min diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 4bc03ade89..afcea137b5 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -1132,17 +1132,18 @@ class CohereForCausalLM(CoherePreTrainedModel): # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds} + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: - model_inputs = {"input_ids": input_ids} + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: - if inputs_embeds is not None: - batch_size, sequence_length = inputs_embeds.shape - device = inputs_embeds.device + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device else: - batch_size, sequence_length = input_ids.shape - device = input_ids.device + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device dtype = self.lm_head.weight.dtype min_dtype = torch.finfo(dtype).min diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index f07b910fcf..3486d5ed3a 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -1403,17 +1403,18 @@ class DbrxForCausalLM(DbrxPreTrainedModel): # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds} + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: - model_inputs = {"input_ids": input_ids} + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: - if inputs_embeds is not None: - batch_size, sequence_length = inputs_embeds.shape - device = inputs_embeds.device + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device else: - batch_size, sequence_length = input_ids.shape - device = input_ids.device + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device dtype = self.lm_head.weight.dtype min_dtype = torch.finfo(dtype).min diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 5ddc1fba9e..edaef78f92 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -1270,17 +1270,18 @@ class FalconForCausalLM(FalconPreTrainedModel): # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds} + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: - model_inputs = {"input_ids": input_ids} + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: - if inputs_embeds is not None: - batch_size, sequence_length = inputs_embeds.shape - device = inputs_embeds.device + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device else: - batch_size, sequence_length = input_ids.shape - device = input_ids.device + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device dtype = self.lm_head.weight.dtype min_dtype = torch.finfo(dtype).min diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 2710234717..a05d2c059e 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -1143,17 +1143,18 @@ class GemmaForCausalLM(GemmaPreTrainedModel): # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds} + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: - model_inputs = {"input_ids": input_ids} + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: - if inputs_embeds is not None: - batch_size, sequence_length = inputs_embeds.shape - device = inputs_embeds.device + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device else: - batch_size, sequence_length = input_ids.shape - device = input_ids.device + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device dtype = self.lm_head.weight.dtype min_dtype = torch.finfo(dtype).min diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 8953238186..da929c0867 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -104,7 +104,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position( causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) - return causal_mask @@ -301,7 +300,6 @@ class Gemma2Attention(nn.Module): attn_weights = attn_weights / self.config.attn_logit_softcapping attn_weights = torch.tanh(attn_weights) attn_weights = attn_weights * self.config.attn_logit_softcapping - if attention_mask is not None: # no matter the length, we just slice it causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask @@ -501,11 +499,9 @@ class Gemma2SdpaAttention(Gemma2Attention): key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - causal_mask = attention_mask if attention_mask is not None: causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. if query_states.device.type == "cuda" and causal_mask is not None: @@ -516,7 +512,6 @@ class Gemma2SdpaAttention(Gemma2Attention): # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. is_causal = True if causal_mask is None and q_len > 1 else False - attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, @@ -581,7 +576,6 @@ class Gemma2DecoderLayer(nn.Module): attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) if attention_mask.shape[-1] <= 1: # when decoding attention_mask = attention_mask[:, :, :, -self.sliding_window :] - residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -1013,7 +1007,6 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel): output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, @@ -1080,7 +1073,6 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel): input_ids = input_ids[:, -cache_position.shape[0] :] elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) input_ids = input_ids[:, cache_position] - if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 @@ -1096,22 +1088,20 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel): # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds} + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: # The clone here is for the same reason as for `position_ids`. - model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format)} + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} if isinstance(past_key_values, HybridCache) and attention_mask.ndim == 2: - if inputs_embeds is not None: - batch_size, sequence_length = inputs_embeds.shape - device = inputs_embeds.device + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device else: - batch_size, sequence_length = input_ids.shape - device = input_ids.device - + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device dtype = self.lm_head.weight.dtype min_dtype = torch.finfo(dtype).min - attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, @@ -1122,7 +1112,6 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel): cache_position=cache_position, batch_size=batch_size, ) - model_inputs.update( { "position_ids": position_ids, diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 8335268e84..3a606c37b3 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -970,17 +970,18 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel): # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds} + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: - model_inputs = {"input_ids": input_ids} + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: - if inputs_embeds is not None: - batch_size, sequence_length = inputs_embeds.shape - device = inputs_embeds.device + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device else: - batch_size, sequence_length = input_ids.shape - device = input_ids.device + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device dtype = self.lm_head.weight.dtype min_dtype = torch.finfo(dtype).min diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 3e72eec072..22fbb0429f 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -1220,17 +1220,18 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel): # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds} + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: - model_inputs = {"input_ids": input_ids} + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: - if inputs_embeds is not None: - batch_size, sequence_length = inputs_embeds.shape - device = inputs_embeds.device + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device else: - batch_size, sequence_length = input_ids.shape - device = input_ids.device + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device dtype = self.embed_out.weight.dtype min_dtype = torch.finfo(dtype).min diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 39b0f1fc26..82540fe98e 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -1100,17 +1100,18 @@ class GPTJForCausalLM(GPTJPreTrainedModel): # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds} + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: - model_inputs = {"input_ids": input_ids} + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: - if inputs_embeds is not None: - batch_size, sequence_length = inputs_embeds.shape - device = inputs_embeds.device + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device else: - batch_size, sequence_length = input_ids.shape - device = input_ids.device + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device dtype = self.lm_head.weight.dtype min_dtype = torch.finfo(dtype).min diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index dd053c805f..4a0887629c 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1265,17 +1265,18 @@ class LlamaForCausalLM(LlamaPreTrainedModel): # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds} + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: - model_inputs = {"input_ids": input_ids} + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: - if inputs_embeds is not None: - batch_size, sequence_length = inputs_embeds.shape - device = inputs_embeds.device + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device else: - batch_size, sequence_length = input_ids.shape - device = input_ids.device + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device dtype = self.lm_head.weight.dtype min_dtype = torch.finfo(dtype).min diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index c56a275bcd..db4bce273c 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -1136,17 +1136,18 @@ class NemotronForCausalLM(NemotronPreTrainedModel): # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds} + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: - model_inputs = {"input_ids": input_ids} + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: - if inputs_embeds is not None: - batch_size, sequence_length = inputs_embeds.shape - device = inputs_embeds.device + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device else: - batch_size, sequence_length = input_ids.shape - device = input_ids.device + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device dtype = self.lm_head.weight.dtype min_dtype = torch.finfo(dtype).min diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 61a8a2bf6b..1940660f61 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -1176,17 +1176,18 @@ class OlmoForCausalLM(OlmoPreTrainedModel): # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds} + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: - model_inputs = {"input_ids": input_ids} + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: - if inputs_embeds is not None: - batch_size, sequence_length = inputs_embeds.shape - device = inputs_embeds.device + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device else: - batch_size, sequence_length = input_ids.shape - device = input_ids.device + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device dtype = self.lm_head.weight.dtype min_dtype = torch.finfo(dtype).min diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 885d744266..1e4f56c067 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -993,17 +993,18 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel): # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds} + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: - model_inputs = {"input_ids": input_ids} + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: - if inputs_embeds is not None: - batch_size, sequence_length = inputs_embeds.shape - device = inputs_embeds.device + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device else: - batch_size, sequence_length = input_ids.shape - device = input_ids.device + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device dtype = self.lm_head.weight.dtype min_dtype = torch.finfo(dtype).min diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index c445459763..6d63c0ea7e 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -1278,17 +1278,18 @@ class PhiForCausalLM(PhiPreTrainedModel): # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds} + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: - model_inputs = {"input_ids": input_ids} + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: - if inputs_embeds is not None: - batch_size, sequence_length = inputs_embeds.shape - device = inputs_embeds.device + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device else: - batch_size, sequence_length = input_ids.shape - device = input_ids.device + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device dtype = self.lm_head.weight.dtype min_dtype = torch.finfo(dtype).min diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 32871a37c0..601508e368 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -1318,17 +1318,18 @@ class Phi3ForCausalLM(Phi3PreTrainedModel): # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds} + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: - model_inputs = {"input_ids": input_ids} + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: - if inputs_embeds is not None: - batch_size, sequence_length = inputs_embeds.shape - device = inputs_embeds.device + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device else: - batch_size, sequence_length = input_ids.shape - device = input_ids.device + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device dtype = self.lm_head.weight.dtype min_dtype = torch.finfo(dtype).min diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index f30f3bdac7..28b414b190 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -1176,17 +1176,18 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel): # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds} + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: - model_inputs = {"input_ids": input_ids} + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: - if inputs_embeds is not None: - batch_size, sequence_length = inputs_embeds.shape - device = inputs_embeds.device + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device else: - batch_size, sequence_length = input_ids.shape - device = input_ids.device + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device dtype = self.lm_head.weight.dtype min_dtype = torch.finfo(dtype).min diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 8cf5c200d8..12ebe26e05 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -1372,17 +1372,18 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel): # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds} + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: - model_inputs = {"input_ids": input_ids} + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: - if inputs_embeds is not None: - batch_size, sequence_length = inputs_embeds.shape - device = inputs_embeds.device + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device else: - batch_size, sequence_length = input_ids.shape - device = input_ids.device + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device dtype = self.lm_head.weight.dtype min_dtype = torch.finfo(dtype).min diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 2f326184e1..988948a9a8 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -1271,17 +1271,18 @@ class StableLmForCausalLM(StableLmPreTrainedModel): # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds} + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: - model_inputs = {"input_ids": input_ids} + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: - if inputs_embeds is not None: - batch_size, sequence_length = inputs_embeds.shape - device = inputs_embeds.device + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device else: - batch_size, sequence_length = input_ids.shape - device = input_ids.device + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device dtype = self.lm_head.weight.dtype min_dtype = torch.finfo(dtype).min diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index d35b191149..f3b365776e 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -1152,17 +1152,18 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel): # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds} + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: - model_inputs = {"input_ids": input_ids} + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: - if inputs_embeds is not None: - batch_size, sequence_length = inputs_embeds.shape - device = inputs_embeds.device + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device else: - batch_size, sequence_length = input_ids.shape - device = input_ids.device + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device dtype = self.lm_head.weight.dtype min_dtype = torch.finfo(dtype).min diff --git a/tests/models/bart/test_modeling_bart.py b/tests/models/bart/test_modeling_bart.py index 61a0aa9091..dd0cb5bf4c 100644 --- a/tests/models/bart/test_modeling_bart.py +++ b/tests/models/bart/test_modeling_bart.py @@ -1540,3 +1540,8 @@ class BartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, un @unittest.skip def test_save_load_fast_init_from_base(self): pass + + @unittest.skip(reason="Generate needs input ids") + def test_inputs_embeds_matches_input_ids_with_generate(self): + # generate only works with input ids for bartforcausalLM + pass diff --git a/tests/models/bert/test_modeling_bert.py b/tests/models/bert/test_modeling_bert.py index 6ae9f6c279..766cc1c1bb 100644 --- a/tests/models/bert/test_modeling_bert.py +++ b/tests/models/bert/test_modeling_bert.py @@ -502,6 +502,11 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() self.model_tester.create_and_check_model_as_decoder(*config_and_inputs) + @unittest.skip(reason="Generate needs input ids") + def test_inputs_embeds_matches_input_ids_with_generate(self): + # generate only works with input ids for bertforcausalLM + pass + def test_model_as_decoder_with_default_input_mask(self): # This regression test was failing with PyTorch < 1.3 ( diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index f43f29f565..269e726422 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -4058,6 +4058,11 @@ class WhisperStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, # generate only works with input ids for whisper pass + @unittest.skip(reason="Generate needs input ids") + def test_inputs_embeds_matches_input_ids_with_generate(self): + # generate only works with input ids for whisper + pass + @unittest.skip(reason="Decoder can't keep attention grads") def test_retain_grad_hidden_states_attentions(self): return diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 4a29942641..d9eddfbbc8 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2819,6 +2819,53 @@ class ModelTesterMixin: )[0] self.assertTrue(torch.allclose(out_embeds, out_ids)) + def test_inputs_embeds_matches_input_ids_with_generate(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + if model_class.__name__ not in get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES): + continue + model = model_class(config) + model.to(torch_device) + model.eval() + + model_forward_args = inspect.signature(model.forward).parameters + if "inputs_embeds" not in model_forward_args: + self.skipTest(reason="This model doesn't use `inputs_embeds`") + + inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) + pad_token_id = config.pad_token_id if config.pad_token_id is not None else 1 + + wte = model.get_input_embeddings() + if not self.is_encoder_decoder: + input_ids = inputs["input_ids"] + # some models infer position ids/attn mask differently when input ids + # by check if pad_token let's make sure no padding is in input ids + not_pad_token_id = pad_token_id + 1 if max(0, pad_token_id - 1) == 0 else pad_token_id - 1 + input_ids[input_ids == pad_token_id] = not_pad_token_id + del inputs["input_ids"] + inputs_embeds = wte(input_ids) + out_ids = model.generate(input_ids=input_ids, **inputs, max_new_tokens=2)[:, -2:] + out_embeds = model.generate(inputs_embeds=inputs_embeds, **inputs, max_new_tokens=2) + else: + encoder_input_ids = inputs["input_ids"] + decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids) + encoder_input_ids[encoder_input_ids == pad_token_id] = max(0, pad_token_id + 1) + decoder_input_ids[decoder_input_ids == pad_token_id] = max(0, pad_token_id + 1) + del inputs["input_ids"] + inputs.pop("decoder_input_ids", None) + inputs_embeds = wte(encoder_input_ids) + decoder_inputs_embeds = wte(decoder_input_ids) + out_ids = model.generate( + input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids, **inputs, max_new_tokens=2 + )[:, -2:] + out_embeds = model.generate( + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + **inputs, + max_new_tokens=2, + ) + self.assertTrue(torch.allclose(out_embeds, out_ids)) + @require_torch_multi_gpu def test_multi_gpu_data_parallel_forward(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()