Fix the CI (#4903)

* Fix CI
This commit is contained in:
Sylvain Gugger
2020-06-10 09:26:06 -04:00
committed by GitHub
parent 0a375f5abd
commit ac99217e92

View File

@@ -67,6 +67,8 @@ class ModelTesterMixin:
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values(): if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
return { return {
k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous() k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()
if isinstance(v, torch.Tensor) and v.ndim != 0
else v
for k, v in inputs_dict.items() for k, v in inputs_dict.items()
} }
return inputs_dict return inputs_dict
@@ -157,7 +159,7 @@ class ModelTesterMixin:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
outputs = model(**inputs_dict) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs[-1] attentions = outputs[-1]
self.assertEqual(model.config.output_hidden_states, False) self.assertEqual(model.config.output_hidden_states, False)
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)