Factor out methods (#10215)
This commit is contained in:
@@ -86,6 +86,36 @@ def find_pruneable_heads_and_indices(
|
|||||||
return heads, index
|
return heads, index
|
||||||
|
|
||||||
|
|
||||||
|
def get_parameter_device(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
|
||||||
|
try:
|
||||||
|
return next(parameter.parameters()).device
|
||||||
|
except StopIteration:
|
||||||
|
# For nn.DataParallel compatibility in PyTorch 1.5
|
||||||
|
|
||||||
|
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
|
||||||
|
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
||||||
|
return tuples
|
||||||
|
|
||||||
|
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
||||||
|
first_tuple = next(gen)
|
||||||
|
return first_tuple[1].device
|
||||||
|
|
||||||
|
|
||||||
|
def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
|
||||||
|
try:
|
||||||
|
return next(parameter.parameters()).dtype
|
||||||
|
except StopIteration:
|
||||||
|
# For nn.DataParallel compatibility in PyTorch 1.5
|
||||||
|
|
||||||
|
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
|
||||||
|
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
||||||
|
return tuples
|
||||||
|
|
||||||
|
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
||||||
|
first_tuple = next(gen)
|
||||||
|
return first_tuple[1].dtype
|
||||||
|
|
||||||
|
|
||||||
class ModuleUtilsMixin:
|
class ModuleUtilsMixin:
|
||||||
"""
|
"""
|
||||||
A few utilities for :obj:`torch.nn.Modules`, to be used as a mixin.
|
A few utilities for :obj:`torch.nn.Modules`, to be used as a mixin.
|
||||||
@@ -145,36 +175,14 @@ class ModuleUtilsMixin:
|
|||||||
:obj:`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
|
:obj:`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
|
||||||
device).
|
device).
|
||||||
"""
|
"""
|
||||||
try:
|
return get_parameter_device(self)
|
||||||
return next(self.parameters()).device
|
|
||||||
except StopIteration:
|
|
||||||
# For nn.DataParallel compatibility in PyTorch 1.5
|
|
||||||
|
|
||||||
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
|
|
||||||
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
|
||||||
return tuples
|
|
||||||
|
|
||||||
gen = self._named_members(get_members_fn=find_tensor_attributes)
|
|
||||||
first_tuple = next(gen)
|
|
||||||
return first_tuple[1].device
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self) -> dtype:
|
def dtype(self) -> dtype:
|
||||||
"""
|
"""
|
||||||
:obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
:obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
||||||
"""
|
"""
|
||||||
try:
|
return get_parameter_dtype(self)
|
||||||
return next(self.parameters()).dtype
|
|
||||||
except StopIteration:
|
|
||||||
# For nn.DataParallel compatibility in PyTorch 1.5
|
|
||||||
|
|
||||||
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
|
|
||||||
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
|
||||||
return tuples
|
|
||||||
|
|
||||||
gen = self._named_members(get_members_fn=find_tensor_attributes)
|
|
||||||
first_tuple = next(gen)
|
|
||||||
return first_tuple[1].dtype
|
|
||||||
|
|
||||||
def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
|
def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
|
||||||
"""
|
"""
|
||||||
@@ -1238,7 +1246,7 @@ class PoolerStartLogits(nn.Module):
|
|||||||
x = self.dense(hidden_states).squeeze(-1)
|
x = self.dense(hidden_states).squeeze(-1)
|
||||||
|
|
||||||
if p_mask is not None:
|
if p_mask is not None:
|
||||||
if next(self.parameters()).dtype == torch.float16:
|
if get_parameter_dtype(self) == torch.float16:
|
||||||
x = x * (1 - p_mask) - 65500 * p_mask
|
x = x * (1 - p_mask) - 65500 * p_mask
|
||||||
else:
|
else:
|
||||||
x = x * (1 - p_mask) - 1e30 * p_mask
|
x = x * (1 - p_mask) - 1e30 * p_mask
|
||||||
@@ -1305,7 +1313,7 @@ class PoolerEndLogits(nn.Module):
|
|||||||
x = self.dense_1(x).squeeze(-1)
|
x = self.dense_1(x).squeeze(-1)
|
||||||
|
|
||||||
if p_mask is not None:
|
if p_mask is not None:
|
||||||
if next(self.parameters()).dtype == torch.float16:
|
if get_parameter_dtype(self) == torch.float16:
|
||||||
x = x * (1 - p_mask) - 65500 * p_mask
|
x = x * (1 - p_mask) - 65500 * p_mask
|
||||||
else:
|
else:
|
||||||
x = x * (1 - p_mask) - 1e30 * p_mask
|
x = x * (1 - p_mask) - 1e30 * p_mask
|
||||||
|
|||||||
Reference in New Issue
Block a user