fix assisted decoding assistant model inputs (#27503)

* fix assisted decoding attention_cat

* fix attention_mask for assisted decoding

* fix attention_mask len

* fix attn len

* Use a more clean way to prepare assistant models inputs

* fix param meaning

* fix param name

* fix assistant model inputs

* update token type ids

* fix assistant kwargs copy

* add encoder-decoder tests of assisted decoding

* check if assistant kwargs contains updated keys

* revert test

* fix whisper tests

* fix assistant kwargs

* revert whisper test

* delete _extend funcs
This commit is contained in:
jiqing-feng
2023-11-27 22:23:54 +08:00
committed by GitHub
parent 307cf3a2ab
commit 1d7f406e19
4 changed files with 86 additions and 103 deletions

View File

@@ -348,10 +348,6 @@ class NllbMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
self.assertIsNotNone(model(**input_dict)["encoder_router_logits"][1])
self.assertIsNotNone(model(**input_dict)["decoder_router_logits"][0])
@unittest.skip("Test does not fail individually but fails on the CI @ArthurZucker looking into it")
def test_assisted_decoding_sample(self):
pass
@require_torch
@require_sentencepiece