DataParallel fixes (#5733)
* DataParallel fixes: 1. switched to a more precise check - if self.args.n_gpu > 1: + if isinstance(model, nn.DataParallel): 2. fix tests - require the same fixup under DataParallel as the training module * another fix
This commit is contained in:
@@ -199,6 +199,9 @@ def train(args, train_dataset, model, tokenizer):
|
||||
{"langs": (torch.ones(batch[0].shape, dtype=torch.int64) * args.lang_id).to(args.device)}
|
||||
)
|
||||
|
||||
if isinstance(model, torch.nn.DataParallel):
|
||||
inputs["return_tuple"] = True
|
||||
|
||||
outputs = model(**inputs)
|
||||
# model outputs are always tuple in transformers (see doc)
|
||||
loss = outputs[0]
|
||||
|
||||
Reference in New Issue
Block a user