[s2s] warn if --fp16 for torch 1.6 (#6977)

This commit is contained in:
Sam Shleifer
2020-09-06 20:41:29 -04:00
committed by GitHub
parent f72fe1f31a
commit ce37be9d94

View File

@@ -3,6 +3,7 @@ import glob
import logging import logging
import os import os
import time import time
import warnings
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
@@ -10,6 +11,7 @@ from typing import Dict, List, Tuple
import numpy as np import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from packaging import version
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from lightning_base import BaseTransformer, add_generic_args, generic_train from lightning_base import BaseTransformer, add_generic_args, generic_train
@@ -354,7 +356,8 @@ def main(args, model=None) -> SummarizationModule:
model: SummarizationModule = SummarizationModule(args) model: SummarizationModule = SummarizationModule(args)
else: else:
model: SummarizationModule = TranslationModule(args) model: SummarizationModule = TranslationModule(args)
if version.parse(torch.__version__) == version.parse("1.6") and args.fp16:
warnings.warn("FP16 only seems to work with torch 1.5+apex")
dataset = Path(args.data_dir).name dataset = Path(args.data_dir).name
if ( if (
args.logger_name == "default" args.logger_name == "default"