Generate: consistently handle special tokens as tensors (#30624)

* tmp commit

* [test_all] mvp

* missing not

* [test_all] final test fixes

* fix musicgen_melody and rag

* [test_all] empty commit

* PR comments

* Update src/transformers/generation/utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Joao Gante
2024-05-09 18:01:57 +01:00
committed by GitHub
parent 5413b8986d
commit 7130a22db9
12 changed files with 297 additions and 191 deletions

View File

@@ -14,6 +14,7 @@
# limitations under the License.
import copy
import inspect
import tempfile
import unittest
@@ -168,7 +169,9 @@ class GenerationTesterMixin:
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave(
num_interleave, dim=0
)
input_ids = torch.zeros_like(input_ids[:, :1]) + model._get_decoder_start_token_id()
generation_config = copy.deepcopy(model.generation_config)
model._prepare_special_tokens(generation_config)
input_ids = torch.zeros_like(input_ids[:, :1]) + generation_config.decoder_start_token_id
attention_mask = None
return encoder_outputs, input_ids, attention_mask