[Generate] Facilitate PyTorch generate using ModelOutputs (#6735)
* fix generate for GPT2 Double Head * fix gpt2 double head model * fix bart / t5 * also add for no beam search * fix no beam search * fix encoder decoder * simplify t5 * simplify t5 * fix t5 tests * fix BART * fix transfo-xl * fix conflict * integrating sylvains and sams comments * fix tf past_decoder_key_values * fix enc dec test
This commit is contained in:
committed by
GitHub
parent
397f819615
commit
afc4ece462
@@ -289,9 +289,9 @@ class GPT2ModelTester:
|
||||
}
|
||||
|
||||
result = model(**inputs)
|
||||
self.parent.assertEqual(result.lm_loss.shape, ())
|
||||
self.parent.assertEqual(result.loss.shape, ())
|
||||
self.parent.assertEqual(
|
||||
result.lm_logits.shape, (self.batch_size, self.num_choices, self.seq_length, self.vocab_size)
|
||||
result.logits.shape, (self.batch_size, self.num_choices, self.seq_length, self.vocab_size)
|
||||
)
|
||||
self.parent.assertEqual(result.mc_logits.shape, (self.batch_size, self.num_choices))
|
||||
|
||||
@@ -324,7 +324,7 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else ()
|
||||
all_generative_model_classes = (
|
||||
(GPT2LMHeadModel,) if is_torch_available() else ()
|
||||
(GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else ()
|
||||
) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly
|
||||
test_missing_keys = False
|
||||
|
||||
|
||||
Reference in New Issue
Block a user