cleanup tf unittests: part 2 (#6260)
* cleanup torch unittests: part 2 * remove trailing comma added by isort, and which breaks flake * one more comma * revert odd balls * part 3: odd cases * more ["key"] -> .key refactoring * .numpy() is not needed * more unncessary .numpy() removed * more simplification
This commit is contained in:
@@ -141,9 +141,9 @@ class T5ModelTester:
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
)
|
||||
result = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
||||
decoder_output = result["last_hidden_state"]
|
||||
decoder_past = result["decoder_past_key_values"]
|
||||
encoder_output = result["encoder_last_hidden_state"]
|
||||
decoder_output = result.last_hidden_state
|
||||
decoder_past = result.decoder_past_key_values
|
||||
encoder_output = result.encoder_last_hidden_state
|
||||
|
||||
self.parent.assertEqual(encoder_output.size(), (self.batch_size, self.encoder_seq_length, self.hidden_size))
|
||||
self.parent.assertEqual(decoder_output.size(), (self.batch_size, self.decoder_seq_length, self.hidden_size))
|
||||
|
||||
Reference in New Issue
Block a user