DataParallel fixes (#5733)

* DataParallel fixes:

1. switched to a more precise check
-        if self.args.n_gpu > 1:
+        if isinstance(model, nn.DataParallel):

2. fix tests - require the same fixup under DataParallel as the training module

* another fix
This commit is contained in:
Stas Bekman
2020-07-20 06:29:12 -07:00
committed by GitHub
parent 290b6e18ac
commit 35cb101eae
3 changed files with 7 additions and 2 deletions

View File

@@ -199,6 +199,9 @@ def train(args, train_dataset, model, tokenizer):
{"langs": (torch.ones(batch[0].shape, dtype=torch.int64) * args.lang_id).to(args.device)}
)
if isinstance(model, torch.nn.DataParallel):
inputs["return_tuple"] = True
outputs = model(**inputs)
# model outputs are always tuple in transformers (see doc)
loss = outputs[0]