DataParallel fixes (#5733)
* DataParallel fixes: 1. switched to a more precise check - if self.args.n_gpu > 1: + if isinstance(model, nn.DataParallel): 2. fix tests - require the same fixup under DataParallel as the training module * another fix
This commit is contained in:
@@ -803,6 +803,8 @@ class ModelTesterMixin:
|
||||
|
||||
# Wrap model in nn.DataParallel
|
||||
model = torch.nn.DataParallel(model)
|
||||
# Our model outputs do not work with DataParallel, so forcing return tuple.
|
||||
inputs_dict["return_tuple"] = True
|
||||
with torch.no_grad():
|
||||
_ = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user