Fix generate with inputs_embeds as input (#32493)
* I think inputs_embeds has ndim == 3 * fix sequence length catch * add generate test * [run-slow]olmo, persimmon, gemma, gemma2, qwen2, llama * skip whisper * fix bart test * more fixes
This commit is contained in:
@@ -2819,6 +2819,53 @@ class ModelTesterMixin:
|
||||
)[0]
|
||||
self.assertTrue(torch.allclose(out_embeds, out_ids))
|
||||
|
||||
def test_inputs_embeds_matches_input_ids_with_generate(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
if model_class.__name__ not in get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES):
|
||||
continue
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
model_forward_args = inspect.signature(model.forward).parameters
|
||||
if "inputs_embeds" not in model_forward_args:
|
||||
self.skipTest(reason="This model doesn't use `inputs_embeds`")
|
||||
|
||||
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||
pad_token_id = config.pad_token_id if config.pad_token_id is not None else 1
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
if not self.is_encoder_decoder:
|
||||
input_ids = inputs["input_ids"]
|
||||
# some models infer position ids/attn mask differently when input ids
|
||||
# by check if pad_token let's make sure no padding is in input ids
|
||||
not_pad_token_id = pad_token_id + 1 if max(0, pad_token_id - 1) == 0 else pad_token_id - 1
|
||||
input_ids[input_ids == pad_token_id] = not_pad_token_id
|
||||
del inputs["input_ids"]
|
||||
inputs_embeds = wte(input_ids)
|
||||
out_ids = model.generate(input_ids=input_ids, **inputs, max_new_tokens=2)[:, -2:]
|
||||
out_embeds = model.generate(inputs_embeds=inputs_embeds, **inputs, max_new_tokens=2)
|
||||
else:
|
||||
encoder_input_ids = inputs["input_ids"]
|
||||
decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
|
||||
encoder_input_ids[encoder_input_ids == pad_token_id] = max(0, pad_token_id + 1)
|
||||
decoder_input_ids[decoder_input_ids == pad_token_id] = max(0, pad_token_id + 1)
|
||||
del inputs["input_ids"]
|
||||
inputs.pop("decoder_input_ids", None)
|
||||
inputs_embeds = wte(encoder_input_ids)
|
||||
decoder_inputs_embeds = wte(decoder_input_ids)
|
||||
out_ids = model.generate(
|
||||
input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids, **inputs, max_new_tokens=2
|
||||
)[:, -2:]
|
||||
out_embeds = model.generate(
|
||||
inputs_embeds=inputs_embeds,
|
||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||
**inputs,
|
||||
max_new_tokens=2,
|
||||
)
|
||||
self.assertTrue(torch.allclose(out_embeds, out_ids))
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_multi_gpu_data_parallel_forward(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
Reference in New Issue
Block a user