[Assistant Generation] Improve Encoder Decoder (#26701)

* [Assistant Generation] Improve enc dec

* save more

* Fix logit processor checks

* Clean

* make style

* fix deprecation

* fix generation test

* Apply suggestions from code review

* fix biogpt

* make style
This commit is contained in:
Patrick von Platen
2023-10-11 15:52:20 +02:00
committed by GitHub
parent 5334796d20
commit da69de17e8
4 changed files with 59 additions and 18 deletions

View File

@@ -2953,7 +2953,8 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
return outs
def prepare_inputs_for_generation(self, *args, foo=False, **kwargs):
def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs):
kwargs["encoder_outputs"] = encoder_outputs
inputs = super().prepare_inputs_for_generation(*args, **kwargs)
inputs["foo"] = foo
@@ -2992,3 +2993,14 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
assistant_model=assistant,
)
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
# Check that passing encoder_outputs directly also works as expected
encoder_outputs = assistant.get_encoder()(input_ids)
outputs_assisted = model.generate(
foo=True,
assistant_model=assistant,
encoder_outputs=encoder_outputs,
assistant_encoder_outputs=encoder_outputs,
)
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())