From 35cb101eae8a9678e91b809a8df77d34c259831d Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 20 Jul 2020 06:29:12 -0700 Subject: [PATCH] DataParallel fixes (#5733) * DataParallel fixes: 1. switched to a more precise check - if self.args.n_gpu > 1: + if isinstance(model, nn.DataParallel): 2. fix tests - require the same fixup under DataParallel as the training module * another fix --- examples/question-answering/run_squad.py | 3 +++ src/transformers/trainer.py | 4 ++-- tests/test_modeling_common.py | 2 ++ 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/examples/question-answering/run_squad.py b/examples/question-answering/run_squad.py index 2bd4e90ff9..20b6c88922 100644 --- a/examples/question-answering/run_squad.py +++ b/examples/question-answering/run_squad.py @@ -199,6 +199,9 @@ def train(args, train_dataset, model, tokenizer): {"langs": (torch.ones(batch[0].shape, dtype=torch.int64) * args.lang_id).to(args.device)} ) + if isinstance(model, torch.nn.DataParallel): + inputs["return_tuple"] = True + outputs = model(**inputs) # model outputs are always tuple in transformers (see doc) loss = outputs[0] diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 23333f49ca..566dd54a90 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -623,7 +623,7 @@ class Trainer: if self.args.past_index >= 0 and self._past is not None: inputs["mems"] = self._past # Our model outputs do not work with DataParallel, so forcing return tuple. - if self.args.n_gpu > 1: + if isinstance(model, nn.DataParallel): inputs["return_tuple"] = True outputs = model(**inputs) @@ -826,7 +826,7 @@ class Trainer: if self.args.past_index >= 0: inputs["mems"] = past # Our model outputs do not work with DataParallel, so forcing return tuple. - if self.args.n_gpu > 1: + if isinstance(model, nn.DataParallel): inputs["return_tuple"] = True with torch.no_grad(): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 0021f23c3e..097c387543 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -803,6 +803,8 @@ class ModelTesterMixin: # Wrap model in nn.DataParallel model = torch.nn.DataParallel(model) + # Our model outputs do not work with DataParallel, so forcing return tuple. + inputs_dict["return_tuple"] = True with torch.no_grad(): _ = model(**self._prepare_for_class(inputs_dict, model_class))