[s2s] warn if --fp16 for torch 1.6 (#6977)
This commit is contained in:
@@ -3,6 +3,7 @@ import glob
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
@@ -10,6 +11,7 @@ from typing import Dict, List, Tuple
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from lightning_base import BaseTransformer, add_generic_args, generic_train
|
||||
@@ -354,7 +356,8 @@ def main(args, model=None) -> SummarizationModule:
|
||||
model: SummarizationModule = SummarizationModule(args)
|
||||
else:
|
||||
model: SummarizationModule = TranslationModule(args)
|
||||
|
||||
if version.parse(torch.__version__) == version.parse("1.6") and args.fp16:
|
||||
warnings.warn("FP16 only seems to work with torch 1.5+apex")
|
||||
dataset = Path(args.data_dir).name
|
||||
if (
|
||||
args.logger_name == "default"
|
||||
|
||||
Reference in New Issue
Block a user