Fix torch version comparisons (#18460)
Comparisons like
version.parse(torch.__version__) > version.parse("1.6")
are True for torch==1.6.0+cu101 or torch==1.6.0+cpu
version.parse(version.parse(torch.__version__).base_version) are preferred (and available in pytorch_utils.py
This commit is contained in:
@@ -30,7 +30,7 @@ from transformers import (
|
||||
if is_apex_available():
|
||||
from apex import amp
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse("1.6"):
|
||||
if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.6"):
|
||||
_is_native_amp_available = True
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ if is_apex_available():
|
||||
from apex import amp
|
||||
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse("1.6"):
|
||||
if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.6"):
|
||||
_is_native_amp_available = True
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices
|
||||
if is_apex_available():
|
||||
from apex import amp
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse("1.6"):
|
||||
if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.6"):
|
||||
_is_native_amp_available = True
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
|
||||
Reference in New Issue
Block a user