cleanup tf unittests: part 2 (#6260)
* cleanup torch unittests: part 2 * remove trailing comma added by isort, and which breaks flake * one more comma * revert odd balls * part 3: odd cases * more ["key"] -> .key refactoring * .numpy() is not needed * more unncessary .numpy() removed * more simplification
This commit is contained in:
@@ -93,9 +93,9 @@ class TFT5ModelTester:
|
||||
result = model(inputs)
|
||||
|
||||
result = model(input_ids, decoder_attention_mask=input_mask, decoder_input_ids=input_ids)
|
||||
decoder_output = result["last_hidden_state"]
|
||||
decoder_past = result["decoder_past_key_values"]
|
||||
encoder_output = result["encoder_last_hidden_state"]
|
||||
decoder_output = result.last_hidden_state
|
||||
decoder_past = result.decoder_past_key_values
|
||||
encoder_output = result.encoder_last_hidden_state
|
||||
self.parent.assertListEqual(list(encoder_output.shape), [self.batch_size, self.seq_length, self.hidden_size])
|
||||
self.parent.assertListEqual(list(decoder_output.shape), [self.batch_size, self.seq_length, self.hidden_size])
|
||||
self.parent.assertEqual(len(decoder_past), 2)
|
||||
@@ -116,7 +116,7 @@ class TFT5ModelTester:
|
||||
|
||||
result = model(inputs_dict)
|
||||
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.vocab_size])
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_t5_decoder_model_past(self, config, input_ids, decoder_input_ids, attention_mask):
|
||||
model = TFT5Model(config=config).get_decoder()
|
||||
|
||||
Reference in New Issue
Block a user