[examples] better PL version check (#8429)
This commit is contained in:
@@ -4,6 +4,7 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import packaging
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
from pytorch_lightning.utilities import rank_zero_info
|
from pytorch_lightning.utilities import rank_zero_info
|
||||||
|
|
||||||
@@ -33,16 +34,18 @@ from transformers.optimization import (
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
try:
|
|
||||||
pkg = "pytorch_lightning"
|
def require_min_ver(pkg, min_ver):
|
||||||
min_ver = "1.0.4"
|
got_ver = pkg_resources.get_distribution(pkg).version
|
||||||
pkg_resources.require(f"{pkg}>={min_ver}")
|
if packaging.version.parse(got_ver) < packaging.version.parse(min_ver):
|
||||||
except pkg_resources.VersionConflict:
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"{pkg}>={min_ver} is required for a normal functioning of this module, but found {pkg}=={pkg_resources.get_distribution(pkg).version}. Try pip install -r examples/requirements.txt"
|
f"{pkg}>={min_ver} is required for a normal functioning of this module, but found {pkg}=={got_ver}. "
|
||||||
|
"Try: pip install -r examples/requirements.txt"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
require_min_ver("pytorch_lightning", "1.0.4")
|
||||||
|
|
||||||
MODEL_MODES = {
|
MODEL_MODES = {
|
||||||
"base": AutoModel,
|
"base": AutoModel,
|
||||||
"sequence-classification": AutoModelForSequenceClassification,
|
"sequence-classification": AutoModelForSequenceClassification,
|
||||||
|
|||||||
Reference in New Issue
Block a user