Fix torch version comparisons (#18460)

Comparisons like
version.parse(torch.__version__) > version.parse("1.6")
are True for torch==1.6.0+cu101 or torch==1.6.0+cpu

version.parse(version.parse(torch.__version__).base_version) are preferred (and available in pytorch_utils.py
This commit is contained in:
LSinev
2022-08-03 20:37:18 +03:00
committed by GitHub
parent be41eaf55f
commit 02b176c4ce
34 changed files with 164 additions and 87 deletions

View File

@@ -30,7 +30,7 @@ from transformers import (
if is_apex_available():
from apex import amp
if version.parse(torch.__version__) >= version.parse("1.6"):
if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.6"):
_is_native_amp_available = True
from torch.cuda.amp import autocast

View File

@@ -33,7 +33,7 @@ if is_apex_available():
from apex import amp
if version.parse(torch.__version__) >= version.parse("1.6"):
if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.6"):
_is_native_amp_available = True
from torch.cuda.amp import autocast

View File

@@ -26,7 +26,7 @@ from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices
if is_apex_available():
from apex import amp
if version.parse(torch.__version__) >= version.parse("1.6"):
if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.6"):
_is_native_amp_available = True
from torch.cuda.amp import autocast