Centralize logging (#6434)

* Logging

* Style

* hf_logging > utils.logging

* Address @thomwolf's comments

* Update test

* Update src/transformers/benchmark/benchmark_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Revert bad change

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Lysandre Debut
2020-08-26 11:10:36 -04:00
committed by GitHub
parent 461ae86812
commit 77abd1e79f
144 changed files with 497 additions and 347 deletions

View File

@@ -17,8 +17,6 @@ else:
absl.logging.set_stderrthreshold("info") absl.logging.set_stderrthreshold("info")
absl.logging._warn_preinit_stderr = False absl.logging._warn_preinit_stderr = False
import logging
# Configurations # Configurations
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, AutoConfig from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, AutoConfig
@@ -184,9 +182,10 @@ from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer
from .trainer_utils import EvalPrediction, set_seed from .trainer_utils import EvalPrediction, set_seed
from .training_args import TrainingArguments from .training_args import TrainingArguments
from .training_args_tf import TFTrainingArguments from .training_args_tf import TFTrainingArguments
from .utils import logging
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_sklearn_available(): if is_sklearn_available():

View File

@@ -1,11 +1,12 @@
import logging
import math import math
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from .utils import logging
logger = logging.getLogger(__name__)
logger = logging.get_logger(__name__)
def swish(x): def swish(x):

View File

@@ -18,13 +18,13 @@
""" """
import logging
import timeit import timeit
from typing import Callable, Optional from typing import Callable, Optional
from ..configuration_utils import PretrainedConfig from ..configuration_utils import PretrainedConfig
from ..file_utils import is_py3nvml_available, is_torch_available from ..file_utils import is_py3nvml_available, is_torch_available
from ..modeling_auto import MODEL_MAPPING, MODEL_WITH_LM_HEAD_MAPPING from ..modeling_auto import MODEL_MAPPING, MODEL_WITH_LM_HEAD_MAPPING
from ..utils import logging
from .benchmark_utils import ( from .benchmark_utils import (
Benchmark, Benchmark,
Memory, Memory,
@@ -45,7 +45,7 @@ if is_py3nvml_available():
import py3nvml.py3nvml as nvml import py3nvml.py3nvml as nvml
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
class PyTorchBenchmark(Benchmark): class PyTorchBenchmark(Benchmark):

View File

@@ -14,11 +14,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Tuple from typing import Tuple
from ..file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required from ..file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required
from ..utils import logging
from .benchmark_args_utils import BenchmarkArguments from .benchmark_args_utils import BenchmarkArguments
@@ -29,7 +29,7 @@ if is_torch_tpu_available():
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
@dataclass @dataclass

View File

@@ -14,11 +14,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Tuple from typing import Tuple
from ..file_utils import cached_property, is_tf_available, tf_required from ..file_utils import cached_property, is_tf_available, tf_required
from ..utils import logging
from .benchmark_args_utils import BenchmarkArguments from .benchmark_args_utils import BenchmarkArguments
@@ -26,7 +26,7 @@ if is_tf_available():
import tensorflow as tf import tensorflow as tf
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
@dataclass @dataclass

View File

@@ -16,13 +16,14 @@
import dataclasses import dataclasses
import json import json
import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from time import time from time import time
from typing import List from typing import List
from ..utils import logging
logger = logging.getLogger(__name__)
logger = logging.get_logger(__name__)
def list_field(default=None, metadata=None): def list_field(default=None, metadata=None):

View File

@@ -18,7 +18,6 @@
""" """
import logging
import random import random
import timeit import timeit
from functools import wraps from functools import wraps
@@ -27,6 +26,7 @@ from typing import Callable, Optional
from ..configuration_utils import PretrainedConfig from ..configuration_utils import PretrainedConfig
from ..file_utils import is_py3nvml_available, is_tf_available from ..file_utils import is_py3nvml_available, is_tf_available
from ..modeling_tf_auto import TF_MODEL_MAPPING, TF_MODEL_WITH_LM_HEAD_MAPPING from ..modeling_tf_auto import TF_MODEL_MAPPING, TF_MODEL_WITH_LM_HEAD_MAPPING
from ..utils import logging
from .benchmark_utils import ( from .benchmark_utils import (
Benchmark, Benchmark,
Memory, Memory,
@@ -46,7 +46,7 @@ if is_tf_available():
if is_py3nvml_available(): if is_py3nvml_available():
import py3nvml.py3nvml as nvml import py3nvml.py3nvml as nvml
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
def run_with_tf_optimizations(do_eager_mode: bool, use_xla: bool): def run_with_tf_optimizations(do_eager_mode: bool, use_xla: bool):

View File

@@ -7,7 +7,6 @@ Copyright by the AllenNLP authors.
import copy import copy
import csv import csv
import linecache import linecache
import logging
import os import os
import platform import platform
import sys import sys
@@ -22,6 +21,7 @@ from transformers import AutoConfig, PretrainedConfig
from transformers import __version__ as version from transformers import __version__ as version
from ..file_utils import is_psutil_available, is_py3nvml_available, is_tf_available, is_torch_available from ..file_utils import is_psutil_available, is_py3nvml_available, is_tf_available, is_torch_available
from ..utils import logging
from .benchmark_args_utils import BenchmarkArguments from .benchmark_args_utils import BenchmarkArguments
@@ -43,7 +43,7 @@ else:
from signal import SIGKILL from signal import SIGKILL
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
_is_memory_tracing_enabled = False _is_memory_tracing_enabled = False
@@ -94,7 +94,7 @@ def separate_process_wrapper_fn(func: Callable[[], None], do_multi_processing: b
return result return result
if do_multi_processing: if do_multi_processing:
logging.info("fFunction {func} is executed in its own process...") logger.info(f"Function {func} is executed in its own process...")
return multi_process_func return multi_process_func
else: else:
return func return func

View File

@@ -1,8 +1,9 @@
from argparse import ArgumentParser, Namespace from argparse import ArgumentParser, Namespace
from logging import getLogger
from transformers.commands import BaseTransformersCLICommand from transformers.commands import BaseTransformersCLICommand
from ..utils import logging
def convert_command_factory(args: Namespace): def convert_command_factory(args: Namespace):
""" """
@@ -52,7 +53,7 @@ class ConvertCommand(BaseTransformersCLICommand):
finetuning_task_name: str, finetuning_task_name: str,
*args *args
): ):
self._logger = getLogger("transformers-cli/converting") self._logger = logging.get_logger("transformers-cli/converting")
self._logger.info("Loading model {}".format(model_type)) self._logger.info("Loading model {}".format(model_type))
self._model_type = model_type self._model_type = model_type

View File

@@ -1,11 +1,12 @@
import logging
from argparse import ArgumentParser from argparse import ArgumentParser
from transformers.commands import BaseTransformersCLICommand from transformers.commands import BaseTransformersCLICommand
from transformers.pipelines import SUPPORTED_TASKS, Pipeline, PipelineDataFormat, pipeline from transformers.pipelines import SUPPORTED_TASKS, Pipeline, PipelineDataFormat, pipeline
from ..utils import logging
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def try_infer_format_from_ext(path: str): def try_infer_format_from_ext(path: str):

View File

@@ -1,4 +1,3 @@
import logging
from argparse import ArgumentParser, Namespace from argparse import ArgumentParser, Namespace
from typing import Any, List, Optional from typing import Any, List, Optional
@@ -6,6 +5,8 @@ from transformers import Pipeline
from transformers.commands import BaseTransformersCLICommand from transformers.commands import BaseTransformersCLICommand
from transformers.pipelines import SUPPORTED_TASKS, pipeline from transformers.pipelines import SUPPORTED_TASKS, pipeline
from ..utils import logging
try: try:
from fastapi import Body, FastAPI, HTTPException from fastapi import Body, FastAPI, HTTPException

View File

@@ -1,11 +1,12 @@
import os import os
from argparse import ArgumentParser, Namespace from argparse import ArgumentParser, Namespace
from logging import getLogger
from transformers import SingleSentenceClassificationProcessor as Processor from transformers import SingleSentenceClassificationProcessor as Processor
from transformers import TextClassificationPipeline, is_tf_available, is_torch_available from transformers import TextClassificationPipeline, is_tf_available, is_torch_available
from transformers.commands import BaseTransformersCLICommand from transformers.commands import BaseTransformersCLICommand
from ..utils import logging
if not is_tf_available() and not is_torch_available(): if not is_tf_available() and not is_torch_available():
raise RuntimeError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training") raise RuntimeError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training")
@@ -76,7 +77,7 @@ class TrainCommand(BaseTransformersCLICommand):
train_parser.set_defaults(func=train_command_factory) train_parser.set_defaults(func=train_command_factory)
def __init__(self, args: Namespace): def __init__(self, args: Namespace):
self.logger = getLogger("transformers-cli/training") self.logger = logging.get_logger("transformers-cli/training")
self.framework = "tf" if is_tf_available() else "torch" self.framework = "tf" if is_tf_available() else "torch"

View File

@@ -15,7 +15,6 @@
""" Auto Config class. """ """ Auto Config class. """
import logging
from collections import OrderedDict from collections import OrderedDict
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
@@ -45,9 +44,6 @@ from .configuration_xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig from .configuration_xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig
logger = logging.getLogger(__name__)
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict( ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
(key, value) (key, value)
for pretrained_map in [ for pretrained_map in [

View File

@@ -14,14 +14,12 @@
# limitations under the License. # limitations under the License.
""" BART configuration """ """ BART configuration """
import logging
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .file_utils import add_start_docstrings_to_callable from .file_utils import add_start_docstrings_to_callable
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
BART_PRETRAINED_CONFIG_ARCHIVE_MAP = { BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/bart-base": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-base/config.json", "facebook/bart-base": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-base/config.json",

View File

@@ -15,13 +15,11 @@
# limitations under the License. # limitations under the License.
""" BERT model configuration """ """ BERT model configuration """
import logging
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json", "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",

View File

@@ -15,13 +15,11 @@
# limitations under the License. # limitations under the License.
""" CamemBERT configuration """ """ CamemBERT configuration """
import logging
from .configuration_roberta import RobertaConfig from .configuration_roberta import RobertaConfig
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"camembert-base": "https://s3.amazonaws.com/models.huggingface.co/bert/camembert-base-config.json", "camembert-base": "https://s3.amazonaws.com/models.huggingface.co/bert/camembert-base-config.json",

View File

@@ -14,13 +14,11 @@
# limitations under the License. # limitations under the License.
""" Salesforce CTRL configuration """ """ Salesforce CTRL configuration """
import logging
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP = {"ctrl": "https://s3.amazonaws.com/models.huggingface.co/bert/ctrl-config.json"} CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP = {"ctrl": "https://s3.amazonaws.com/models.huggingface.co/bert/ctrl-config.json"}

View File

@@ -14,13 +14,11 @@
# limitations under the License. # limitations under the License.
""" DistilBERT model configuration """ """ DistilBERT model configuration """
import logging
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"distilbert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json", "distilbert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json",

View File

@@ -14,13 +14,11 @@
# limitations under the License. # limitations under the License.
""" DPR model configuration """ """ DPR model configuration """
import logging
from .configuration_bert import BertConfig from .configuration_bert import BertConfig
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
DPR_PRETRAINED_CONFIG_ARCHIVE_MAP = { DPR_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/dpr-ctx_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/dpr-ctx_encoder-single-nq-base/config.json", "facebook/dpr-ctx_encoder-single-nq-base": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/dpr-ctx_encoder-single-nq-base/config.json",

View File

@@ -15,13 +15,11 @@
# limitations under the License. # limitations under the License.
""" ELECTRA model configuration """ """ ELECTRA model configuration """
import logging
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP = { ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"google/electra-small-generator": "https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-small-generator/config.json", "google/electra-small-generator": "https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-small-generator/config.json",

View File

@@ -15,12 +15,12 @@
# limitations under the License. # limitations under the License.
import copy import copy
import logging
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
class EncoderDecoderConfig(PretrainedConfig): class EncoderDecoderConfig(PretrainedConfig):

View File

@@ -14,13 +14,11 @@
# limitations under the License. # limitations under the License.
""" Flaubert configuration, based on XLM. """ """ Flaubert configuration, based on XLM. """
import logging
from .configuration_xlm import XLMConfig from .configuration_xlm import XLMConfig
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"flaubert/flaubert_small_cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_small_cased/config.json", "flaubert/flaubert_small_cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_small_cased/config.json",

View File

@@ -15,13 +15,11 @@
# limitations under the License. # limitations under the License.
""" OpenAI GPT-2 configuration """ """ OpenAI GPT-2 configuration """
import logging
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = { GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json", "gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",

View File

@@ -14,13 +14,13 @@
# limitations under the License. # limitations under the License.
""" Longformer configuration """ """ Longformer configuration """
import logging
from typing import List, Union from typing import List, Union
from .configuration_roberta import RobertaConfig from .configuration_roberta import RobertaConfig
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = { LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"allenai/longformer-base-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-base-4096/config.json", "allenai/longformer-base-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-base-4096/config.json",

View File

@@ -14,12 +14,11 @@
# limitations under the License. # limitations under the License.
""" MBART configuration """ """ MBART configuration """
import logging
from .configuration_bart import BartConfig from .configuration_bart import BartConfig
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
MBART_PRETRAINED_CONFIG_ARCHIVE_MAP = { MBART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json", "facebook/mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json",

View File

@@ -15,11 +15,10 @@
# limitations under the License. # limitations under the License.
""" MMBT configuration """ """ MMBT configuration """
from .utils import logging
import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
class MMBTConfig(object): class MMBTConfig(object):

View File

@@ -12,12 +12,11 @@
# limitations under the License. # limitations under the License.
""" MobileBERT model configuration """ """ MobileBERT model configuration """
import logging
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"mobilebert-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/google/mobilebert-uncased/config.json" "mobilebert-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/google/mobilebert-uncased/config.json"

View File

@@ -15,13 +15,11 @@
# limitations under the License. # limitations under the License.
""" OpenAI GPT configuration """ """ OpenAI GPT configuration """
import logging
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP = { OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json" "openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"

View File

@@ -14,13 +14,12 @@
# limitations under the License. # limitations under the License.
""" PEGASUS model configuration """ """ PEGASUS model configuration """
import logging
from .configuration_bart import BART_CONFIG_ARGS_DOC, BartConfig from .configuration_bart import BART_CONFIG_ARGS_DOC, BartConfig
from .file_utils import add_start_docstrings_to_callable from .file_utils import add_start_docstrings_to_callable
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
# These config values do not vary between checkpoints # These config values do not vary between checkpoints
DEFAULTS = dict( DEFAULTS = dict(

View File

@@ -15,13 +15,11 @@
# limitations under the License. # limitations under the License.
""" Reformer model configuration """ """ Reformer model configuration """
import logging
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = { REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"google/reformer-crime-and-punishment": "https://cdn.huggingface.co/google/reformer-crime-and-punishment/config.json", "google/reformer-crime-and-punishment": "https://cdn.huggingface.co/google/reformer-crime-and-punishment/config.json",

View File

@@ -14,13 +14,11 @@
# limitations under the License. # limitations under the License.
""" RetriBERT model configuration """ """ RetriBERT model configuration """
import logging
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
# TODO: uploadto AWS # TODO: uploadto AWS
RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {

View File

@@ -15,13 +15,11 @@
# limitations under the License. # limitations under the License.
""" RoBERTa configuration """ """ RoBERTa configuration """
import logging
from .configuration_bert import BertConfig from .configuration_bert import BertConfig
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = { ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json", "roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json",

View File

@@ -14,13 +14,11 @@
# limitations under the License. # limitations under the License.
""" T5 model configuration """ """ T5 model configuration """
import logging
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
T5_PRETRAINED_CONFIG_ARCHIVE_MAP = { T5_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"t5-small": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-small-config.json", "t5-small": "https://s3.amazonaws.com/models.huggingface.co/bert/t5-small-config.json",

View File

@@ -16,13 +16,12 @@
""" Transformer XL configuration """ """ Transformer XL configuration """
import logging
import warnings import warnings
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP = { TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"transfo-xl-wt103": "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json", "transfo-xl-wt103": "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json",

View File

@@ -18,14 +18,14 @@
import copy import copy
import json import json
import logging
import os import os
from typing import Any, Dict, Tuple from typing import Any, Dict, Tuple
from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
class PretrainedConfig(object): class PretrainedConfig(object):

View File

@@ -14,13 +14,11 @@
# limitations under the License. # limitations under the License.
""" XLM configuration """ """ XLM configuration """
import logging
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = { XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"xlm-mlm-en-2048": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json", "xlm-mlm-en-2048": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json",

View File

@@ -15,13 +15,11 @@
# limitations under the License. # limitations under the License.
""" XLM-RoBERTa configuration """ """ XLM-RoBERTa configuration """
import logging
from .configuration_roberta import RobertaConfig from .configuration_roberta import RobertaConfig
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = { XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"xlm-roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-base-config.json", "xlm-roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-base-config.json",

View File

@@ -15,13 +15,13 @@
# limitations under the License. # limitations under the License.
""" XLNet configuration """ """ XLNet configuration """
import logging
import warnings import warnings
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP = { XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"xlnet-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-config.json", "xlnet-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-config.json",

View File

@@ -16,14 +16,15 @@
import argparse import argparse
import logging
import torch import torch
from transformers import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert from transformers import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert
from .utils import logging
logging.basicConfig(level=logging.INFO)
logging.set_verbosity_info()
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pytorch_dump_path): def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, albert_config_file, pytorch_dump_path):

View File

@@ -16,7 +16,6 @@
import argparse import argparse
import logging
import os import os
from pathlib import Path from pathlib import Path
@@ -33,6 +32,8 @@ from transformers import (
) )
from transformers.modeling_bart import _make_linear_from_emb from transformers.modeling_bart import _make_linear_from_emb
from .utils import logging
FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn", "bart_xsum/model.pt"] FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn", "bart_xsum/model.pt"]
extra_arch = {"bart.large": BartModel, "bart.large.mnli": BartForSequenceClassification} extra_arch = {"bart.large": BartModel, "bart.large.mnli": BartForSequenceClassification}
@@ -40,8 +41,8 @@ if version.parse(fairseq.__version__) < version.parse("0.9.0"):
raise Exception("requires fairseq >= 0.9.0") raise Exception("requires fairseq >= 0.9.0")
logging.basicConfig(level=logging.INFO) logging.set_verbosity_info()
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
SAMPLE_TEXT = " Hello world! cécé herlolip" SAMPLE_TEXT = " Hello world! cécé herlolip"

View File

@@ -8,7 +8,6 @@ The script re-maps the TF2.x Bert weight names to the original names, so the mod
You may adapt this script to include classification/MLM/NSP/etc. heads. You may adapt this script to include classification/MLM/NSP/etc. heads.
""" """
import argparse import argparse
import logging
import os import os
import re import re
@@ -17,9 +16,11 @@ import torch
from transformers import BertConfig, BertModel from transformers import BertConfig, BertModel
from .utils import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def load_tf2_weights_in_bert(model, tf_checkpoint_path, config): def load_tf2_weights_in_bert(model, tf_checkpoint_path, config):

View File

@@ -16,14 +16,15 @@
import argparse import argparse
import logging
import torch import torch
from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
from .utils import logging
logging.basicConfig(level=logging.INFO)
logging.set_verbosity_info()
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):

View File

@@ -16,14 +16,15 @@
import argparse import argparse
import logging
import torch import torch
from transformers import ElectraConfig, ElectraForMaskedLM, ElectraForPreTraining, load_tf_weights_in_electra from transformers import ElectraConfig, ElectraForMaskedLM, ElectraForPreTraining, load_tf_weights_in_electra
from .utils import logging
logging.basicConfig(level=logging.INFO)
logging.set_verbosity_info()
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path, discriminator_or_generator): def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path, discriminator_or_generator):

View File

@@ -16,14 +16,15 @@
import argparse import argparse
import logging
import torch import torch
from transformers import CONFIG_NAME, WEIGHTS_NAME, GPT2Config, GPT2Model, load_tf_weights_in_gpt2 from transformers import CONFIG_NAME, WEIGHTS_NAME, GPT2Config, GPT2Model, load_tf_weights_in_gpt2
from .utils import logging
logging.basicConfig(level=logging.INFO)
logging.set_verbosity_info()
def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path): def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path):

View File

@@ -1,12 +1,13 @@
import argparse import argparse
import logging
import torch import torch
from transformers import MobileBertConfig, MobileBertForPreTraining, load_tf_weights_in_mobilebert from transformers import MobileBertConfig, MobileBertForPreTraining, load_tf_weights_in_mobilebert
from .utils import logging
logging.basicConfig(level=logging.INFO)
logging.set_verbosity_info()
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, mobilebert_config_file, pytorch_dump_path): def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, mobilebert_config_file, pytorch_dump_path):

View File

@@ -16,14 +16,15 @@
import argparse import argparse
import logging
import torch import torch
from transformers import CONFIG_NAME, WEIGHTS_NAME, OpenAIGPTConfig, OpenAIGPTModel, load_tf_weights_in_openai_gpt from transformers import CONFIG_NAME, WEIGHTS_NAME, OpenAIGPTConfig, OpenAIGPTModel, load_tf_weights_in_openai_gpt
from .utils import logging
logging.basicConfig(level=logging.INFO)
logging.set_verbosity_info()
def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path): def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path):

View File

@@ -16,7 +16,6 @@
import argparse import argparse
import logging
import os import os
from transformers import ( from transformers import (
@@ -76,6 +75,8 @@ from transformers import (
) )
from transformers.file_utils import hf_bucket_url from transformers.file_utils import hf_bucket_url
from .utils import logging
if is_torch_available(): if is_torch_available():
import numpy as np import numpy as np
@@ -104,7 +105,7 @@ if is_torch_available():
) )
logging.basicConfig(level=logging.INFO) logging.set_verbosity_info()
MODEL_CLASSES = { MODEL_CLASSES = {
"bert": (BertConfig, TFBertForPreTraining, BertForPreTraining, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,), "bert": (BertConfig, TFBertForPreTraining, BertForPreTraining, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,),

View File

@@ -16,7 +16,6 @@
import argparse import argparse
import logging
import pickle import pickle
import numpy as np import numpy as np
@@ -24,8 +23,10 @@ import torch
from transformers import ReformerConfig, ReformerModelWithLMHead from transformers import ReformerConfig, ReformerModelWithLMHead
from .utils import logging
logging.basicConfig(level=logging.INFO)
logging.set_verbosity_info()
def set_param(torch_layer, weight, bias=None): def set_param(torch_layer, weight, bias=None):

View File

@@ -16,7 +16,6 @@
import argparse import argparse
import logging
import pathlib import pathlib
import fairseq import fairseq
@@ -28,13 +27,15 @@ from packaging import version
from transformers.modeling_bert import BertIntermediate, BertLayer, BertOutput, BertSelfAttention, BertSelfOutput from transformers.modeling_bert import BertIntermediate, BertLayer, BertOutput, BertSelfAttention, BertSelfOutput
from transformers.modeling_roberta import RobertaConfig, RobertaForMaskedLM, RobertaForSequenceClassification from transformers.modeling_roberta import RobertaConfig, RobertaForMaskedLM, RobertaForSequenceClassification
from .utils import logging
if version.parse(fairseq.__version__) < version.parse("0.9.0"): if version.parse(fairseq.__version__) < version.parse("0.9.0"):
raise Exception("requires fairseq >= 0.9.0") raise Exception("requires fairseq >= 0.9.0")
logging.basicConfig(level=logging.INFO) logging.set_verbosity_info()
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
SAMPLE_TEXT = "Hello world! cécé herlolip" SAMPLE_TEXT = "Hello world! cécé herlolip"

View File

@@ -16,14 +16,15 @@
import argparse import argparse
import logging
import torch import torch
from transformers import T5Config, T5Model, load_tf_weights_in_t5 from transformers import T5Config, T5Model, load_tf_weights_in_t5
from .utils import logging
logging.basicConfig(level=logging.INFO)
logging.set_verbosity_info()
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):

View File

@@ -16,7 +16,6 @@
import argparse import argparse
import logging
import os import os
import pickle import pickle
import sys import sys
@@ -33,8 +32,10 @@ from transformers import (
) )
from transformers.tokenization_transfo_xl import CORPUS_NAME, VOCAB_FILES_NAMES from transformers.tokenization_transfo_xl import CORPUS_NAME, VOCAB_FILES_NAMES
from .utils import logging
logging.basicConfig(level=logging.INFO)
logging.set_verbosity_info()
# We do this to be able to load python 2 datasets pickles # We do this to be able to load python 2 datasets pickles
# See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918 # See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918

View File

@@ -17,7 +17,6 @@
import argparse import argparse
import json import json
import logging
import numpy import numpy
import torch import torch
@@ -25,8 +24,10 @@ import torch
from transformers import CONFIG_NAME, WEIGHTS_NAME from transformers import CONFIG_NAME, WEIGHTS_NAME
from transformers.tokenization_xlm import VOCAB_FILES_NAMES from transformers.tokenization_xlm import VOCAB_FILES_NAMES
from .utils import logging
logging.basicConfig(level=logging.INFO)
logging.set_verbosity_info()
def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_path): def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_path):

View File

@@ -16,7 +16,6 @@
import argparse import argparse
import logging
import os import os
import torch import torch
@@ -31,6 +30,8 @@ from transformers import (
load_tf_weights_in_xlnet, load_tf_weights_in_xlnet,
) )
from .utils import logging
GLUE_TASKS_NUM_LABELS = { GLUE_TASKS_NUM_LABELS = {
"cola": 2, "cola": 2,
@@ -45,7 +46,7 @@ GLUE_TASKS_NUM_LABELS = {
} }
logging.basicConfig(level=logging.INFO) logging.set_verbosity_info()
def convert_xlnet_checkpoint_to_pytorch( def convert_xlnet_checkpoint_to_pytorch(

View File

@@ -1,4 +1,3 @@
import logging
import os import os
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
@@ -14,11 +13,12 @@ from ...tokenization_bart import BartTokenizer, BartTokenizerFast
from ...tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast from ...tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
from ...tokenization_utils import PreTrainedTokenizer from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_xlm_roberta import XLMRobertaTokenizer from ...tokenization_xlm_roberta import XLMRobertaTokenizer
from ...utils import logging
from ..processors.glue import glue_convert_examples_to_features, glue_output_modes, glue_processors from ..processors.glue import glue_convert_examples_to_features, glue_output_modes, glue_processors
from ..processors.utils import InputFeatures from ..processors.utils import InputFeatures
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
@dataclass @dataclass

View File

@@ -1,4 +1,3 @@
import logging
import os import os
import pickle import pickle
import time import time
@@ -9,9 +8,10 @@ from torch.utils.data.dataset import Dataset
from filelock import FileLock from filelock import FileLock
from ...tokenization_utils import PreTrainedTokenizer from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
class TextDataset(Dataset): class TextDataset(Dataset):

View File

@@ -1,4 +1,3 @@
import logging
import os import os
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
@@ -12,10 +11,11 @@ from filelock import FileLock
from ...modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING from ...modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
from ...tokenization_utils import PreTrainedTokenizer from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging
from ..processors.squad import SquadFeatures, SquadV1Processor, SquadV2Processor, squad_convert_examples_to_features from ..processors.squad import SquadFeatures, SquadV1Processor, SquadV2Processor, squad_convert_examples_to_features
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
MODEL_CONFIG_CLASSES = list(MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys()) MODEL_CONFIG_CLASSES = list(MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)

View File

@@ -10,15 +10,16 @@ that a question is unanswerable.
import collections import collections
import json import json
import logging
import math import math
import re import re
import string import string
from transformers.tokenization_bert import BasicTokenizer from transformers.tokenization_bert import BasicTokenizer
from ...utils import logging
logger = logging.getLogger(__name__)
logger = logging.get_logger(__name__)
def normalize_answer(s): def normalize_answer(s):

View File

@@ -15,7 +15,6 @@
# limitations under the License. # limitations under the License.
""" GLUE processors and helpers """ """ GLUE processors and helpers """
import logging
import os import os
from dataclasses import asdict from dataclasses import asdict
from enum import Enum from enum import Enum
@@ -23,13 +22,14 @@ from typing import List, Optional, Union
from ...file_utils import is_tf_available from ...file_utils import is_tf_available
from ...tokenization_utils import PreTrainedTokenizer from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging
from .utils import DataProcessor, InputExample, InputFeatures from .utils import DataProcessor, InputExample, InputFeatures
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
def glue_convert_examples_to_features( def glue_convert_examples_to_features(

View File

@@ -1,5 +1,4 @@
import json import json
import logging
import os import os
from functools import partial from functools import partial
from multiprocessing import Pool, cpu_count from multiprocessing import Pool, cpu_count
@@ -10,6 +9,7 @@ from tqdm import tqdm
from ...file_utils import is_tf_available, is_torch_available from ...file_utils import is_tf_available, is_torch_available
from ...tokenization_bert import whitespace_tokenize from ...tokenization_bert import whitespace_tokenize
from ...tokenization_utils_base import TruncationStrategy from ...tokenization_utils_base import TruncationStrategy
from ...utils import logging
from .utils import DataProcessor from .utils import DataProcessor
@@ -24,7 +24,7 @@ if is_torch_available():
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_answer_text): def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_answer_text):

View File

@@ -17,14 +17,14 @@
import csv import csv
import dataclasses import dataclasses
import json import json
import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Union from typing import List, Optional, Union
from ...file_utils import is_tf_available, is_torch_available from ...file_utils import is_tf_available, is_torch_available
from ...utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
@dataclass @dataclass

View File

@@ -16,13 +16,13 @@
""" XNLI utils (dataset loading and evaluation) """ """ XNLI utils (dataset loading and evaluation) """
import logging
import os import os
from ...utils import logging
from .utils import DataProcessor, InputExample from .utils import DataProcessor, InputExample
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
class XnliProcessor(DataProcessor): class XnliProcessor(DataProcessor):

View File

@@ -6,7 +6,6 @@ Copyright by the AllenNLP authors.
import fnmatch import fnmatch
import json import json
import logging
import os import os
import re import re
import shutil import shutil
@@ -30,9 +29,10 @@ import requests
from filelock import FileLock from filelock import FileLock
from . import __version__ from . import __version__
from .utils import logging
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
try: try:
USE_TF = os.environ.get("USE_TF", "AUTO").upper() USE_TF = os.environ.get("USE_TF", "AUTO").upper()
@@ -757,7 +757,7 @@ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent: Union[Dict
total=total, total=total,
initial=resume_size, initial=resume_size,
desc="Downloading", desc="Downloading",
disable=bool(logger.getEffectiveLevel() == logging.NOTSET), disable=bool(logging.get_verbosity() > logging.NOTSET),
) )
for chunk in response.iter_content(chunk_size=1024): for chunk in response.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks if chunk: # filter out keep-alive new chunks

View File

@@ -14,13 +14,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from .utils import logging
logger = logging.getLogger(__name__)
logger = logging.get_logger(__name__)
class TFGenerationMixin: class TFGenerationMixin:

View File

@@ -14,15 +14,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple
import torch import torch
from torch import Tensor from torch import Tensor
from torch.nn import functional as F from torch.nn import functional as F
from .utils import logging
logger = logging.getLogger(__name__)
logger = logging.get_logger(__name__)
class GenerationMixin: class GenerationMixin:

View File

@@ -17,7 +17,6 @@
import copy import copy
import json import json
import logging
import os import os
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
@@ -30,9 +29,10 @@ from .file_utils import (
hf_bucket_url, hf_bucket_url,
is_remote_url, is_remote_url,
) )
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
class ModelCard: class ModelCard:

View File

@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
"""PyTorch ALBERT model. """ """PyTorch ALBERT model. """
import logging
import math import math
import os import os
import warnings import warnings
@@ -44,9 +43,10 @@ from .modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from .modeling_utils import PreTrainedModel, apply_chunking_to_forward, find_pruneable_heads_and_indices from .modeling_utils import PreTrainedModel, apply_chunking_to_forward, find_pruneable_heads_and_indices
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "AlbertConfig" _CONFIG_FOR_DOC = "AlbertConfig"
_TOKENIZER_FOR_DOC = "AlbertTokenizer" _TOKENIZER_FOR_DOC = "AlbertTokenizer"

View File

@@ -15,7 +15,6 @@
""" Auto Model class. """ """ Auto Model class. """
import logging
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
@@ -172,9 +171,10 @@ from .modeling_xlnet import (
XLNetLMHeadModel, XLNetLMHeadModel,
XLNetModel, XLNetModel,
) )
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
MODEL_MAPPING = OrderedDict( MODEL_MAPPING = OrderedDict(

View File

@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""PyTorch BART model, ported from the fairseq repo.""" """PyTorch BART model, ported from the fairseq repo."""
import logging
import math import math
import random import random
import warnings import warnings
@@ -43,9 +42,10 @@ from .modeling_outputs import (
Seq2SeqSequenceClassifierOutput, Seq2SeqSequenceClassifierOutput,
) )
from .modeling_utils import PreTrainedModel from .modeling_utils import PreTrainedModel
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "BartConfig" _CONFIG_FOR_DOC = "BartConfig"
_TOKENIZER_FOR_DOC = "BartTokenizer" _TOKENIZER_FOR_DOC = "BartTokenizer"

View File

@@ -16,7 +16,6 @@
"""PyTorch BERT model. """ """PyTorch BERT model. """
import logging
import math import math
import os import os
import warnings import warnings
@@ -54,9 +53,10 @@ from .modeling_utils import (
find_pruneable_heads_and_indices, find_pruneable_heads_and_indices,
prune_linear_layer, prune_linear_layer,
) )
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "BertConfig" _CONFIG_FOR_DOC = "BertConfig"
_TOKENIZER_FOR_DOC = "BertTokenizer" _TOKENIZER_FOR_DOC = "BertTokenizer"

View File

@@ -15,8 +15,6 @@
# limitations under the License. # limitations under the License.
"""PyTorch CamemBERT model. """ """PyTorch CamemBERT model. """
import logging
from .configuration_camembert import CamembertConfig from .configuration_camembert import CamembertConfig
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
from .modeling_roberta import ( from .modeling_roberta import (
@@ -28,9 +26,10 @@ from .modeling_roberta import (
RobertaForTokenClassification, RobertaForTokenClassification,
RobertaModel, RobertaModel,
) )
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
_TOKENIZER_FOR_DOC = "CamembertTokenizer" _TOKENIZER_FOR_DOC = "CamembertTokenizer"

View File

@@ -16,7 +16,6 @@
""" PyTorch CTRL model.""" """ PyTorch CTRL model."""
import logging
import warnings import warnings
import numpy as np import numpy as np
@@ -28,9 +27,10 @@ from .configuration_ctrl import CTRLConfig
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
from .modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from .modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from .modeling_utils import Conv1D, PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from .modeling_utils import Conv1D, PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "CTRLConfig" _CONFIG_FOR_DOC = "CTRLConfig"
_TOKENIZER_FOR_DOC = "CTRLTokenizer" _TOKENIZER_FOR_DOC = "CTRLTokenizer"

View File

@@ -19,7 +19,6 @@
import copy import copy
import logging
import math import math
import warnings import warnings
@@ -50,9 +49,10 @@ from .modeling_utils import (
find_pruneable_heads_and_indices, find_pruneable_heads_and_indices,
prune_linear_layer, prune_linear_layer,
) )
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "DistilBertConfig" _CONFIG_FOR_DOC = "DistilBertConfig"
_TOKENIZER_FOR_DOC = "DistilBertTokenizer" _TOKENIZER_FOR_DOC = "DistilBertTokenizer"

View File

@@ -15,7 +15,6 @@
""" PyTorch DPR model for Open Domain Question Answering.""" """ PyTorch DPR model for Open Domain Question Answering."""
import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
@@ -27,9 +26,10 @@ from .file_utils import ModelOutput, add_start_docstrings, add_start_docstrings_
from .modeling_bert import BertModel from .modeling_bert import BertModel
from .modeling_outputs import BaseModelOutputWithPooling from .modeling_outputs import BaseModelOutputWithPooling
from .modeling_utils import PreTrainedModel from .modeling_utils import PreTrainedModel
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "DPRConfig" _CONFIG_FOR_DOC = "DPRConfig"

View File

@@ -1,4 +1,3 @@
import logging
import os import os
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
@@ -27,9 +26,10 @@ from .modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from .modeling_utils import SequenceSummary from .modeling_utils import SequenceSummary
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "ElectraConfig" _CONFIG_FOR_DOC = "ElectraConfig"
_TOKENIZER_FOR_DOC = "ElectraTokenizer" _TOKENIZER_FOR_DOC = "ElectraTokenizer"

View File

@@ -15,15 +15,15 @@
""" Classes to support Encoder-Decoder architectures """ """ Classes to support Encoder-Decoder architectures """
import logging
from typing import Optional from typing import Optional
from .configuration_encoder_decoder import EncoderDecoderConfig from .configuration_encoder_decoder import EncoderDecoderConfig
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .modeling_utils import PreTrainedModel from .modeling_utils import PreTrainedModel
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
class EncoderDecoderModel(PreTrainedModel): class EncoderDecoderModel(PreTrainedModel):

View File

@@ -15,7 +15,6 @@
""" PyTorch Flaubert model, based on XLM. """ """ PyTorch Flaubert model, based on XLM. """
import logging
import random import random
import torch import torch
@@ -34,9 +33,10 @@ from .modeling_xlm import (
XLMWithLMHeadModel, XLMWithLMHeadModel,
get_masks, get_masks,
) )
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "FlaubertConfig" _CONFIG_FOR_DOC = "FlaubertConfig"
_TOKENIZER_FOR_DOC = "FlaubertTokenizer" _TOKENIZER_FOR_DOC = "FlaubertTokenizer"

View File

@@ -16,7 +16,6 @@
"""PyTorch OpenAI GPT-2 model.""" """PyTorch OpenAI GPT-2 model."""
import logging
import os import os
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
@@ -43,9 +42,10 @@ from .modeling_utils import (
find_pruneable_heads_and_indices, find_pruneable_heads_and_indices,
prune_conv1d_layer, prune_conv1d_layer,
) )
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "GPT2Config" _CONFIG_FOR_DOC = "GPT2Config"
_TOKENIZER_FOR_DOC = "GPT2Tokenizer" _TOKENIZER_FOR_DOC = "GPT2Tokenizer"

View File

@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
"""PyTorch Longformer model. """ """PyTorch Longformer model. """
import logging
import math import math
import warnings import warnings
@@ -47,9 +46,10 @@ from .modeling_utils import (
find_pruneable_heads_and_indices, find_pruneable_heads_and_indices,
prune_linear_layer, prune_linear_layer,
) )
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "LongformerConfig" _CONFIG_FOR_DOC = "LongformerConfig"
_TOKENIZER_FOR_DOC = "LongformerTokenizer" _TOKENIZER_FOR_DOC = "LongformerTokenizer"

View File

@@ -16,8 +16,6 @@
"""PyTorch MMBT model. """ """PyTorch MMBT model. """
import logging
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
@@ -25,9 +23,10 @@ from torch.nn import CrossEntropyLoss, MSELoss
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable, replace_return_docstrings from .file_utils import add_start_docstrings, add_start_docstrings_to_callable, replace_return_docstrings
from .modeling_outputs import BaseModelOutputWithPooling, SequenceClassifierOutput from .modeling_outputs import BaseModelOutputWithPooling, SequenceClassifierOutput
from .modeling_utils import ModuleUtilsMixin from .modeling_utils import ModuleUtilsMixin
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "MMBTConfig" _CONFIG_FOR_DOC = "MMBTConfig"

View File

@@ -20,7 +20,6 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE. # SOFTWARE.
import logging
import math import math
import os import os
import warnings import warnings
@@ -53,9 +52,10 @@ from .modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "MobileBertConfig" _CONFIG_FOR_DOC = "MobileBertConfig"
_TOKENIZER_FOR_DOC = "MobileBertTokenizer" _TOKENIZER_FOR_DOC = "MobileBertTokenizer"

View File

@@ -17,7 +17,6 @@
import json import json
import logging
import math import math
import os import os
import warnings import warnings
@@ -45,9 +44,10 @@ from .modeling_utils import (
find_pruneable_heads_and_indices, find_pruneable_heads_and_indices,
prune_conv1d_layer, prune_conv1d_layer,
) )
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "OpenAIGPTConfig" _CONFIG_FOR_DOC = "OpenAIGPTConfig"
_TOKENIZER_FOR_DOC = "OpenAIGPTTokenizer" _TOKENIZER_FOR_DOC = "OpenAIGPTTokenizer"

View File

@@ -15,7 +15,6 @@
# limitations under the License. # limitations under the License.
"""PyTorch REFORMER model. """ """PyTorch REFORMER model. """
import logging
import sys import sys
from collections import namedtuple from collections import namedtuple
from dataclasses import dataclass from dataclasses import dataclass
@@ -41,9 +40,10 @@ from .file_utils import (
) )
from .modeling_outputs import CausalLMOutput, MaskedLMOutput, QuestionAnsweringModelOutput, SequenceClassifierOutput from .modeling_outputs import CausalLMOutput, MaskedLMOutput, QuestionAnsweringModelOutput, SequenceClassifierOutput
from .modeling_utils import PreTrainedModel, apply_chunking_to_forward from .modeling_utils import PreTrainedModel, apply_chunking_to_forward
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "ReformerConfig" _CONFIG_FOR_DOC = "ReformerConfig"
_TOKENIZER_FOR_DOC = "ReformerTokenizer" _TOKENIZER_FOR_DOC = "ReformerTokenizer"

View File

@@ -17,7 +17,6 @@ RetriBERT model
""" """
import logging
import math import math
import torch import torch
@@ -28,9 +27,10 @@ from .configuration_retribert import RetriBertConfig
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
from .modeling_bert import BertLayerNorm, BertModel from .modeling_bert import BertLayerNorm, BertModel
from .modeling_utils import PreTrainedModel from .modeling_utils import PreTrainedModel
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"yjernite/retribert-base-uncased", "yjernite/retribert-base-uncased",

View File

@@ -16,7 +16,6 @@
"""PyTorch RoBERTa model. """ """PyTorch RoBERTa model. """
import logging
import warnings import warnings
import torch import torch
@@ -39,9 +38,10 @@ from .modeling_outputs import (
SequenceClassifierOutput, SequenceClassifierOutput,
TokenClassifierOutput, TokenClassifierOutput,
) )
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "RobertaConfig" _CONFIG_FOR_DOC = "RobertaConfig"
_TOKENIZER_FOR_DOC = "RobertaTokenizer" _TOKENIZER_FOR_DOC = "RobertaTokenizer"

View File

@@ -16,7 +16,6 @@
import copy import copy
import logging
import math import math
import os import os
import warnings import warnings
@@ -36,9 +35,10 @@ from .file_utils import (
) )
from .modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, Seq2SeqLMOutput, Seq2SeqModelOutput from .modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, Seq2SeqLMOutput, Seq2SeqModelOutput
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "T5Config" _CONFIG_FOR_DOC = "T5Config"
_TOKENIZER_FOR_DOC = "T5Tokenizer" _TOKENIZER_FOR_DOC = "T5Tokenizer"

View File

@@ -16,7 +16,6 @@
""" TF 2.0 ALBERT model. """ """ TF 2.0 ALBERT model. """
import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
@@ -53,9 +52,10 @@ from .modeling_tf_utils import (
shape_list, shape_list,
) )
from .tokenization_utils import BatchEncoding from .tokenization_utils import BatchEncoding
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "AlbertConfig" _CONFIG_FOR_DOC = "AlbertConfig"
_TOKENIZER_FOR_DOC = "AlbertTokenizer" _TOKENIZER_FOR_DOC = "AlbertTokenizer"

View File

@@ -15,7 +15,6 @@
""" Auto Model class. """ """ Auto Model class. """
import logging
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
@@ -139,9 +138,10 @@ from .modeling_tf_xlnet import (
TFXLNetLMHeadModel, TFXLNetLMHeadModel,
TFXLNetModel, TFXLNetModel,
) )
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
TF_MODEL_MAPPING = OrderedDict( TF_MODEL_MAPPING = OrderedDict(

View File

@@ -16,7 +16,6 @@
""" TF 2.0 BERT model. """ """ TF 2.0 BERT model. """
import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
@@ -56,9 +55,10 @@ from .modeling_tf_utils import (
shape_list, shape_list,
) )
from .tokenization_utils import BatchEncoding from .tokenization_utils import BatchEncoding
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "BertConfig" _CONFIG_FOR_DOC = "BertConfig"
_TOKENIZER_FOR_DOC = "BertTokenizer" _TOKENIZER_FOR_DOC = "BertTokenizer"

View File

@@ -15,9 +15,6 @@
# limitations under the License. # limitations under the License.
""" TF 2.0 CamemBERT model. """ """ TF 2.0 CamemBERT model. """
import logging
from .configuration_camembert import CamembertConfig from .configuration_camembert import CamembertConfig
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
from .modeling_tf_roberta import ( from .modeling_tf_roberta import (
@@ -28,9 +25,10 @@ from .modeling_tf_roberta import (
TFRobertaForTokenClassification, TFRobertaForTokenClassification,
TFRobertaModel, TFRobertaModel,
) )
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
# See all CamemBERT models at https://huggingface.co/models?filter=camembert # See all CamemBERT models at https://huggingface.co/models?filter=camembert

View File

@@ -16,8 +16,6 @@
""" TF 2.0 CTRL model.""" """ TF 2.0 CTRL model."""
import logging
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
@@ -32,9 +30,10 @@ from .modeling_tf_utils import (
shape_list, shape_list,
) )
from .tokenization_utils import BatchEncoding from .tokenization_utils import BatchEncoding
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "CTRLConfig" _CONFIG_FOR_DOC = "CTRLConfig"
_TOKENIZER_FOR_DOC = "CTRLTokenizer" _TOKENIZER_FOR_DOC = "CTRLTokenizer"

View File

@@ -16,7 +16,6 @@
""" """
import logging
import math import math
import numpy as np import numpy as np
@@ -50,9 +49,10 @@ from .modeling_tf_utils import (
shape_list, shape_list,
) )
from .tokenization_utils import BatchEncoding from .tokenization_utils import BatchEncoding
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "DistilBertConfig" _CONFIG_FOR_DOC = "DistilBertConfig"
_TOKENIZER_FOR_DOC = "DistilBertTokenizer" _TOKENIZER_FOR_DOC = "DistilBertTokenizer"

View File

@@ -1,4 +1,3 @@
import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
@@ -35,9 +34,10 @@ from .modeling_tf_utils import (
shape_list, shape_list,
) )
from .tokenization_utils import BatchEncoding from .tokenization_utils import BatchEncoding
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "ElectraConfig" _CONFIG_FOR_DOC = "ElectraConfig"
_TOKENIZER_FOR_DOC = "ElectraTokenizer" _TOKENIZER_FOR_DOC = "ElectraTokenizer"

View File

@@ -15,7 +15,6 @@
""" TF 2.0 Flaubert model. """ TF 2.0 Flaubert model.
""" """
import logging
import random import random
import tensorflow as tf import tensorflow as tf
@@ -36,9 +35,10 @@ from .modeling_tf_xlm import (
get_masks, get_masks,
) )
from .tokenization_utils import BatchEncoding from .tokenization_utils import BatchEncoding
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
# See all Flaubert models at https://huggingface.co/models?filter=flaubert # See all Flaubert models at https://huggingface.co/models?filter=flaubert

View File

@@ -16,7 +16,6 @@
""" TF 2.0 OpenAI GPT-2 model. """ """ TF 2.0 OpenAI GPT-2 model. """
import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
@@ -43,9 +42,10 @@ from .modeling_tf_utils import (
shape_list, shape_list,
) )
from .tokenization_utils import BatchEncoding from .tokenization_utils import BatchEncoding
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "GPT2Config" _CONFIG_FOR_DOC = "GPT2Config"
_TOKENIZER_FOR_DOC = "GPT2Tokenizer" _TOKENIZER_FOR_DOC = "GPT2Tokenizer"

View File

@@ -14,8 +14,6 @@
# limitations under the License. # limitations under the License.
"""Tensorflow Longformer model. """ """Tensorflow Longformer model. """
import logging
import tensorflow as tf import tensorflow as tf
from .configuration_longformer import LongformerConfig from .configuration_longformer import LongformerConfig
@@ -37,9 +35,10 @@ from .modeling_tf_utils import (
shape_list, shape_list,
) )
from .tokenization_utils import BatchEncoding from .tokenization_utils import BatchEncoding
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "LongformerConfig" _CONFIG_FOR_DOC = "LongformerConfig"
_TOKENIZER_FOR_DOC = "LongformerTokenizer" _TOKENIZER_FOR_DOC = "LongformerTokenizer"

View File

@@ -16,7 +16,6 @@
""" TF 2.0 MobileBERT model. """ """ TF 2.0 MobileBERT model. """
import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
@@ -54,9 +53,10 @@ from .modeling_tf_utils import (
shape_list, shape_list,
) )
from .tokenization_utils import BatchEncoding from .tokenization_utils import BatchEncoding
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "MobileBertConfig" _CONFIG_FOR_DOC = "MobileBertConfig"
_TOKENIZER_FOR_DOC = "MobileBertTokenizer" _TOKENIZER_FOR_DOC = "MobileBertTokenizer"

View File

@@ -16,7 +16,6 @@
""" TF 2.0 OpenAI GPT model.""" """ TF 2.0 OpenAI GPT model."""
import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
@@ -43,9 +42,10 @@ from .modeling_tf_utils import (
shape_list, shape_list,
) )
from .tokenization_utils import BatchEncoding from .tokenization_utils import BatchEncoding
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "OpenAIGPTConfig" _CONFIG_FOR_DOC = "OpenAIGPTConfig"
_TOKENIZER_FOR_DOC = "OpenAIGPTTokenizer" _TOKENIZER_FOR_DOC = "OpenAIGPTTokenizer"

View File

@@ -16,14 +16,15 @@
""" PyTorch - TF 2.0 general utilities.""" """ PyTorch - TF 2.0 general utilities."""
import logging
import os import os
import re import re
import numpy import numpy
from .utils import logging
logger = logging.getLogger(__name__)
logger = logging.get_logger(__name__)
def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove=""): def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove=""):

View File

@@ -16,8 +16,6 @@
""" TF 2.0 RoBERTa model. """ """ TF 2.0 RoBERTa model. """
import logging
import tensorflow as tf import tensorflow as tf
from .configuration_roberta import RobertaConfig from .configuration_roberta import RobertaConfig
@@ -48,9 +46,10 @@ from .modeling_tf_utils import (
shape_list, shape_list,
) )
from .tokenization_utils_base import BatchEncoding from .tokenization_utils_base import BatchEncoding
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "RobertaConfig" _CONFIG_FOR_DOC = "RobertaConfig"
_TOKENIZER_FOR_DOC = "RobertaTokenizer" _TOKENIZER_FOR_DOC = "RobertaTokenizer"

View File

@@ -18,7 +18,6 @@
import copy import copy
import itertools import itertools
import logging
import math import math
import warnings import warnings
@@ -42,9 +41,10 @@ from .modeling_tf_utils import (
shape_list, shape_list,
) )
from .tokenization_utils import BatchEncoding from .tokenization_utils import BatchEncoding
from .utils import logging
logger = logging.getLogger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "T5Config" _CONFIG_FOR_DOC = "T5Config"
_TOKENIZER_FOR_DOC = "T5Tokenizer" _TOKENIZER_FOR_DOC = "T5Tokenizer"

Some files were not shown because too many files have changed in this diff Show More