Generate: consistently handle special tokens as tensors (#30624)
* tmp commit * [test_all] mvp * missing not * [test_all] final test fixes * fix musicgen_melody and rag * [test_all] empty commit * PR comments * Update src/transformers/generation/utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -414,9 +414,11 @@ class SeamlessM4TModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase):
|
||||
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave(
|
||||
num_interleave, dim=0
|
||||
)
|
||||
generation_config = copy.deepcopy(model.generation_config)
|
||||
model._prepare_special_tokens(generation_config)
|
||||
input_ids = (
|
||||
torch.zeros(input_ids.shape[:2], dtype=torch.int64, layout=input_ids.layout, device=input_ids.device)
|
||||
+ model._get_decoder_start_token_id()
|
||||
+ generation_config.decoder_start_token_id
|
||||
)
|
||||
attention_mask = None
|
||||
return encoder_outputs, input_ids, attention_mask
|
||||
|
||||
@@ -430,9 +430,11 @@ class SeamlessM4Tv2ModelWithSpeechInputTest(ModelTesterMixin, unittest.TestCase)
|
||||
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave(
|
||||
num_interleave, dim=0
|
||||
)
|
||||
generation_config = copy.deepcopy(model.generation_config)
|
||||
model._prepare_special_tokens(generation_config)
|
||||
input_ids = (
|
||||
torch.zeros(input_ids.shape[:2], dtype=torch.int64, layout=input_ids.layout, device=input_ids.device)
|
||||
+ model._get_decoder_start_token_id()
|
||||
+ generation_config.decoder_start_token_id
|
||||
)
|
||||
attention_mask = None
|
||||
return encoder_outputs, input_ids, attention_mask
|
||||
|
||||
@@ -645,7 +645,9 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
|
||||
num_interleave, dim=0
|
||||
)
|
||||
input_ids = input_ids[:, :, 0]
|
||||
input_ids = torch.zeros_like(input_ids[:, :1], dtype=torch.long) + model._get_decoder_start_token_id()
|
||||
generation_config = copy.deepcopy(model.generation_config)
|
||||
model._prepare_special_tokens(generation_config)
|
||||
input_ids = torch.zeros_like(input_ids[:, :1]) + generation_config.decoder_start_token_id
|
||||
attention_mask = None
|
||||
return encoder_outputs, input_ids, attention_mask
|
||||
|
||||
|
||||
@@ -833,10 +833,10 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave(
|
||||
num_interleave, dim=0
|
||||
)
|
||||
generation_config = copy.deepcopy(model.generation_config)
|
||||
model._prepare_special_tokens(generation_config)
|
||||
input_ids = input_ids[:, :, 0]
|
||||
input_ids = torch.zeros_like(input_ids[:, :1], dtype=torch.long) + torch.tensor(
|
||||
[model._get_decoder_start_token_id()], device=input_ids.device
|
||||
)
|
||||
input_ids = torch.zeros_like(input_ids[:, :1], dtype=torch.long) + generation_config.decoder_start_token_id
|
||||
attention_mask = None
|
||||
return encoder_outputs, input_ids, attention_mask
|
||||
|
||||
|
||||
Reference in New Issue
Block a user