Fix torch version comparisons (#18460)
Comparisons like
version.parse(torch.__version__) > version.parse("1.6")
are True for torch==1.6.0+cu101 or torch==1.6.0+cpu
version.parse(version.parse(torch.__version__).base_version) are preferred (and available in pytorch_utils.py
This commit is contained in:
@@ -22,7 +22,6 @@ import os
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from typing import Optional, Tuple, Union
|
||||
@@ -48,6 +47,7 @@ from ...pytorch_utils import (
|
||||
apply_chunking_to_forward,
|
||||
find_pruneable_heads_and_indices,
|
||||
prune_linear_layer,
|
||||
is_torch_greater_than_1_6,
|
||||
)
|
||||
from ...utils import logging
|
||||
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
|
||||
@@ -157,7 +157,7 @@ class {{cookiecutter.camelcase_modelname}}Embeddings(nn.Module):
|
||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||
if version.parse(torch.__version__) > version.parse("1.6.0"):
|
||||
if is_torch_greater_than_1_6:
|
||||
self.register_buffer(
|
||||
"token_type_ids",
|
||||
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
|
||||
|
||||
Reference in New Issue
Block a user