Fix nn.DataParallel compatibility in PyTorch 1.5 (#4300)

* Test case for #3936

* multigpu tests pass on pytorch 1.4.0

* Fixup

* multigpu tests pass on pytorch 1.5.0

* Update src/transformers/modeling_utils.py

* Update src/transformers/modeling_utils.py

* rename multigpu to require_multigpu

* mode doc
This commit is contained in:
Julien Chaumond
2020-05-18 20:34:50 -04:00
committed by GitHub
parent 9de4afa897
commit 4c06893610
12 changed files with 95 additions and 21 deletions

View File

@@ -17,7 +17,7 @@
import inspect
import logging
import os
from typing import Callable, Tuple
from typing import Callable, List, Tuple
import torch
from torch import Tensor, device, dtype, nn
@@ -110,11 +110,33 @@ class ModuleUtilsMixin:
@property
def device(self) -> device:
return next(self.parameters()).device
try:
return next(self.parameters()).device
except StopIteration:
# For nn.DataParallel compatibility in PyTorch 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = self._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].device
@property
def dtype(self) -> dtype:
return next(self.parameters()).dtype
try:
return next(self.parameters()).dtype
except StopIteration:
# For nn.DataParallel compatibility in PyTorch 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = self._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].dtype
def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
"""type: torch.Tensor -> torch.Tensor"""