@@ -67,6 +67,8 @@ class ModelTesterMixin:
|
|||||||
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
||||||
return {
|
return {
|
||||||
k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()
|
k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()
|
||||||
|
if isinstance(v, torch.Tensor) and v.ndim != 0
|
||||||
|
else v
|
||||||
for k, v in inputs_dict.items()
|
for k, v in inputs_dict.items()
|
||||||
}
|
}
|
||||||
return inputs_dict
|
return inputs_dict
|
||||||
@@ -157,7 +159,7 @@ class ModelTesterMixin:
|
|||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(**inputs_dict)
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
attentions = outputs[-1]
|
attentions = outputs[-1]
|
||||||
self.assertEqual(model.config.output_hidden_states, False)
|
self.assertEqual(model.config.output_hidden_states, False)
|
||||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||||
|
|||||||
Reference in New Issue
Block a user