[Assistant Generation] Improve Encoder Decoder (#26701)
* [Assistant Generation] Improve enc dec * save more * Fix logit processor checks * Clean * make style * fix deprecation * fix generation test * Apply suggestions from code review * fix biogpt * make style
This commit is contained in:
committed by
GitHub
parent
5334796d20
commit
da69de17e8
@@ -2953,7 +2953,8 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
|
||||
return outs
|
||||
|
||||
def prepare_inputs_for_generation(self, *args, foo=False, **kwargs):
|
||||
def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs):
|
||||
kwargs["encoder_outputs"] = encoder_outputs
|
||||
inputs = super().prepare_inputs_for_generation(*args, **kwargs)
|
||||
|
||||
inputs["foo"] = foo
|
||||
@@ -2992,3 +2993,14 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
assistant_model=assistant,
|
||||
)
|
||||
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
|
||||
|
||||
# Check that passing encoder_outputs directly also works as expected
|
||||
encoder_outputs = assistant.get_encoder()(input_ids)
|
||||
|
||||
outputs_assisted = model.generate(
|
||||
foo=True,
|
||||
assistant_model=assistant,
|
||||
encoder_outputs=encoder_outputs,
|
||||
assistant_encoder_outputs=encoder_outputs,
|
||||
)
|
||||
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
|
||||
|
||||
Reference in New Issue
Block a user