Model output test (#6155)
* Use return_dict=True in all tests * Formatting
This commit is contained in:
@@ -115,6 +115,7 @@ class DPRModelTester:
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
return_dict=True,
|
||||
)
|
||||
config = DPRConfig(projection_dim=self.projection_dim, **config.to_dict())
|
||||
|
||||
@@ -126,15 +127,11 @@ class DPRModelTester:
|
||||
model = DPRContextEncoder(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
embeddings = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)[0]
|
||||
embeddings = model(input_ids, token_type_ids=token_type_ids)[0]
|
||||
embeddings = model(input_ids)[0]
|
||||
|
||||
result = {
|
||||
"embeddings": embeddings,
|
||||
}
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
result = model(input_ids, token_type_ids=token_type_ids)
|
||||
result = model(input_ids)
|
||||
self.parent.assertListEqual(
|
||||
list(result["embeddings"].size()), [self.batch_size, self.projection_dim or self.hidden_size]
|
||||
list(result["pooler_output"].size()), [self.batch_size, self.projection_dim or self.hidden_size]
|
||||
)
|
||||
|
||||
def create_and_check_dpr_question_encoder(
|
||||
@@ -143,15 +140,11 @@ class DPRModelTester:
|
||||
model = DPRQuestionEncoder(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
embeddings = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)[0]
|
||||
embeddings = model(input_ids, token_type_ids=token_type_ids)[0]
|
||||
embeddings = model(input_ids)[0]
|
||||
|
||||
result = {
|
||||
"embeddings": embeddings,
|
||||
}
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
result = model(input_ids, token_type_ids=token_type_ids)
|
||||
result = model(input_ids)
|
||||
self.parent.assertListEqual(
|
||||
list(result["embeddings"].size()), [self.batch_size, self.projection_dim or self.hidden_size]
|
||||
list(result["pooler_output"].size()), [self.batch_size, self.projection_dim or self.hidden_size]
|
||||
)
|
||||
|
||||
def create_and_check_dpr_reader(
|
||||
@@ -160,12 +153,7 @@ class DPRModelTester:
|
||||
model = DPRReader(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
start_logits, end_logits, relevance_logits, *_ = model(input_ids, attention_mask=input_mask,)
|
||||
result = {
|
||||
"relevance_logits": relevance_logits,
|
||||
"start_logits": start_logits,
|
||||
"end_logits": end_logits,
|
||||
}
|
||||
result = model(input_ids, attention_mask=input_mask,)
|
||||
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["relevance_logits"].size()), [self.batch_size])
|
||||
|
||||
Reference in New Issue
Block a user