Centralize logging (#6434)
* Logging * Style * hf_logging > utils.logging * Address @thomwolf's comments * Update test * Update src/transformers/benchmark/benchmark_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Revert bad change Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -17,8 +17,6 @@ else:
|
|||||||
absl.logging.set_stderrthreshold("info")
|
absl.logging.set_stderrthreshold("info")
|
||||||
absl.logging._warn_preinit_stderr = False
|
absl.logging._warn_preinit_stderr = False
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
# Configurations
|
# Configurations
|
||||||
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
|
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
|
||||||
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, AutoConfig
|
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, AutoConfig
|
||||||
@@ -184,9 +182,10 @@ from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer
|
|||||||
from .trainer_utils import EvalPrediction, set_seed
|
from .trainer_utils import EvalPrediction, set_seed
|
||||||
from .training_args import TrainingArguments
|
from .training_args import TrainingArguments
|
||||||
from .training_args_tf import TFTrainingArguments
|
from .training_args_tf import TFTrainingArguments
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
if is_sklearn_available():
|
if is_sklearn_available():
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
import logging
|
|
||||||
import math
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def swish(x):
|
def swish(x):
|
||||||
|
|||||||
@@ -18,13 +18,13 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import timeit
|
import timeit
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
from ..configuration_utils import PretrainedConfig
|
from ..configuration_utils import PretrainedConfig
|
||||||
from ..file_utils import is_py3nvml_available, is_torch_available
|
from ..file_utils import is_py3nvml_available, is_torch_available
|
||||||
from ..modeling_auto import MODEL_MAPPING, MODEL_WITH_LM_HEAD_MAPPING
|
from ..modeling_auto import MODEL_MAPPING, MODEL_WITH_LM_HEAD_MAPPING
|
||||||
|
from ..utils import logging
|
||||||
from .benchmark_utils import (
|
from .benchmark_utils import (
|
||||||
Benchmark,
|
Benchmark,
|
||||||
Memory,
|
Memory,
|
||||||
@@ -45,7 +45,7 @@ if is_py3nvml_available():
|
|||||||
import py3nvml.py3nvml as nvml
|
import py3nvml.py3nvml as nvml
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PyTorchBenchmark(Benchmark):
|
class PyTorchBenchmark(Benchmark):
|
||||||
|
|||||||
@@ -14,11 +14,11 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
from ..file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required
|
from ..file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required
|
||||||
|
from ..utils import logging
|
||||||
from .benchmark_args_utils import BenchmarkArguments
|
from .benchmark_args_utils import BenchmarkArguments
|
||||||
|
|
||||||
|
|
||||||
@@ -29,7 +29,7 @@ if is_torch_tpu_available():
|
|||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -14,11 +14,11 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
from ..file_utils import cached_property, is_tf_available, tf_required
|
from ..file_utils import cached_property, is_tf_available, tf_required
|
||||||
|
from ..utils import logging
|
||||||
from .benchmark_args_utils import BenchmarkArguments
|
from .benchmark_args_utils import BenchmarkArguments
|
||||||
|
|
||||||
|
|
||||||
@@ -26,7 +26,7 @@ if is_tf_available():
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -16,13 +16,14 @@
|
|||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from time import time
|
from time import time
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
from ..utils import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def list_field(default=None, metadata=None):
|
def list_field(default=None, metadata=None):
|
||||||
|
|||||||
@@ -18,7 +18,6 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import random
|
import random
|
||||||
import timeit
|
import timeit
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
@@ -27,6 +26,7 @@ from typing import Callable, Optional
|
|||||||
from ..configuration_utils import PretrainedConfig
|
from ..configuration_utils import PretrainedConfig
|
||||||
from ..file_utils import is_py3nvml_available, is_tf_available
|
from ..file_utils import is_py3nvml_available, is_tf_available
|
||||||
from ..modeling_tf_auto import TF_MODEL_MAPPING, TF_MODEL_WITH_LM_HEAD_MAPPING
|
from ..modeling_tf_auto import TF_MODEL_MAPPING, TF_MODEL_WITH_LM_HEAD_MAPPING
|
||||||
|
from ..utils import logging
|
||||||
from .benchmark_utils import (
|
from .benchmark_utils import (
|
||||||
Benchmark,
|
Benchmark,
|
||||||
Memory,
|
Memory,
|
||||||
@@ -46,7 +46,7 @@ if is_tf_available():
|
|||||||
if is_py3nvml_available():
|
if is_py3nvml_available():
|
||||||
import py3nvml.py3nvml as nvml
|
import py3nvml.py3nvml as nvml
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def run_with_tf_optimizations(do_eager_mode: bool, use_xla: bool):
|
def run_with_tf_optimizations(do_eager_mode: bool, use_xla: bool):
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ Copyright by the AllenNLP authors.
|
|||||||
import copy
|
import copy
|
||||||
import csv
|
import csv
|
||||||
import linecache
|
import linecache
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import sys
|
import sys
|
||||||
@@ -22,6 +21,7 @@ from transformers import AutoConfig, PretrainedConfig
|
|||||||
from transformers import __version__ as version
|
from transformers import __version__ as version
|
||||||
|
|
||||||
from ..file_utils import is_psutil_available, is_py3nvml_available, is_tf_available, is_torch_available
|
from ..file_utils import is_psutil_available, is_py3nvml_available, is_tf_available, is_torch_available
|
||||||
|
from ..utils import logging
|
||||||
from .benchmark_args_utils import BenchmarkArguments
|
from .benchmark_args_utils import BenchmarkArguments
|
||||||
|
|
||||||
|
|
||||||
@@ -43,7 +43,7 @@ else:
|
|||||||
from signal import SIGKILL
|
from signal import SIGKILL
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
_is_memory_tracing_enabled = False
|
_is_memory_tracing_enabled = False
|
||||||
@@ -94,7 +94,7 @@ def separate_process_wrapper_fn(func: Callable[[], None], do_multi_processing: b
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
if do_multi_processing:
|
if do_multi_processing:
|
||||||
logging.info("fFunction {func} is executed in its own process...")
|
logger.info(f"Function {func} is executed in its own process...")
|
||||||
return multi_process_func
|
return multi_process_func
|
||||||
else:
|
else:
|
||||||
return func
|
return func
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
from argparse import ArgumentParser, Namespace
|
from argparse import ArgumentParser, Namespace
|
||||||
from logging import getLogger
|
|
||||||
|
|
||||||
from transformers.commands import BaseTransformersCLICommand
|
from transformers.commands import BaseTransformersCLICommand
|
||||||
|
|
||||||
|
from ..utils import logging
|
||||||
|
|
||||||
|
|
||||||
def convert_command_factory(args: Namespace):
|
def convert_command_factory(args: Namespace):
|
||||||
"""
|
"""
|
||||||
@@ -52,7 +53,7 @@ class ConvertCommand(BaseTransformersCLICommand):
|
|||||||
finetuning_task_name: str,
|
finetuning_task_name: str,
|
||||||
*args
|
*args
|
||||||
):
|
):
|
||||||
self._logger = getLogger("transformers-cli/converting")
|
self._logger = logging.get_logger("transformers-cli/converting")
|
||||||
|
|
||||||
self._logger.info("Loading model {}".format(model_type))
|
self._logger.info("Loading model {}".format(model_type))
|
||||||
self._model_type = model_type
|
self._model_type = model_type
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
import logging
|
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
from transformers.commands import BaseTransformersCLICommand
|
from transformers.commands import BaseTransformersCLICommand
|
||||||
from transformers.pipelines import SUPPORTED_TASKS, Pipeline, PipelineDataFormat, pipeline
|
from transformers.pipelines import SUPPORTED_TASKS, Pipeline, PipelineDataFormat, pipeline
|
||||||
|
|
||||||
|
from ..utils import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
def try_infer_format_from_ext(path: str):
|
def try_infer_format_from_ext(path: str):
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import logging
|
|
||||||
from argparse import ArgumentParser, Namespace
|
from argparse import ArgumentParser, Namespace
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
@@ -6,6 +5,8 @@ from transformers import Pipeline
|
|||||||
from transformers.commands import BaseTransformersCLICommand
|
from transformers.commands import BaseTransformersCLICommand
|
||||||
from transformers.pipelines import SUPPORTED_TASKS, pipeline
|
from transformers.pipelines import SUPPORTED_TASKS, pipeline
|
||||||
|
|
||||||
|
from ..utils import logging
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from fastapi import Body, FastAPI, HTTPException
|
from fastapi import Body, FastAPI, HTTPException
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
import os
|
import os
|
||||||
from argparse import ArgumentParser, Namespace
|
from argparse import ArgumentParser, Namespace
|
||||||
from logging import getLogger
|
|
||||||
|
|
||||||
from transformers import SingleSentenceClassificationProcessor as Processor
|
from transformers import SingleSentenceClassificationProcessor as Processor
|
||||||
from transformers import TextClassificationPipeline, is_tf_available, is_torch_available
|
from transformers import TextClassificationPipeline, is_tf_available, is_torch_available
|
||||||
from transformers.commands import BaseTransformersCLICommand
|
from transformers.commands import BaseTransformersCLICommand
|
||||||
|
|
||||||
|
from ..utils import logging
|
||||||
|
|
||||||
|
|
||||||
if not is_tf_available() and not is_torch_available():
|
if not is_tf_available() and not is_torch_available():
|
||||||
raise RuntimeError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training")
|
raise RuntimeError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training")
|
||||||
@@ -76,7 +77,7 @@ class TrainCommand(BaseTransformersCLICommand):
|
|||||||
train_parser.set_defaults(func=train_command_factory)
|
train_parser.set_defaults(func=train_command_factory)
|
||||||
|
|
||||||
def __init__(self, args: Namespace):
|
def __init__(self, args: Namespace):
|
||||||
self.logger = getLogger("transformers-cli/training")
|
self.logger = logging.get_logger("transformers-cli/training")
|
||||||
|
|
||||||
self.framework = "tf" if is_tf_available() else "torch"
|
self.framework = "tf" if is_tf_available() else "torch"
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,6 @@
|
|||||||
""" Auto Config class. """
|
""" Auto Config class. """
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
|
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
|
||||||
@@ -45,9 +44,6 @@ from .configuration_xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
|
|||||||
from .configuration_xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig
|
from .configuration_xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
|
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
|
||||||
(key, value)
|
(key, value)
|
||||||
for pretrained_map in [
|
for pretrained_map in [
|
||||||
|
|||||||
@@ -14,14 +14,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" BART configuration """
|
""" BART configuration """
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
from .file_utils import add_start_docstrings_to_callable
|
from .file_utils import add_start_docstrings_to_callable
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
"facebook/bart-base": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-base/config.json",
|
"facebook/bart-base": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-base/config.json",
|
||||||
|
|||||||
@@ -15,13 +15,11 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" BERT model configuration """
|
""" BERT model configuration """
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
"bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
|
"bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
|
||||||
|
|||||||
@@ -15,13 +15,11 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" CamemBERT configuration """
|
""" CamemBERT configuration """
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from .configuration_roberta import RobertaConfig
|
from .configuration_roberta import RobertaConfig
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
"camembert-base": "https://s3.amazonaws.com/models.huggingface.co/bert/camembert-base-config.json",
|
"camembert-base": "https://s3.amazonaws.com/models.huggingface.co/bert/camembert-base-config.json",
|
||||||
|
|||||||
@@ -14,13 +14,11 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" Salesforce CTRL configuration """
|
""" Salesforce CTRL configuration """
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP = {"ctrl": "https://s3.amazonaws.com/models.huggingface.co/bert/ctrl-config.json"}
|
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP = {"ctrl": "https://s3.amazonaws.com/models.huggingface.co/bert/ctrl-config.json"}
|
||||||
|
|
||||||
|
|||||||
@@ -14,13 +14,11 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" DistilBERT model configuration """
|
""" DistilBERT model configuration """
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
"distilbert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json",
|
"distilbert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json",
|
||||||
|
|||||||
@@ -14,13 +14,11 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" DPR model configuration """
|
""" DPR model configuration """
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from .configuration_bert import BertConfig
|
from .configuration_bert import BertConfig
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
DPR_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
DPR_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
"facebook/dpr-ctx_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/dpr-ctx_encoder-single-nq-base/config.json",
|
"facebook/dpr-ctx_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/dpr-ctx_encoder-single-nq-base/config.json",
|
||||||
|
|||||||
@@ -15,13 +15,11 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" ELECTRA model configuration """
|
""" ELECTRA model configuration """
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
"google/electra-small-generator": "https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-small-generator/config.json",
|
"google/electra-small-generator": "https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-small-generator/config.json",
|
||||||
|
|||||||
@@ -15,12 +15,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import logging
|
|
||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class EncoderDecoderConfig(PretrainedConfig):
|
class EncoderDecoderConfig(PretrainedConfig):
|
||||||
|
|||||||
@@ -14,13 +14,11 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" Flaubert configuration, based on XLM. """
|
""" Flaubert configuration, based on XLM. """
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from .configuration_xlm import XLMConfig
|
from .configuration_xlm import XLMConfig
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
"flaubert/flaubert_small_cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_small_cased/config.json",
|
"flaubert/flaubert_small_cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_small_cased/config.json",
|
||||||
|
|||||||
@@ -15,13 +15,11 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" OpenAI GPT-2 configuration """
|
""" OpenAI GPT-2 configuration """
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
|
"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
|
||||||
|
|||||||
@@ -14,13 +14,13 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" Longformer configuration """
|
""" Longformer configuration """
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
from .configuration_roberta import RobertaConfig
|
from .configuration_roberta import RobertaConfig
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
"allenai/longformer-base-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-base-4096/config.json",
|
"allenai/longformer-base-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-base-4096/config.json",
|
||||||
|
|||||||
@@ -14,12 +14,11 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" MBART configuration """
|
""" MBART configuration """
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from .configuration_bart import BartConfig
|
from .configuration_bart import BartConfig
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
MBART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
MBART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
"facebook/mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json",
|
"facebook/mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json",
|
||||||
|
|||||||
@@ -15,11 +15,10 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" MMBT configuration """
|
""" MMBT configuration """
|
||||||
|
|
||||||
|
from .utils import logging
|
||||||
import logging
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MMBTConfig(object):
|
class MMBTConfig(object):
|
||||||
|
|||||||
@@ -12,12 +12,11 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" MobileBERT model configuration """
|
""" MobileBERT model configuration """
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
"mobilebert-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/google/mobilebert-uncased/config.json"
|
"mobilebert-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/google/mobilebert-uncased/config.json"
|
||||||
|
|||||||
@@ -15,13 +15,11 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" OpenAI GPT configuration """
|
""" OpenAI GPT configuration """
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"
|
"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"
|
||||||
|
|||||||
@@ -14,13 +14,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" PEGASUS model configuration """
|
""" PEGASUS model configuration """
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from .configuration_bart import BART_CONFIG_ARGS_DOC, BartConfig
|
from .configuration_bart import BART_CONFIG_ARGS_DOC, BartConfig
|
||||||
from .file_utils import add_start_docstrings_to_callable
|
from .file_utils import add_start_docstrings_to_callable
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
# These config values do not vary between checkpoints
|
# These config values do not vary between checkpoints
|
||||||
DEFAULTS = dict(
|
DEFAULTS = dict(
|
||||||
|
|||||||
@@ -15,13 +15,11 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" Reformer model configuration """
|
""" Reformer model configuration """
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
"google/reformer-crime-and-punishment": "https://cdn.huggingface.co/google/reformer-crime-and-punishment/config.json",
|
"google/reformer-crime-and-punishment": "https://cdn.huggingface.co/google/reformer-crime-and-punishment/config.json",
|
||||||
|
|||||||
@@ -14,13 +14,11 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" RetriBERT model configuration """
|
""" RetriBERT model configuration """
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
# TODO: uploadto AWS
|
# TODO: uploadto AWS
|
||||||
RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
|
|||||||
@@ -15,13 +15,11 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" RoBERTa configuration """
|
""" RoBERTa configuration """
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from .configuration_bert import BertConfig
|
from .configuration_bert import BertConfig
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
"roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json",
|
"roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json",
|
||||||
|
|||||||
@@ -14,13 +14,11 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" T5 model configuration """
|
""" T5 model configuration """
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
T5_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
T5_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
"t5-small": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-small-config.json",
|
"t5-small": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-small-config.json",
|
||||||
|
|||||||
@@ -16,13 +16,12 @@
|
|||||||
""" Transformer XL configuration """
|
""" Transformer XL configuration """
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
"transfo-xl-wt103": "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json",
|
"transfo-xl-wt103": "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json",
|
||||||
|
|||||||
@@ -18,14 +18,14 @@
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, Tuple
|
from typing import Any, Dict, Tuple
|
||||||
|
|
||||||
from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url
|
from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PretrainedConfig(object):
|
class PretrainedConfig(object):
|
||||||
|
|||||||
@@ -14,13 +14,11 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" XLM configuration """
|
""" XLM configuration """
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
"xlm-mlm-en-2048": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json",
|
"xlm-mlm-en-2048": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json",
|
||||||
|
|||||||
@@ -15,13 +15,11 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" XLM-RoBERTa configuration """
|
""" XLM-RoBERTa configuration """
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from .configuration_roberta import RobertaConfig
|
from .configuration_roberta import RobertaConfig
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
"xlm-roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-base-config.json",
|
"xlm-roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-base-config.json",
|
||||||
|
|||||||
@@ -15,13 +15,13 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" XLNet configuration """
|
""" XLNet configuration """
|
||||||
|
|
||||||
import logging
|
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
"xlnet-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-config.json",
|
"xlnet-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-config.json",
|
||||||
|
|||||||
@@ -16,14 +16,15 @@
|
|||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert
|
from transformers import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert
|
||||||
|
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
|
logging.set_verbosity_info()
|
||||||
|
|
||||||
|
|
||||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pytorch_dump_path):
|
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pytorch_dump_path):
|
||||||
|
|||||||
@@ -16,7 +16,6 @@
|
|||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -33,6 +32,8 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
from transformers.modeling_bart import _make_linear_from_emb
|
from transformers.modeling_bart import _make_linear_from_emb
|
||||||
|
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn", "bart_xsum/model.pt"]
|
FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn", "bart_xsum/model.pt"]
|
||||||
extra_arch = {"bart.large": BartModel, "bart.large.mnli": BartForSequenceClassification}
|
extra_arch = {"bart.large": BartModel, "bart.large.mnli": BartForSequenceClassification}
|
||||||
@@ -40,8 +41,8 @@ if version.parse(fairseq.__version__) < version.parse("0.9.0"):
|
|||||||
raise Exception("requires fairseq >= 0.9.0")
|
raise Exception("requires fairseq >= 0.9.0")
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.set_verbosity_info()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
SAMPLE_TEXT = " Hello world! cécé herlolip"
|
SAMPLE_TEXT = " Hello world! cécé herlolip"
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ The script re-maps the TF2.x Bert weight names to the original names, so the mod
|
|||||||
You may adapt this script to include classification/MLM/NSP/etc. heads.
|
You may adapt this script to include classification/MLM/NSP/etc. heads.
|
||||||
"""
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
|
||||||
@@ -17,9 +16,11 @@ import torch
|
|||||||
|
|
||||||
from transformers import BertConfig, BertModel
|
from transformers import BertConfig, BertModel
|
||||||
|
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
logging.set_verbosity_info()
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def load_tf2_weights_in_bert(model, tf_checkpoint_path, config):
|
def load_tf2_weights_in_bert(model, tf_checkpoint_path, config):
|
||||||
|
|||||||
@@ -16,14 +16,15 @@
|
|||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
|
from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
|
||||||
|
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
|
logging.set_verbosity_info()
|
||||||
|
|
||||||
|
|
||||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
|
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
|
||||||
|
|||||||
@@ -16,14 +16,15 @@
|
|||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import ElectraConfig, ElectraForMaskedLM, ElectraForPreTraining, load_tf_weights_in_electra
|
from transformers import ElectraConfig, ElectraForMaskedLM, ElectraForPreTraining, load_tf_weights_in_electra
|
||||||
|
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
|
logging.set_verbosity_info()
|
||||||
|
|
||||||
|
|
||||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path, discriminator_or_generator):
|
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path, discriminator_or_generator):
|
||||||
|
|||||||
@@ -16,14 +16,15 @@
|
|||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import CONFIG_NAME, WEIGHTS_NAME, GPT2Config, GPT2Model, load_tf_weights_in_gpt2
|
from transformers import CONFIG_NAME, WEIGHTS_NAME, GPT2Config, GPT2Model, load_tf_weights_in_gpt2
|
||||||
|
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
|
logging.set_verbosity_info()
|
||||||
|
|
||||||
|
|
||||||
def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path):
|
def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path):
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import MobileBertConfig, MobileBertForPreTraining, load_tf_weights_in_mobilebert
|
from transformers import MobileBertConfig, MobileBertForPreTraining, load_tf_weights_in_mobilebert
|
||||||
|
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
|
logging.set_verbosity_info()
|
||||||
|
|
||||||
|
|
||||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, mobilebert_config_file, pytorch_dump_path):
|
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, mobilebert_config_file, pytorch_dump_path):
|
||||||
|
|||||||
@@ -16,14 +16,15 @@
|
|||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import CONFIG_NAME, WEIGHTS_NAME, OpenAIGPTConfig, OpenAIGPTModel, load_tf_weights_in_openai_gpt
|
from transformers import CONFIG_NAME, WEIGHTS_NAME, OpenAIGPTConfig, OpenAIGPTModel, load_tf_weights_in_openai_gpt
|
||||||
|
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
|
logging.set_verbosity_info()
|
||||||
|
|
||||||
|
|
||||||
def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path):
|
def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path):
|
||||||
|
|||||||
@@ -16,7 +16,6 @@
|
|||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@@ -76,6 +75,8 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
from transformers.file_utils import hf_bucket_url
|
from transformers.file_utils import hf_bucket_url
|
||||||
|
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -104,7 +105,7 @@ if is_torch_available():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.set_verbosity_info()
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
MODEL_CLASSES = {
|
||||||
"bert": (BertConfig, TFBertForPreTraining, BertForPreTraining, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,),
|
"bert": (BertConfig, TFBertForPreTraining, BertForPreTraining, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,),
|
||||||
|
|||||||
@@ -16,7 +16,6 @@
|
|||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -24,8 +23,10 @@ import torch
|
|||||||
|
|
||||||
from transformers import ReformerConfig, ReformerModelWithLMHead
|
from transformers import ReformerConfig, ReformerModelWithLMHead
|
||||||
|
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
|
logging.set_verbosity_info()
|
||||||
|
|
||||||
|
|
||||||
def set_param(torch_layer, weight, bias=None):
|
def set_param(torch_layer, weight, bias=None):
|
||||||
|
|||||||
@@ -16,7 +16,6 @@
|
|||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
|
||||||
import pathlib
|
import pathlib
|
||||||
|
|
||||||
import fairseq
|
import fairseq
|
||||||
@@ -28,13 +27,15 @@ from packaging import version
|
|||||||
from transformers.modeling_bert import BertIntermediate, BertLayer, BertOutput, BertSelfAttention, BertSelfOutput
|
from transformers.modeling_bert import BertIntermediate, BertLayer, BertOutput, BertSelfAttention, BertSelfOutput
|
||||||
from transformers.modeling_roberta import RobertaConfig, RobertaForMaskedLM, RobertaForSequenceClassification
|
from transformers.modeling_roberta import RobertaConfig, RobertaForMaskedLM, RobertaForSequenceClassification
|
||||||
|
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
if version.parse(fairseq.__version__) < version.parse("0.9.0"):
|
if version.parse(fairseq.__version__) < version.parse("0.9.0"):
|
||||||
raise Exception("requires fairseq >= 0.9.0")
|
raise Exception("requires fairseq >= 0.9.0")
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.set_verbosity_info()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
SAMPLE_TEXT = "Hello world! cécé herlolip"
|
SAMPLE_TEXT = "Hello world! cécé herlolip"
|
||||||
|
|
||||||
|
|||||||
@@ -16,14 +16,15 @@
|
|||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import T5Config, T5Model, load_tf_weights_in_t5
|
from transformers import T5Config, T5Model, load_tf_weights_in_t5
|
||||||
|
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
|
logging.set_verbosity_info()
|
||||||
|
|
||||||
|
|
||||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
|
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
|
||||||
|
|||||||
@@ -16,7 +16,6 @@
|
|||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import sys
|
import sys
|
||||||
@@ -33,8 +32,10 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
from transformers.tokenization_transfo_xl import CORPUS_NAME, VOCAB_FILES_NAMES
|
from transformers.tokenization_transfo_xl import CORPUS_NAME, VOCAB_FILES_NAMES
|
||||||
|
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
|
logging.set_verbosity_info()
|
||||||
|
|
||||||
# We do this to be able to load python 2 datasets pickles
|
# We do this to be able to load python 2 datasets pickles
|
||||||
# See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918
|
# See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918
|
||||||
|
|||||||
@@ -17,7 +17,6 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
import torch
|
import torch
|
||||||
@@ -25,8 +24,10 @@ import torch
|
|||||||
from transformers import CONFIG_NAME, WEIGHTS_NAME
|
from transformers import CONFIG_NAME, WEIGHTS_NAME
|
||||||
from transformers.tokenization_xlm import VOCAB_FILES_NAMES
|
from transformers.tokenization_xlm import VOCAB_FILES_NAMES
|
||||||
|
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
|
logging.set_verbosity_info()
|
||||||
|
|
||||||
|
|
||||||
def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_path):
|
def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_path):
|
||||||
|
|||||||
@@ -16,7 +16,6 @@
|
|||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -31,6 +30,8 @@ from transformers import (
|
|||||||
load_tf_weights_in_xlnet,
|
load_tf_weights_in_xlnet,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
GLUE_TASKS_NUM_LABELS = {
|
GLUE_TASKS_NUM_LABELS = {
|
||||||
"cola": 2,
|
"cola": 2,
|
||||||
@@ -45,7 +46,7 @@ GLUE_TASKS_NUM_LABELS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.set_verbosity_info()
|
||||||
|
|
||||||
|
|
||||||
def convert_xlnet_checkpoint_to_pytorch(
|
def convert_xlnet_checkpoint_to_pytorch(
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
@@ -14,11 +13,12 @@ from ...tokenization_bart import BartTokenizer, BartTokenizerFast
|
|||||||
from ...tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
|
from ...tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
|
||||||
from ...tokenization_utils import PreTrainedTokenizer
|
from ...tokenization_utils import PreTrainedTokenizer
|
||||||
from ...tokenization_xlm_roberta import XLMRobertaTokenizer
|
from ...tokenization_xlm_roberta import XLMRobertaTokenizer
|
||||||
|
from ...utils import logging
|
||||||
from ..processors.glue import glue_convert_examples_to_features, glue_output_modes, glue_processors
|
from ..processors.glue import glue_convert_examples_to_features, glue_output_modes, glue_processors
|
||||||
from ..processors.utils import InputFeatures
|
from ..processors.utils import InputFeatures
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import time
|
import time
|
||||||
@@ -9,9 +8,10 @@ from torch.utils.data.dataset import Dataset
|
|||||||
from filelock import FileLock
|
from filelock import FileLock
|
||||||
|
|
||||||
from ...tokenization_utils import PreTrainedTokenizer
|
from ...tokenization_utils import PreTrainedTokenizer
|
||||||
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TextDataset(Dataset):
|
class TextDataset(Dataset):
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
@@ -12,10 +11,11 @@ from filelock import FileLock
|
|||||||
|
|
||||||
from ...modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
from ...modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
||||||
from ...tokenization_utils import PreTrainedTokenizer
|
from ...tokenization_utils import PreTrainedTokenizer
|
||||||
|
from ...utils import logging
|
||||||
from ..processors.squad import SquadFeatures, SquadV1Processor, SquadV2Processor, squad_convert_examples_to_features
|
from ..processors.squad import SquadFeatures, SquadV1Processor, SquadV2Processor, squad_convert_examples_to_features
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
MODEL_CONFIG_CLASSES = list(MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys())
|
MODEL_CONFIG_CLASSES = list(MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys())
|
||||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||||
|
|||||||
@@ -10,15 +10,16 @@ that a question is unanswerable.
|
|||||||
|
|
||||||
import collections
|
import collections
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import math
|
import math
|
||||||
import re
|
import re
|
||||||
import string
|
import string
|
||||||
|
|
||||||
from transformers.tokenization_bert import BasicTokenizer
|
from transformers.tokenization_bert import BasicTokenizer
|
||||||
|
|
||||||
|
from ...utils import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def normalize_answer(s):
|
def normalize_answer(s):
|
||||||
|
|||||||
@@ -15,7 +15,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" GLUE processors and helpers """
|
""" GLUE processors and helpers """
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@@ -23,13 +22,14 @@ from typing import List, Optional, Union
|
|||||||
|
|
||||||
from ...file_utils import is_tf_available
|
from ...file_utils import is_tf_available
|
||||||
from ...tokenization_utils import PreTrainedTokenizer
|
from ...tokenization_utils import PreTrainedTokenizer
|
||||||
|
from ...utils import logging
|
||||||
from .utils import DataProcessor, InputExample, InputFeatures
|
from .utils import DataProcessor, InputExample, InputFeatures
|
||||||
|
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def glue_convert_examples_to_features(
|
def glue_convert_examples_to_features(
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from multiprocessing import Pool, cpu_count
|
from multiprocessing import Pool, cpu_count
|
||||||
@@ -10,6 +9,7 @@ from tqdm import tqdm
|
|||||||
from ...file_utils import is_tf_available, is_torch_available
|
from ...file_utils import is_tf_available, is_torch_available
|
||||||
from ...tokenization_bert import whitespace_tokenize
|
from ...tokenization_bert import whitespace_tokenize
|
||||||
from ...tokenization_utils_base import TruncationStrategy
|
from ...tokenization_utils_base import TruncationStrategy
|
||||||
|
from ...utils import logging
|
||||||
from .utils import DataProcessor
|
from .utils import DataProcessor
|
||||||
|
|
||||||
|
|
||||||
@@ -24,7 +24,7 @@ if is_torch_available():
|
|||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_answer_text):
|
def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_answer_text):
|
||||||
|
|||||||
@@ -17,14 +17,14 @@
|
|||||||
import csv
|
import csv
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
from ...file_utils import is_tf_available, is_torch_available
|
from ...file_utils import is_tf_available, is_torch_available
|
||||||
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -16,13 +16,13 @@
|
|||||||
""" XNLI utils (dataset loading and evaluation) """
|
""" XNLI utils (dataset loading and evaluation) """
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from ...utils import logging
|
||||||
from .utils import DataProcessor, InputExample
|
from .utils import DataProcessor, InputExample
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class XnliProcessor(DataProcessor):
|
class XnliProcessor(DataProcessor):
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ Copyright by the AllenNLP authors.
|
|||||||
|
|
||||||
import fnmatch
|
import fnmatch
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
@@ -30,9 +29,10 @@ import requests
|
|||||||
from filelock import FileLock
|
from filelock import FileLock
|
||||||
|
|
||||||
from . import __version__
|
from . import __version__
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
||||||
@@ -757,7 +757,7 @@ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent: Union[Dict
|
|||||||
total=total,
|
total=total,
|
||||||
initial=resume_size,
|
initial=resume_size,
|
||||||
desc="Downloading",
|
desc="Downloading",
|
||||||
disable=bool(logger.getEffectiveLevel() == logging.NOTSET),
|
disable=bool(logging.get_verbosity() > logging.NOTSET),
|
||||||
)
|
)
|
||||||
for chunk in response.iter_content(chunk_size=1024):
|
for chunk in response.iter_content(chunk_size=1024):
|
||||||
if chunk: # filter out keep-alive new chunks
|
if chunk: # filter out keep-alive new chunks
|
||||||
|
|||||||
@@ -14,13 +14,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TFGenerationMixin:
|
class TFGenerationMixin:
|
||||||
|
|||||||
@@ -14,15 +14,16 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Iterable, List, Optional, Tuple
|
from typing import Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class GenerationMixin:
|
class GenerationMixin:
|
||||||
|
|||||||
@@ -17,7 +17,6 @@
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
|
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
@@ -30,9 +29,10 @@ from .file_utils import (
|
|||||||
hf_bucket_url,
|
hf_bucket_url,
|
||||||
is_remote_url,
|
is_remote_url,
|
||||||
)
|
)
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ModelCard:
|
class ModelCard:
|
||||||
|
|||||||
@@ -14,7 +14,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""PyTorch ALBERT model. """
|
"""PyTorch ALBERT model. """
|
||||||
|
|
||||||
import logging
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
@@ -44,9 +43,10 @@ from .modeling_outputs import (
|
|||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from .modeling_utils import PreTrainedModel, apply_chunking_to_forward, find_pruneable_heads_and_indices
|
from .modeling_utils import PreTrainedModel, apply_chunking_to_forward, find_pruneable_heads_and_indices
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "AlbertConfig"
|
_CONFIG_FOR_DOC = "AlbertConfig"
|
||||||
_TOKENIZER_FOR_DOC = "AlbertTokenizer"
|
_TOKENIZER_FOR_DOC = "AlbertTokenizer"
|
||||||
|
|||||||
@@ -15,7 +15,6 @@
|
|||||||
""" Auto Model class. """
|
""" Auto Model class. """
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import warnings
|
import warnings
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
@@ -172,9 +171,10 @@ from .modeling_xlnet import (
|
|||||||
XLNetLMHeadModel,
|
XLNetLMHeadModel,
|
||||||
XLNetModel,
|
XLNetModel,
|
||||||
)
|
)
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
MODEL_MAPPING = OrderedDict(
|
MODEL_MAPPING = OrderedDict(
|
||||||
|
|||||||
@@ -13,7 +13,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""PyTorch BART model, ported from the fairseq repo."""
|
"""PyTorch BART model, ported from the fairseq repo."""
|
||||||
import logging
|
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
import warnings
|
import warnings
|
||||||
@@ -43,9 +42,10 @@ from .modeling_outputs import (
|
|||||||
Seq2SeqSequenceClassifierOutput,
|
Seq2SeqSequenceClassifierOutput,
|
||||||
)
|
)
|
||||||
from .modeling_utils import PreTrainedModel
|
from .modeling_utils import PreTrainedModel
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "BartConfig"
|
_CONFIG_FOR_DOC = "BartConfig"
|
||||||
_TOKENIZER_FOR_DOC = "BartTokenizer"
|
_TOKENIZER_FOR_DOC = "BartTokenizer"
|
||||||
|
|||||||
@@ -16,7 +16,6 @@
|
|||||||
"""PyTorch BERT model. """
|
"""PyTorch BERT model. """
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
@@ -54,9 +53,10 @@ from .modeling_utils import (
|
|||||||
find_pruneable_heads_and_indices,
|
find_pruneable_heads_and_indices,
|
||||||
prune_linear_layer,
|
prune_linear_layer,
|
||||||
)
|
)
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "BertConfig"
|
_CONFIG_FOR_DOC = "BertConfig"
|
||||||
_TOKENIZER_FOR_DOC = "BertTokenizer"
|
_TOKENIZER_FOR_DOC = "BertTokenizer"
|
||||||
|
|||||||
@@ -15,8 +15,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""PyTorch CamemBERT model. """
|
"""PyTorch CamemBERT model. """
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from .configuration_camembert import CamembertConfig
|
from .configuration_camembert import CamembertConfig
|
||||||
from .file_utils import add_start_docstrings
|
from .file_utils import add_start_docstrings
|
||||||
from .modeling_roberta import (
|
from .modeling_roberta import (
|
||||||
@@ -28,9 +26,10 @@ from .modeling_roberta import (
|
|||||||
RobertaForTokenClassification,
|
RobertaForTokenClassification,
|
||||||
RobertaModel,
|
RobertaModel,
|
||||||
)
|
)
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_TOKENIZER_FOR_DOC = "CamembertTokenizer"
|
_TOKENIZER_FOR_DOC = "CamembertTokenizer"
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,6 @@
|
|||||||
""" PyTorch CTRL model."""
|
""" PyTorch CTRL model."""
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -28,9 +27,10 @@ from .configuration_ctrl import CTRLConfig
|
|||||||
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
from .modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||||
from .modeling_utils import Conv1D, PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
from .modeling_utils import Conv1D, PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "CTRLConfig"
|
_CONFIG_FOR_DOC = "CTRLConfig"
|
||||||
_TOKENIZER_FOR_DOC = "CTRLTokenizer"
|
_TOKENIZER_FOR_DOC = "CTRLTokenizer"
|
||||||
|
|||||||
@@ -19,7 +19,6 @@
|
|||||||
|
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import logging
|
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
@@ -50,9 +49,10 @@ from .modeling_utils import (
|
|||||||
find_pruneable_heads_and_indices,
|
find_pruneable_heads_and_indices,
|
||||||
prune_linear_layer,
|
prune_linear_layer,
|
||||||
)
|
)
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "DistilBertConfig"
|
_CONFIG_FOR_DOC = "DistilBertConfig"
|
||||||
_TOKENIZER_FOR_DOC = "DistilBertTokenizer"
|
_TOKENIZER_FOR_DOC = "DistilBertTokenizer"
|
||||||
|
|||||||
@@ -15,7 +15,6 @@
|
|||||||
""" PyTorch DPR model for Open Domain Question Answering."""
|
""" PyTorch DPR model for Open Domain Question Answering."""
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
@@ -27,9 +26,10 @@ from .file_utils import ModelOutput, add_start_docstrings, add_start_docstrings_
|
|||||||
from .modeling_bert import BertModel
|
from .modeling_bert import BertModel
|
||||||
from .modeling_outputs import BaseModelOutputWithPooling
|
from .modeling_outputs import BaseModelOutputWithPooling
|
||||||
from .modeling_utils import PreTrainedModel
|
from .modeling_utils import PreTrainedModel
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "DPRConfig"
|
_CONFIG_FOR_DOC = "DPRConfig"
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -27,9 +26,10 @@ from .modeling_outputs import (
|
|||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from .modeling_utils import SequenceSummary
|
from .modeling_utils import SequenceSummary
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "ElectraConfig"
|
_CONFIG_FOR_DOC = "ElectraConfig"
|
||||||
_TOKENIZER_FOR_DOC = "ElectraTokenizer"
|
_TOKENIZER_FOR_DOC = "ElectraTokenizer"
|
||||||
|
|||||||
@@ -15,15 +15,15 @@
|
|||||||
""" Classes to support Encoder-Decoder architectures """
|
""" Classes to support Encoder-Decoder architectures """
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from .configuration_encoder_decoder import EncoderDecoderConfig
|
from .configuration_encoder_decoder import EncoderDecoderConfig
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
from .modeling_utils import PreTrainedModel
|
from .modeling_utils import PreTrainedModel
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class EncoderDecoderModel(PreTrainedModel):
|
class EncoderDecoderModel(PreTrainedModel):
|
||||||
|
|||||||
@@ -15,7 +15,6 @@
|
|||||||
""" PyTorch Flaubert model, based on XLM. """
|
""" PyTorch Flaubert model, based on XLM. """
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -34,9 +33,10 @@ from .modeling_xlm import (
|
|||||||
XLMWithLMHeadModel,
|
XLMWithLMHeadModel,
|
||||||
get_masks,
|
get_masks,
|
||||||
)
|
)
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "FlaubertConfig"
|
_CONFIG_FOR_DOC = "FlaubertConfig"
|
||||||
_TOKENIZER_FOR_DOC = "FlaubertTokenizer"
|
_TOKENIZER_FOR_DOC = "FlaubertTokenizer"
|
||||||
|
|||||||
@@ -16,7 +16,6 @@
|
|||||||
"""PyTorch OpenAI GPT-2 model."""
|
"""PyTorch OpenAI GPT-2 model."""
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -43,9 +42,10 @@ from .modeling_utils import (
|
|||||||
find_pruneable_heads_and_indices,
|
find_pruneable_heads_and_indices,
|
||||||
prune_conv1d_layer,
|
prune_conv1d_layer,
|
||||||
)
|
)
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "GPT2Config"
|
_CONFIG_FOR_DOC = "GPT2Config"
|
||||||
_TOKENIZER_FOR_DOC = "GPT2Tokenizer"
|
_TOKENIZER_FOR_DOC = "GPT2Tokenizer"
|
||||||
|
|||||||
@@ -14,7 +14,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""PyTorch Longformer model. """
|
"""PyTorch Longformer model. """
|
||||||
|
|
||||||
import logging
|
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
@@ -47,9 +46,10 @@ from .modeling_utils import (
|
|||||||
find_pruneable_heads_and_indices,
|
find_pruneable_heads_and_indices,
|
||||||
prune_linear_layer,
|
prune_linear_layer,
|
||||||
)
|
)
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "LongformerConfig"
|
_CONFIG_FOR_DOC = "LongformerConfig"
|
||||||
_TOKENIZER_FOR_DOC = "LongformerTokenizer"
|
_TOKENIZER_FOR_DOC = "LongformerTokenizer"
|
||||||
|
|||||||
@@ -16,8 +16,6 @@
|
|||||||
"""PyTorch MMBT model. """
|
"""PyTorch MMBT model. """
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn import CrossEntropyLoss, MSELoss
|
from torch.nn import CrossEntropyLoss, MSELoss
|
||||||
@@ -25,9 +23,10 @@ from torch.nn import CrossEntropyLoss, MSELoss
|
|||||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable, replace_return_docstrings
|
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable, replace_return_docstrings
|
||||||
from .modeling_outputs import BaseModelOutputWithPooling, SequenceClassifierOutput
|
from .modeling_outputs import BaseModelOutputWithPooling, SequenceClassifierOutput
|
||||||
from .modeling_utils import ModuleUtilsMixin
|
from .modeling_utils import ModuleUtilsMixin
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "MMBTConfig"
|
_CONFIG_FOR_DOC = "MMBTConfig"
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,6 @@
|
|||||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
# SOFTWARE.
|
# SOFTWARE.
|
||||||
|
|
||||||
import logging
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
@@ -53,9 +52,10 @@ from .modeling_outputs import (
|
|||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "MobileBertConfig"
|
_CONFIG_FOR_DOC = "MobileBertConfig"
|
||||||
_TOKENIZER_FOR_DOC = "MobileBertTokenizer"
|
_TOKENIZER_FOR_DOC = "MobileBertTokenizer"
|
||||||
|
|||||||
@@ -17,7 +17,6 @@
|
|||||||
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
@@ -45,9 +44,10 @@ from .modeling_utils import (
|
|||||||
find_pruneable_heads_and_indices,
|
find_pruneable_heads_and_indices,
|
||||||
prune_conv1d_layer,
|
prune_conv1d_layer,
|
||||||
)
|
)
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "OpenAIGPTConfig"
|
_CONFIG_FOR_DOC = "OpenAIGPTConfig"
|
||||||
_TOKENIZER_FOR_DOC = "OpenAIGPTTokenizer"
|
_TOKENIZER_FOR_DOC = "OpenAIGPTTokenizer"
|
||||||
|
|||||||
@@ -15,7 +15,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""PyTorch REFORMER model. """
|
"""PyTorch REFORMER model. """
|
||||||
|
|
||||||
import logging
|
|
||||||
import sys
|
import sys
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -41,9 +40,10 @@ from .file_utils import (
|
|||||||
)
|
)
|
||||||
from .modeling_outputs import CausalLMOutput, MaskedLMOutput, QuestionAnsweringModelOutput, SequenceClassifierOutput
|
from .modeling_outputs import CausalLMOutput, MaskedLMOutput, QuestionAnsweringModelOutput, SequenceClassifierOutput
|
||||||
from .modeling_utils import PreTrainedModel, apply_chunking_to_forward
|
from .modeling_utils import PreTrainedModel, apply_chunking_to_forward
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "ReformerConfig"
|
_CONFIG_FOR_DOC = "ReformerConfig"
|
||||||
_TOKENIZER_FOR_DOC = "ReformerTokenizer"
|
_TOKENIZER_FOR_DOC = "ReformerTokenizer"
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ RetriBERT model
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import math
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -28,9 +27,10 @@ from .configuration_retribert import RetriBertConfig
|
|||||||
from .file_utils import add_start_docstrings
|
from .file_utils import add_start_docstrings
|
||||||
from .modeling_bert import BertLayerNorm, BertModel
|
from .modeling_bert import BertLayerNorm, BertModel
|
||||||
from .modeling_utils import PreTrainedModel
|
from .modeling_utils import PreTrainedModel
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||||
"yjernite/retribert-base-uncased",
|
"yjernite/retribert-base-uncased",
|
||||||
|
|||||||
@@ -16,7 +16,6 @@
|
|||||||
"""PyTorch RoBERTa model. """
|
"""PyTorch RoBERTa model. """
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -39,9 +38,10 @@ from .modeling_outputs import (
|
|||||||
SequenceClassifierOutput,
|
SequenceClassifierOutput,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
)
|
)
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "RobertaConfig"
|
_CONFIG_FOR_DOC = "RobertaConfig"
|
||||||
_TOKENIZER_FOR_DOC = "RobertaTokenizer"
|
_TOKENIZER_FOR_DOC = "RobertaTokenizer"
|
||||||
|
|||||||
@@ -16,7 +16,6 @@
|
|||||||
|
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import logging
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
@@ -36,9 +35,10 @@ from .file_utils import (
|
|||||||
)
|
)
|
||||||
from .modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, Seq2SeqLMOutput, Seq2SeqModelOutput
|
from .modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, Seq2SeqLMOutput, Seq2SeqModelOutput
|
||||||
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "T5Config"
|
_CONFIG_FOR_DOC = "T5Config"
|
||||||
_TOKENIZER_FOR_DOC = "T5Tokenizer"
|
_TOKENIZER_FOR_DOC = "T5Tokenizer"
|
||||||
|
|||||||
@@ -16,7 +16,6 @@
|
|||||||
""" TF 2.0 ALBERT model. """
|
""" TF 2.0 ALBERT model. """
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
@@ -53,9 +52,10 @@ from .modeling_tf_utils import (
|
|||||||
shape_list,
|
shape_list,
|
||||||
)
|
)
|
||||||
from .tokenization_utils import BatchEncoding
|
from .tokenization_utils import BatchEncoding
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "AlbertConfig"
|
_CONFIG_FOR_DOC = "AlbertConfig"
|
||||||
_TOKENIZER_FOR_DOC = "AlbertTokenizer"
|
_TOKENIZER_FOR_DOC = "AlbertTokenizer"
|
||||||
|
|||||||
@@ -15,7 +15,6 @@
|
|||||||
""" Auto Model class. """
|
""" Auto Model class. """
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import warnings
|
import warnings
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
@@ -139,9 +138,10 @@ from .modeling_tf_xlnet import (
|
|||||||
TFXLNetLMHeadModel,
|
TFXLNetLMHeadModel,
|
||||||
TFXLNetModel,
|
TFXLNetModel,
|
||||||
)
|
)
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
TF_MODEL_MAPPING = OrderedDict(
|
TF_MODEL_MAPPING = OrderedDict(
|
||||||
|
|||||||
@@ -16,7 +16,6 @@
|
|||||||
""" TF 2.0 BERT model. """
|
""" TF 2.0 BERT model. """
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
@@ -56,9 +55,10 @@ from .modeling_tf_utils import (
|
|||||||
shape_list,
|
shape_list,
|
||||||
)
|
)
|
||||||
from .tokenization_utils import BatchEncoding
|
from .tokenization_utils import BatchEncoding
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "BertConfig"
|
_CONFIG_FOR_DOC = "BertConfig"
|
||||||
_TOKENIZER_FOR_DOC = "BertTokenizer"
|
_TOKENIZER_FOR_DOC = "BertTokenizer"
|
||||||
|
|||||||
@@ -15,9 +15,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" TF 2.0 CamemBERT model. """
|
""" TF 2.0 CamemBERT model. """
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from .configuration_camembert import CamembertConfig
|
from .configuration_camembert import CamembertConfig
|
||||||
from .file_utils import add_start_docstrings
|
from .file_utils import add_start_docstrings
|
||||||
from .modeling_tf_roberta import (
|
from .modeling_tf_roberta import (
|
||||||
@@ -28,9 +25,10 @@ from .modeling_tf_roberta import (
|
|||||||
TFRobertaForTokenClassification,
|
TFRobertaForTokenClassification,
|
||||||
TFRobertaModel,
|
TFRobertaModel,
|
||||||
)
|
)
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||||
# See all CamemBERT models at https://huggingface.co/models?filter=camembert
|
# See all CamemBERT models at https://huggingface.co/models?filter=camembert
|
||||||
|
|||||||
@@ -16,8 +16,6 @@
|
|||||||
""" TF 2.0 CTRL model."""
|
""" TF 2.0 CTRL model."""
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
@@ -32,9 +30,10 @@ from .modeling_tf_utils import (
|
|||||||
shape_list,
|
shape_list,
|
||||||
)
|
)
|
||||||
from .tokenization_utils import BatchEncoding
|
from .tokenization_utils import BatchEncoding
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "CTRLConfig"
|
_CONFIG_FOR_DOC = "CTRLConfig"
|
||||||
_TOKENIZER_FOR_DOC = "CTRLTokenizer"
|
_TOKENIZER_FOR_DOC = "CTRLTokenizer"
|
||||||
|
|||||||
@@ -16,7 +16,6 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import math
|
import math
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -50,9 +49,10 @@ from .modeling_tf_utils import (
|
|||||||
shape_list,
|
shape_list,
|
||||||
)
|
)
|
||||||
from .tokenization_utils import BatchEncoding
|
from .tokenization_utils import BatchEncoding
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "DistilBertConfig"
|
_CONFIG_FOR_DOC = "DistilBertConfig"
|
||||||
_TOKENIZER_FOR_DOC = "DistilBertTokenizer"
|
_TOKENIZER_FOR_DOC = "DistilBertTokenizer"
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import logging
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
@@ -35,9 +34,10 @@ from .modeling_tf_utils import (
|
|||||||
shape_list,
|
shape_list,
|
||||||
)
|
)
|
||||||
from .tokenization_utils import BatchEncoding
|
from .tokenization_utils import BatchEncoding
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "ElectraConfig"
|
_CONFIG_FOR_DOC = "ElectraConfig"
|
||||||
_TOKENIZER_FOR_DOC = "ElectraTokenizer"
|
_TOKENIZER_FOR_DOC = "ElectraTokenizer"
|
||||||
|
|||||||
@@ -15,7 +15,6 @@
|
|||||||
""" TF 2.0 Flaubert model.
|
""" TF 2.0 Flaubert model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
@@ -36,9 +35,10 @@ from .modeling_tf_xlm import (
|
|||||||
get_masks,
|
get_masks,
|
||||||
)
|
)
|
||||||
from .tokenization_utils import BatchEncoding
|
from .tokenization_utils import BatchEncoding
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||||
# See all Flaubert models at https://huggingface.co/models?filter=flaubert
|
# See all Flaubert models at https://huggingface.co/models?filter=flaubert
|
||||||
|
|||||||
@@ -16,7 +16,6 @@
|
|||||||
""" TF 2.0 OpenAI GPT-2 model. """
|
""" TF 2.0 OpenAI GPT-2 model. """
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
@@ -43,9 +42,10 @@ from .modeling_tf_utils import (
|
|||||||
shape_list,
|
shape_list,
|
||||||
)
|
)
|
||||||
from .tokenization_utils import BatchEncoding
|
from .tokenization_utils import BatchEncoding
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "GPT2Config"
|
_CONFIG_FOR_DOC = "GPT2Config"
|
||||||
_TOKENIZER_FOR_DOC = "GPT2Tokenizer"
|
_TOKENIZER_FOR_DOC = "GPT2Tokenizer"
|
||||||
|
|||||||
@@ -14,8 +14,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Tensorflow Longformer model. """
|
"""Tensorflow Longformer model. """
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from .configuration_longformer import LongformerConfig
|
from .configuration_longformer import LongformerConfig
|
||||||
@@ -37,9 +35,10 @@ from .modeling_tf_utils import (
|
|||||||
shape_list,
|
shape_list,
|
||||||
)
|
)
|
||||||
from .tokenization_utils import BatchEncoding
|
from .tokenization_utils import BatchEncoding
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "LongformerConfig"
|
_CONFIG_FOR_DOC = "LongformerConfig"
|
||||||
_TOKENIZER_FOR_DOC = "LongformerTokenizer"
|
_TOKENIZER_FOR_DOC = "LongformerTokenizer"
|
||||||
|
|||||||
@@ -16,7 +16,6 @@
|
|||||||
""" TF 2.0 MobileBERT model. """
|
""" TF 2.0 MobileBERT model. """
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
@@ -54,9 +53,10 @@ from .modeling_tf_utils import (
|
|||||||
shape_list,
|
shape_list,
|
||||||
)
|
)
|
||||||
from .tokenization_utils import BatchEncoding
|
from .tokenization_utils import BatchEncoding
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "MobileBertConfig"
|
_CONFIG_FOR_DOC = "MobileBertConfig"
|
||||||
_TOKENIZER_FOR_DOC = "MobileBertTokenizer"
|
_TOKENIZER_FOR_DOC = "MobileBertTokenizer"
|
||||||
|
|||||||
@@ -16,7 +16,6 @@
|
|||||||
""" TF 2.0 OpenAI GPT model."""
|
""" TF 2.0 OpenAI GPT model."""
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
@@ -43,9 +42,10 @@ from .modeling_tf_utils import (
|
|||||||
shape_list,
|
shape_list,
|
||||||
)
|
)
|
||||||
from .tokenization_utils import BatchEncoding
|
from .tokenization_utils import BatchEncoding
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "OpenAIGPTConfig"
|
_CONFIG_FOR_DOC = "OpenAIGPTConfig"
|
||||||
_TOKENIZER_FOR_DOC = "OpenAIGPTTokenizer"
|
_TOKENIZER_FOR_DOC = "OpenAIGPTTokenizer"
|
||||||
|
|||||||
@@ -16,14 +16,15 @@
|
|||||||
""" PyTorch - TF 2.0 general utilities."""
|
""" PyTorch - TF 2.0 general utilities."""
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove=""):
|
def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove=""):
|
||||||
|
|||||||
@@ -16,8 +16,6 @@
|
|||||||
""" TF 2.0 RoBERTa model. """
|
""" TF 2.0 RoBERTa model. """
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from .configuration_roberta import RobertaConfig
|
from .configuration_roberta import RobertaConfig
|
||||||
@@ -48,9 +46,10 @@ from .modeling_tf_utils import (
|
|||||||
shape_list,
|
shape_list,
|
||||||
)
|
)
|
||||||
from .tokenization_utils_base import BatchEncoding
|
from .tokenization_utils_base import BatchEncoding
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "RobertaConfig"
|
_CONFIG_FOR_DOC = "RobertaConfig"
|
||||||
_TOKENIZER_FOR_DOC = "RobertaTokenizer"
|
_TOKENIZER_FOR_DOC = "RobertaTokenizer"
|
||||||
|
|||||||
@@ -18,7 +18,6 @@
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
@@ -42,9 +41,10 @@ from .modeling_tf_utils import (
|
|||||||
shape_list,
|
shape_list,
|
||||||
)
|
)
|
||||||
from .tokenization_utils import BatchEncoding
|
from .tokenization_utils import BatchEncoding
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "T5Config"
|
_CONFIG_FOR_DOC = "T5Config"
|
||||||
_TOKENIZER_FOR_DOC = "T5Tokenizer"
|
_TOKENIZER_FOR_DOC = "T5Tokenizer"
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user