fix assisted decoding assistant model inputs (#27503)

* fix assisted decoding attention_cat

* fix attention_mask for assisted decoding

* fix attention_mask len

* fix attn len

* Use a more clean way to prepare assistant models inputs

* fix param meaning

* fix param name

* fix assistant model inputs

* update token type ids

* fix assistant kwargs copy

* add encoder-decoder tests of assisted decoding

* check if assistant kwargs contains updated keys

* revert test

* fix whisper tests

* fix assistant kwargs

* revert whisper test

* delete _extend funcs
This commit is contained in:
jiqing-feng
2023-11-27 22:23:54 +08:00
committed by GitHub
parent 307cf3a2ab
commit 1d7f406e19
4 changed files with 86 additions and 103 deletions

View File

@@ -726,10 +726,6 @@ class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel
def test_disk_offload(self):
pass
@unittest.skip("Test does not fail individually but fails on the CI @ArthurZucker looking into it")
def test_assisted_decoding_sample(self):
pass
class SwitchTransformersEncoderOnlyModelTester:
def __init__(