Fix torch version comparisons (#18460)

Comparisons like
version.parse(torch.__version__) > version.parse("1.6")
are True for torch==1.6.0+cu101 or torch==1.6.0+cpu

version.parse(version.parse(torch.__version__).base_version) are preferred (and available in pytorch_utils.py
This commit is contained in:
LSinev
2022-08-03 20:37:18 +03:00
committed by GitHub
parent be41eaf55f
commit 02b176c4ce
34 changed files with 164 additions and 87 deletions

View File

@@ -22,7 +22,6 @@ import os
import torch
import torch.utils.checkpoint
from packaging import version
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from typing import Optional, Tuple, Union
@@ -48,6 +47,7 @@ from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
is_torch_greater_than_1_6,
)
from ...utils import logging
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
@@ -157,7 +157,7 @@ class {{cookiecutter.camelcase_modelname}}Embeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
if version.parse(torch.__version__) > version.parse("1.6.0"):
if is_torch_greater_than_1_6:
self.register_buffer(
"token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),