[s2s] warn if --fp16 for torch 1.6 (#6977)
This commit is contained in:
@@ -3,6 +3,7 @@ import glob
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
import warnings
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
@@ -10,6 +11,7 @@ from typing import Dict, List, Tuple
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from packaging import version
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from lightning_base import BaseTransformer, add_generic_args, generic_train
|
from lightning_base import BaseTransformer, add_generic_args, generic_train
|
||||||
@@ -354,7 +356,8 @@ def main(args, model=None) -> SummarizationModule:
|
|||||||
model: SummarizationModule = SummarizationModule(args)
|
model: SummarizationModule = SummarizationModule(args)
|
||||||
else:
|
else:
|
||||||
model: SummarizationModule = TranslationModule(args)
|
model: SummarizationModule = TranslationModule(args)
|
||||||
|
if version.parse(torch.__version__) == version.parse("1.6") and args.fp16:
|
||||||
|
warnings.warn("FP16 only seems to work with torch 1.5+apex")
|
||||||
dataset = Path(args.data_dir).name
|
dataset = Path(args.data_dir).name
|
||||||
if (
|
if (
|
||||||
args.logger_name == "default"
|
args.logger_name == "default"
|
||||||
|
|||||||
Reference in New Issue
Block a user