also gathering file names in file_utils
This commit is contained in:
@@ -24,7 +24,7 @@ from .tokenization_roberta import RobertaTokenizer
|
|||||||
from .tokenization_distilbert import DistilBertTokenizer
|
from .tokenization_distilbert import DistilBertTokenizer
|
||||||
|
|
||||||
# Configurations
|
# Configurations
|
||||||
from .configuration_utils import CONFIG_NAME, PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
from .configuration_auto import AutoConfig
|
from .configuration_auto import AutoConfig
|
||||||
from .configuration_bert import BertConfig, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
from .configuration_bert import BertConfig, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
from .configuration_openai import OpenAIGPTConfig, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
from .configuration_openai import OpenAIGPTConfig, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
@@ -36,7 +36,7 @@ from .configuration_roberta import RobertaConfig, ROBERTA_PRETRAINED_CONFIG_ARCH
|
|||||||
from .configuration_distilbert import DistilBertConfig, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
from .configuration_distilbert import DistilBertConfig, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
|
|
||||||
# Modeling
|
# Modeling
|
||||||
from .modeling_utils import (WEIGHTS_NAME, TF_WEIGHTS_NAME, PreTrainedModel, prune_layer, Conv1D)
|
from .modeling_utils import (PreTrainedModel, prune_layer, Conv1D)
|
||||||
from .modeling_auto import (AutoModel, AutoModelForSequenceClassification, AutoModelForQuestionAnswering,
|
from .modeling_auto import (AutoModel, AutoModelForSequenceClassification, AutoModelForQuestionAnswering,
|
||||||
AutoModelWithLMHead)
|
AutoModelWithLMHead)
|
||||||
|
|
||||||
@@ -70,4 +70,6 @@ from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, Wa
|
|||||||
WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule)
|
WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule)
|
||||||
|
|
||||||
# Files and general utilities
|
# Files and general utilities
|
||||||
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE, cached_path, add_start_docstrings, add_end_docstrings)
|
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
|
||||||
|
cached_path, add_start_docstrings, add_end_docstrings,
|
||||||
|
WEIGHTS_NAME, TF_WEIGHTS_NAME, CONFIG_NAME)
|
||||||
|
|||||||
@@ -24,12 +24,10 @@ import logging
|
|||||||
import os
|
import os
|
||||||
from io import open
|
from io import open
|
||||||
|
|
||||||
from .file_utils import cached_path
|
from .file_utils import cached_path, CONFIG_NAME
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
CONFIG_NAME = "config.json"
|
|
||||||
|
|
||||||
class PretrainedConfig(object):
|
class PretrainedConfig(object):
|
||||||
r""" Base class for all configuration classes.
|
r""" Base class for all configuration classes.
|
||||||
Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations.
|
Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations.
|
||||||
|
|||||||
@@ -48,6 +48,10 @@ except (AttributeError, ImportError):
|
|||||||
|
|
||||||
PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility
|
PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility
|
||||||
|
|
||||||
|
WEIGHTS_NAME = "pytorch_model.bin"
|
||||||
|
TF_WEIGHTS_NAME = 'model.ckpt'
|
||||||
|
CONFIG_NAME = "config.json"
|
||||||
|
|
||||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
if not six.PY2:
|
if not six.PY2:
|
||||||
|
|||||||
@@ -31,13 +31,10 @@ from torch.nn import CrossEntropyLoss
|
|||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
from .file_utils import cached_path
|
from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
WEIGHTS_NAME = "pytorch_model.bin"
|
|
||||||
TF_WEIGHTS_NAME = 'model.ckpt'
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from torch.nn import Identity
|
from torch.nn import Identity
|
||||||
|
|||||||
Reference in New Issue
Block a user