[Generate] Facilitate PyTorch generate using ModelOutputs (#6735)

* fix generate for GPT2 Double Head

* fix gpt2 double head model

* fix  bart / t5

* also add for no beam search

* fix no beam search

* fix encoder decoder

* simplify t5

* simplify t5

* fix t5 tests

* fix BART

* fix transfo-xl

* fix conflict

* integrating sylvains and sams comments

* fix tf past_decoder_key_values

* fix enc dec test
This commit is contained in:
Patrick von Platen
2020-09-01 12:38:25 +02:00
committed by GitHub
parent 397f819615
commit afc4ece462
20 changed files with 393 additions and 259 deletions

View File

@@ -289,9 +289,9 @@ class GPT2ModelTester:
}
result = model(**inputs)
self.parent.assertEqual(result.lm_loss.shape, ())
self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(
result.lm_logits.shape, (self.batch_size, self.num_choices, self.seq_length, self.vocab_size)
result.logits.shape, (self.batch_size, self.num_choices, self.seq_length, self.vocab_size)
)
self.parent.assertEqual(result.mc_logits.shape, (self.batch_size, self.num_choices))
@@ -324,7 +324,7 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else ()
all_generative_model_classes = (
(GPT2LMHeadModel,) if is_torch_available() else ()
(GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else ()
) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly
test_missing_keys = False