cleanup tf unittests: part 2 (#6260)

* cleanup torch unittests: part 2

* remove trailing comma added by isort, and which breaks flake

* one more comma

* revert odd balls

* part 3: odd cases

* more ["key"] -> .key refactoring

* .numpy() is not needed

* more unncessary .numpy() removed

* more simplification
This commit is contained in:
Stas Bekman
2020-08-13 01:29:06 -07:00
committed by GitHub
parent bc820476a5
commit e983da0e7d
21 changed files with 159 additions and 239 deletions

View File

@@ -93,9 +93,9 @@ class TFT5ModelTester:
result = model(inputs)
result = model(input_ids, decoder_attention_mask=input_mask, decoder_input_ids=input_ids)
decoder_output = result["last_hidden_state"]
decoder_past = result["decoder_past_key_values"]
encoder_output = result["encoder_last_hidden_state"]
decoder_output = result.last_hidden_state
decoder_past = result.decoder_past_key_values
encoder_output = result.encoder_last_hidden_state
self.parent.assertListEqual(list(encoder_output.shape), [self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertListEqual(list(decoder_output.shape), [self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertEqual(len(decoder_past), 2)
@@ -116,7 +116,7 @@ class TFT5ModelTester:
result = model(inputs_dict)
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_t5_decoder_model_past(self, config, input_ids, decoder_input_ids, attention_mask):
model = TFT5Model(config=config).get_decoder()