also gathering file names in file_utils

This commit is contained in:
thomwolf
2019-09-05 02:34:09 +02:00
parent d68a8fe462
commit 59fe641b8b
4 changed files with 11 additions and 10 deletions

View File

@@ -31,13 +31,10 @@ from torch.nn import CrossEntropyLoss
from torch.nn import functional as F
from .configuration_utils import PretrainedConfig
from .file_utils import cached_path
from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME
logger = logging.getLogger(__name__)
WEIGHTS_NAME = "pytorch_model.bin"
TF_WEIGHTS_NAME = 'model.ckpt'
try:
from torch.nn import Identity