DataParallel fixes (#5733)
* DataParallel fixes: 1. switched to a more precise check - if self.args.n_gpu > 1: + if isinstance(model, nn.DataParallel): 2. fix tests - require the same fixup under DataParallel as the training module * another fix
This commit is contained in:
@@ -199,6 +199,9 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
{"langs": (torch.ones(batch[0].shape, dtype=torch.int64) * args.lang_id).to(args.device)}
|
{"langs": (torch.ones(batch[0].shape, dtype=torch.int64) * args.lang_id).to(args.device)}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance(model, torch.nn.DataParallel):
|
||||||
|
inputs["return_tuple"] = True
|
||||||
|
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
# model outputs are always tuple in transformers (see doc)
|
# model outputs are always tuple in transformers (see doc)
|
||||||
loss = outputs[0]
|
loss = outputs[0]
|
||||||
|
|||||||
@@ -623,7 +623,7 @@ class Trainer:
|
|||||||
if self.args.past_index >= 0 and self._past is not None:
|
if self.args.past_index >= 0 and self._past is not None:
|
||||||
inputs["mems"] = self._past
|
inputs["mems"] = self._past
|
||||||
# Our model outputs do not work with DataParallel, so forcing return tuple.
|
# Our model outputs do not work with DataParallel, so forcing return tuple.
|
||||||
if self.args.n_gpu > 1:
|
if isinstance(model, nn.DataParallel):
|
||||||
inputs["return_tuple"] = True
|
inputs["return_tuple"] = True
|
||||||
|
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
@@ -826,7 +826,7 @@ class Trainer:
|
|||||||
if self.args.past_index >= 0:
|
if self.args.past_index >= 0:
|
||||||
inputs["mems"] = past
|
inputs["mems"] = past
|
||||||
# Our model outputs do not work with DataParallel, so forcing return tuple.
|
# Our model outputs do not work with DataParallel, so forcing return tuple.
|
||||||
if self.args.n_gpu > 1:
|
if isinstance(model, nn.DataParallel):
|
||||||
inputs["return_tuple"] = True
|
inputs["return_tuple"] = True
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|||||||
@@ -803,6 +803,8 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
# Wrap model in nn.DataParallel
|
# Wrap model in nn.DataParallel
|
||||||
model = torch.nn.DataParallel(model)
|
model = torch.nn.DataParallel(model)
|
||||||
|
# Our model outputs do not work with DataParallel, so forcing return tuple.
|
||||||
|
inputs_dict["return_tuple"] = True
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
_ = model(**self._prepare_for_class(inputs_dict, model_class))
|
_ = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user