@@ -67,6 +67,8 @@ class ModelTesterMixin:
|
||||
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
||||
return {
|
||||
k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()
|
||||
if isinstance(v, torch.Tensor) and v.ndim != 0
|
||||
else v
|
||||
for k, v in inputs_dict.items()
|
||||
}
|
||||
return inputs_dict
|
||||
@@ -157,7 +159,7 @@ class ModelTesterMixin:
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs_dict)
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs[-1]
|
||||
self.assertEqual(model.config.output_hidden_states, False)
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
Reference in New Issue
Block a user