From 59fe641b8bbeca2f4aa8a9d68a2e83fb7945054e Mon Sep 17 00:00:00 2001 From: thomwolf Date: Thu, 5 Sep 2019 02:34:09 +0200 Subject: [PATCH] also gathering file names in file_utils --- pytorch_transformers/__init__.py | 8 +++++--- pytorch_transformers/configuration_utils.py | 4 +--- pytorch_transformers/file_utils.py | 4 ++++ pytorch_transformers/modeling_utils.py | 5 +---- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/pytorch_transformers/__init__.py b/pytorch_transformers/__init__.py index b851c99c9d..24ff52e5d4 100644 --- a/pytorch_transformers/__init__.py +++ b/pytorch_transformers/__init__.py @@ -24,7 +24,7 @@ from .tokenization_roberta import RobertaTokenizer from .tokenization_distilbert import DistilBertTokenizer # Configurations -from .configuration_utils import CONFIG_NAME, PretrainedConfig +from .configuration_utils import PretrainedConfig from .configuration_auto import AutoConfig from .configuration_bert import BertConfig, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP from .configuration_openai import OpenAIGPTConfig, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP @@ -36,7 +36,7 @@ from .configuration_roberta import RobertaConfig, ROBERTA_PRETRAINED_CONFIG_ARCH from .configuration_distilbert import DistilBertConfig, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP # Modeling -from .modeling_utils import (WEIGHTS_NAME, TF_WEIGHTS_NAME, PreTrainedModel, prune_layer, Conv1D) +from .modeling_utils import (PreTrainedModel, prune_layer, Conv1D) from .modeling_auto import (AutoModel, AutoModelForSequenceClassification, AutoModelForQuestionAnswering, AutoModelWithLMHead) @@ -70,4 +70,6 @@ from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, Wa WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule) # Files and general utilities -from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE, cached_path, add_start_docstrings, add_end_docstrings) +from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE, + cached_path, add_start_docstrings, add_end_docstrings, + WEIGHTS_NAME, TF_WEIGHTS_NAME, CONFIG_NAME) diff --git a/pytorch_transformers/configuration_utils.py b/pytorch_transformers/configuration_utils.py index 550b47fab8..7efc735d41 100644 --- a/pytorch_transformers/configuration_utils.py +++ b/pytorch_transformers/configuration_utils.py @@ -24,12 +24,10 @@ import logging import os from io import open -from .file_utils import cached_path +from .file_utils import cached_path, CONFIG_NAME logger = logging.getLogger(__name__) -CONFIG_NAME = "config.json" - class PretrainedConfig(object): r""" Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations. diff --git a/pytorch_transformers/file_utils.py b/pytorch_transformers/file_utils.py index 94ada93d91..a656e757b5 100644 --- a/pytorch_transformers/file_utils.py +++ b/pytorch_transformers/file_utils.py @@ -48,6 +48,10 @@ except (AttributeError, ImportError): PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility +WEIGHTS_NAME = "pytorch_model.bin" +TF_WEIGHTS_NAME = 'model.ckpt' +CONFIG_NAME = "config.json" + logger = logging.getLogger(__name__) # pylint: disable=invalid-name if not six.PY2: diff --git a/pytorch_transformers/modeling_utils.py b/pytorch_transformers/modeling_utils.py index 48420f6d07..2fb4671674 100644 --- a/pytorch_transformers/modeling_utils.py +++ b/pytorch_transformers/modeling_utils.py @@ -31,13 +31,10 @@ from torch.nn import CrossEntropyLoss from torch.nn import functional as F from .configuration_utils import PretrainedConfig -from .file_utils import cached_path +from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME logger = logging.getLogger(__name__) -WEIGHTS_NAME = "pytorch_model.bin" -TF_WEIGHTS_NAME = 'model.ckpt' - try: from torch.nn import Identity