[Generate] Facilitate PyTorch generate using ModelOutputs (#6735)
* fix generate for GPT2 Double Head * fix gpt2 double head model * fix bart / t5 * also add for no beam search * fix no beam search * fix encoder decoder * simplify t5 * simplify t5 * fix t5 tests * fix BART * fix transfo-xl * fix conflict * integrating sylvains and sams comments * fix tf past_decoder_key_values * fix enc dec test
This commit is contained in:
committed by
GitHub
parent
397f819615
commit
afc4ece462
@@ -159,17 +159,15 @@ class T5ModelTester:
|
||||
)
|
||||
result = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
||||
decoder_output = result.last_hidden_state
|
||||
decoder_past = result.decoder_past_key_values
|
||||
decoder_past = result.past_key_values
|
||||
encoder_output = result.encoder_last_hidden_state
|
||||
|
||||
self.parent.assertEqual(encoder_output.size(), (self.batch_size, self.encoder_seq_length, self.hidden_size))
|
||||
self.parent.assertEqual(decoder_output.size(), (self.batch_size, self.decoder_seq_length, self.hidden_size))
|
||||
self.parent.assertEqual(len(decoder_past), 2)
|
||||
self.parent.assertTrue(torch.all(decoder_past[0][0] == encoder_output))
|
||||
# There should be `num_layers` key value embeddings stored in decoder_past[1]
|
||||
self.parent.assertEqual(len(decoder_past[1]), config.num_layers)
|
||||
# There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past[1] tuple
|
||||
self.parent.assertEqual(len(decoder_past[1][0]), 4)
|
||||
# There should be `num_layers` key value embeddings stored in decoder_past
|
||||
self.parent.assertEqual(len(decoder_past), config.num_layers)
|
||||
# There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past tuple
|
||||
self.parent.assertEqual(len(decoder_past[0]), 4)
|
||||
|
||||
def create_and_check_with_lm_head(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user