From ce37be9d94da57897cce9c49b3421e6a8a927d4a Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Sun, 6 Sep 2020 20:41:29 -0400 Subject: [PATCH] [s2s] warn if --fp16 for torch 1.6 (#6977) --- examples/seq2seq/finetune.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/seq2seq/finetune.py b/examples/seq2seq/finetune.py index 73b69d02b3..ef0445e900 100644 --- a/examples/seq2seq/finetune.py +++ b/examples/seq2seq/finetune.py @@ -3,6 +3,7 @@ import glob import logging import os import time +import warnings from collections import defaultdict from pathlib import Path from typing import Dict, List, Tuple @@ -10,6 +11,7 @@ from typing import Dict, List, Tuple import numpy as np import pytorch_lightning as pl import torch +from packaging import version from torch.utils.data import DataLoader from lightning_base import BaseTransformer, add_generic_args, generic_train @@ -354,7 +356,8 @@ def main(args, model=None) -> SummarizationModule: model: SummarizationModule = SummarizationModule(args) else: model: SummarizationModule = TranslationModule(args) - + if version.parse(torch.__version__) == version.parse("1.6") and args.fp16: + warnings.warn("FP16 only seems to work with torch 1.5+apex") dataset = Path(args.data_dir).name if ( args.logger_name == "default"