Model output test (#6155)

* Use return_dict=True in all tests

* Formatting
This commit is contained in:
Sylvain Gugger
2020-07-31 09:44:37 -04:00
committed by GitHub
parent 86caab1e0b
commit d951c14ae4
26 changed files with 320 additions and 765 deletions

View File

@@ -115,6 +115,7 @@ class DPRModelTester:
type_vocab_size=self.type_vocab_size,
is_decoder=False,
initializer_range=self.initializer_range,
return_dict=True,
)
config = DPRConfig(projection_dim=self.projection_dim, **config.to_dict())
@@ -126,15 +127,11 @@ class DPRModelTester:
model = DPRContextEncoder(config=config)
model.to(torch_device)
model.eval()
embeddings = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)[0]
embeddings = model(input_ids, token_type_ids=token_type_ids)[0]
embeddings = model(input_ids)[0]
result = {
"embeddings": embeddings,
}
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
result = model(input_ids, token_type_ids=token_type_ids)
result = model(input_ids)
self.parent.assertListEqual(
list(result["embeddings"].size()), [self.batch_size, self.projection_dim or self.hidden_size]
list(result["pooler_output"].size()), [self.batch_size, self.projection_dim or self.hidden_size]
)
def create_and_check_dpr_question_encoder(
@@ -143,15 +140,11 @@ class DPRModelTester:
model = DPRQuestionEncoder(config=config)
model.to(torch_device)
model.eval()
embeddings = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)[0]
embeddings = model(input_ids, token_type_ids=token_type_ids)[0]
embeddings = model(input_ids)[0]
result = {
"embeddings": embeddings,
}
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
result = model(input_ids, token_type_ids=token_type_ids)
result = model(input_ids)
self.parent.assertListEqual(
list(result["embeddings"].size()), [self.batch_size, self.projection_dim or self.hidden_size]
list(result["pooler_output"].size()), [self.batch_size, self.projection_dim or self.hidden_size]
)
def create_and_check_dpr_reader(
@@ -160,12 +153,7 @@ class DPRModelTester:
model = DPRReader(config=config)
model.to(torch_device)
model.eval()
start_logits, end_logits, relevance_logits, *_ = model(input_ids, attention_mask=input_mask,)
result = {
"relevance_logits": relevance_logits,
"start_logits": start_logits,
"end_logits": end_logits,
}
result = model(input_ids, attention_mask=input_mask,)
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["relevance_logits"].size()), [self.batch_size])