Support T5 Generation (#3228)
* fix conflicts * update bart max length test * correct spelling mistakes * implemented model specific encode function * fix merge conflicts * better naming * save intermediate state -> need to rethink strucuture a bit * leave tf problem as it is for now * current version * add layers.pop * remove ipdb * make style * clean return cut decoding * remove ipdbs * Fix restoring layers in the decoders that doesnt exists. * push good intermediate solution for now * fix conflicts * always good to refuse to merge conflicts when rebasing * fix small bug * improve function calls * remove unused file * add correct scope behavior for t5_generate Co-authored-by: Morgan Funtowicz <funtowiczmo@gmail.com>
This commit is contained in:
committed by
GitHub
parent
656e1386a2
commit
bbf26c4e61
@@ -147,7 +147,7 @@ class ModelTesterMixin:
|
||||
4 # decoder_features_or_logits, decoder_attentions, encoder_features, encoder_attentions
|
||||
)
|
||||
decoder_attention_idx = 1
|
||||
if "lm_labels" in inputs_dict or "decoder_lm_labels" in inputs_dict: # loss will come first
|
||||
if "lm_labels" in inputs_dict: # loss will come first
|
||||
correct_outlen += 1 # compute loss
|
||||
decoder_attention_idx += 1
|
||||
self.assertEqual(out_len, correct_outlen)
|
||||
@@ -601,9 +601,9 @@ class ModelTesterMixin:
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
del inputs_dict["input_ids"]
|
||||
else:
|
||||
encoder_input_ids = inputs_dict["encoder_input_ids"]
|
||||
encoder_input_ids = inputs_dict["input_ids"]
|
||||
decoder_input_ids = inputs_dict.get("decoder_input_ids", encoder_input_ids)
|
||||
del inputs_dict["encoder_input_ids"]
|
||||
del inputs_dict["input_ids"]
|
||||
inputs_dict.pop("decoder_input_ids", None)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
@@ -615,7 +615,7 @@ class ModelTesterMixin:
|
||||
if not self.is_encoder_decoder:
|
||||
inputs_dict["inputs_embeds"] = wte(input_ids)
|
||||
else:
|
||||
inputs_dict["encoder_inputs_embeds"] = wte(encoder_input_ids)
|
||||
inputs_dict["inputs_embeds"] = wte(encoder_input_ids)
|
||||
inputs_dict["decoder_inputs_embeds"] = wte(decoder_input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
@@ -624,9 +624,7 @@ class ModelTesterMixin:
|
||||
def test_lm_head_model_random_generate(self):
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict.get(
|
||||
"input_ids", None
|
||||
) # TODO (PVP): ugly workaround to make code work for t5 for the moment - has to changed when t5 is fixed.
|
||||
input_ids = inputs_dict.get("input_ids")
|
||||
|
||||
if self.is_encoder_decoder:
|
||||
config.output_past = True # needed for Bart TODO: might have to update for other encoder-decoder models
|
||||
|
||||
Reference in New Issue
Block a user