cleanup torch unittests (#6196)

* improve unit tests

this is a sample of one test according to the request in https://github.com/huggingface/transformers/issues/5973
before I apply it to the rest

* batch 1

* batch 2

* batch 3

* batch 4

* batch 5

* style

* non-tf template

* last deletion of check_loss_output
This commit is contained in:
Stas Bekman
2020-08-03 23:42:56 -07:00
committed by GitHub
parent b390a5672a
commit 5deed37f9f
18 changed files with 157 additions and 339 deletions

View File

@@ -175,9 +175,6 @@ class ReformerModelTester:
choice_labels,
)
def check_loss_output(self, result):
self.parent.assertListEqual(list(result["loss"].size()), [])
def create_and_check_reformer_model(self, config, input_ids, input_mask, choice_labels):
model = ReformerModel(config=config)
model.to(torch_device)
@@ -186,8 +183,8 @@ class ReformerModelTester:
result = model(input_ids)
# 2 * hidden_size because we use reversible resnet layers
self.parent.assertListEqual(
list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, 2 * self.hidden_size],
self.parent.assertEqual(
result.last_hidden_state.shape, (self.batch_size, self.seq_length, 2 * self.hidden_size)
)
def create_and_check_reformer_model_with_lm_backward(self, config, input_ids, input_mask, choice_labels):
@@ -206,10 +203,7 @@ class ReformerModelTester:
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, labels=input_ids)
self.parent.assertListEqual(
list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size],
)
self.check_loss_output(result)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_reformer_with_mlm(self, config, input_ids, input_mask, choice_labels):
config.is_decoder = False
@@ -217,10 +211,7 @@ class ReformerModelTester:
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, labels=input_ids)
self.parent.assertListEqual(
list(result["logits"].size()), [self.batch_size, self.seq_length, self.vocab_size],
)
self.check_loss_output(result)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_reformer_model_with_attn_mask(
self, config, input_ids, input_mask, choice_labels, is_decoder=False
@@ -444,9 +435,8 @@ class ReformerModelTester:
result = model(
input_ids, attention_mask=input_mask, start_positions=choice_labels, end_positions=choice_labels,
)
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
self.check_loss_output(result)
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
def create_and_check_past_buckets_states(self, config, input_ids, input_mask, choice_labels):
config.is_decoder = True
@@ -490,8 +480,7 @@ class ReformerModelTester:
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, labels=sequence_labels)
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels])
self.check_loss_output(result)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
class ReformerTesterMixin: