Support T5 Generation (#3228)
* fix conflicts * update bart max length test * correct spelling mistakes * implemented model specific encode function * fix merge conflicts * better naming * save intermediate state -> need to rethink strucuture a bit * leave tf problem as it is for now * current version * add layers.pop * remove ipdb * make style * clean return cut decoding * remove ipdbs * Fix restoring layers in the decoders that doesnt exists. * push good intermediate solution for now * fix conflicts * always good to refuse to merge conflicts when rebasing * fix small bug * improve function calls * remove unused file * add correct scope behavior for t5_generate Co-authored-by: Morgan Funtowicz <funtowiczmo@gmail.com>
This commit is contained in:
committed by
GitHub
parent
656e1386a2
commit
bbf26c4e61
@@ -885,18 +885,17 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
|
||||
return outputs
|
||||
|
||||
def prepare_inputs_for_generation(self, decoder_input_ids, past, encoder_inputs, attention_mask):
|
||||
assert attention_mask.shape == encoder_inputs.shape, "attn_mask.shape != encoder_input.shape: {} =! {}".format(
|
||||
attention_mask.shape, encoder_inputs.shape
|
||||
)
|
||||
if past is None: # first step
|
||||
encoder_outputs, decoder_cached_states = None, None
|
||||
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, **kwargs):
|
||||
assert past is not None, "past has to be defined for encoder_outputs"
|
||||
|
||||
# first step, decoder_cached_states are empty
|
||||
if not past[1]:
|
||||
encoder_outputs, decoder_cached_states = past, None
|
||||
else:
|
||||
encoder_outputs, decoder_cached_states = past
|
||||
|
||||
input_ids = encoder_inputs
|
||||
return {
|
||||
"input_ids": input_ids, # ignored after first pass
|
||||
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"decoder_cached_states": decoder_cached_states,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
@@ -929,6 +928,9 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
past = ((new_enc_out, new_enc_mask), reordered_past)
|
||||
return past
|
||||
|
||||
def get_encoder(self):
|
||||
return self.model.encoder
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
|
||||
Reference in New Issue
Block a user