Generate: consistently handle special tokens as tensors (#30624)
* tmp commit * [test_all] mvp * missing not * [test_all] final test fixes * fix musicgen_melody and rag * [test_all] empty commit * PR comments * Update src/transformers/generation/utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import tempfile
|
||||
import unittest
|
||||
@@ -168,7 +169,9 @@ class GenerationTesterMixin:
|
||||
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave(
|
||||
num_interleave, dim=0
|
||||
)
|
||||
input_ids = torch.zeros_like(input_ids[:, :1]) + model._get_decoder_start_token_id()
|
||||
generation_config = copy.deepcopy(model.generation_config)
|
||||
model._prepare_special_tokens(generation_config)
|
||||
input_ids = torch.zeros_like(input_ids[:, :1]) + generation_config.decoder_start_token_id
|
||||
attention_mask = None
|
||||
return encoder_outputs, input_ids, attention_mask
|
||||
|
||||
|
||||
Reference in New Issue
Block a user