[s2s] warn if --fp16 for torch 1.6 (#6977)

This commit is contained in:
Sam Shleifer
2020-09-06 20:41:29 -04:00
committed by GitHub
parent f72fe1f31a
commit ce37be9d94

View File

@@ -3,6 +3,7 @@ import glob
import logging
import os
import time
import warnings
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
@@ -10,6 +11,7 @@ from typing import Dict, List, Tuple
import numpy as np
import pytorch_lightning as pl
import torch
from packaging import version
from torch.utils.data import DataLoader
from lightning_base import BaseTransformer, add_generic_args, generic_train
@@ -354,7 +356,8 @@ def main(args, model=None) -> SummarizationModule:
model: SummarizationModule = SummarizationModule(args)
else:
model: SummarizationModule = TranslationModule(args)
if version.parse(torch.__version__) == version.parse("1.6") and args.fp16:
warnings.warn("FP16 only seems to work with torch 1.5+apex")
dataset = Path(args.data_dir).name
if (
args.logger_name == "default"