Generate: consistently handle special tokens as tensors (#30624)

* tmp commit

* [test_all] mvp

* missing not

* [test_all] final test fixes

* fix musicgen_melody and rag

* [test_all] empty commit

* PR comments

* Update src/transformers/generation/utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Joao Gante
2024-05-09 18:01:57 +01:00
committed by GitHub
parent 5413b8986d
commit 7130a22db9
12 changed files with 297 additions and 191 deletions

View File

@@ -414,9 +414,11 @@ class SeamlessM4TModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase):
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave(
num_interleave, dim=0
)
generation_config = copy.deepcopy(model.generation_config)
model._prepare_special_tokens(generation_config)
input_ids = (
torch.zeros(input_ids.shape[:2], dtype=torch.int64, layout=input_ids.layout, device=input_ids.device)
+ model._get_decoder_start_token_id()
+ generation_config.decoder_start_token_id
)
attention_mask = None
return encoder_outputs, input_ids, attention_mask

View File

@@ -430,9 +430,11 @@ class SeamlessM4Tv2ModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase)
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave(
num_interleave, dim=0
)
generation_config = copy.deepcopy(model.generation_config)
model._prepare_special_tokens(generation_config)
input_ids = (
torch.zeros(input_ids.shape[:2], dtype=torch.int64, layout=input_ids.layout, device=input_ids.device)
+ model._get_decoder_start_token_id()
+ generation_config.decoder_start_token_id
)
attention_mask = None
return encoder_outputs, input_ids, attention_mask

View File

@@ -645,7 +645,9 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
num_interleave, dim=0
)
input_ids = input_ids[:, :, 0]
input_ids = torch.zeros_like(input_ids[:, :1], dtype=torch.long) + model._get_decoder_start_token_id()
generation_config = copy.deepcopy(model.generation_config)
model._prepare_special_tokens(generation_config)
input_ids = torch.zeros_like(input_ids[:, :1]) + generation_config.decoder_start_token_id
attention_mask = None
return encoder_outputs, input_ids, attention_mask

View File

@@ -833,10 +833,10 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave(
num_interleave, dim=0
)
generation_config = copy.deepcopy(model.generation_config)
model._prepare_special_tokens(generation_config)
input_ids = input_ids[:, :, 0]
input_ids = torch.zeros_like(input_ids[:, :1], dtype=torch.long) + torch.tensor(
[model._get_decoder_start_token_id()], device=input_ids.device
)
input_ids = torch.zeros_like(input_ids[:, :1], dtype=torch.long) + generation_config.decoder_start_token_id
attention_mask = None
return encoder_outputs, input_ids, attention_mask