fix assisted decoding assistant model inputs (#27503)

* fix assisted decoding attention_cat

* fix attention_mask for assisted decoding

* fix attention_mask len

* fix attn len

* Use a more clean way to prepare assistant models inputs

* fix param meaning

* fix param name

* fix assistant model inputs

* update token type ids

* fix assistant kwargs copy

* add encoder-decoder tests of assisted decoding

* check if assistant kwargs contains updated keys

* revert test

* fix whisper tests

* fix assistant kwargs

* revert whisper test

* delete _extend funcs
This commit is contained in:
jiqing-feng
2023-11-27 22:23:54 +08:00
committed by GitHub
parent 307cf3a2ab
commit 1d7f406e19
4 changed files with 86 additions and 103 deletions

View File

@@ -1036,10 +1036,6 @@ class T5EncoderOnlyModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs)
@unittest.skip("Test does not fail individually but fails on the CI @ArthurZucker looking into it")
def test_assisted_decoding_sample(self):
pass
def use_task_specific_params(model, task):
model.config.update(model.config.task_specific_params[task])