Fix imports in conversion scripts (#9674)
This commit is contained in:
@@ -19,8 +19,8 @@ import argparse
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ...utils import logging
|
from transformers import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert
|
||||||
from . import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
|
||||||
logging.set_verbosity_info()
|
logging.set_verbosity_info()
|
||||||
|
|||||||
@@ -23,9 +23,15 @@ import fairseq
|
|||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from ...utils import logging
|
from transformers import (
|
||||||
from . import BartConfig, BartForConditionalGeneration, BartForSequenceClassification, BartModel, BartTokenizer
|
BartConfig,
|
||||||
from .modeling_bart import _make_linear_from_emb
|
BartForConditionalGeneration,
|
||||||
|
BartForSequenceClassification,
|
||||||
|
BartModel,
|
||||||
|
BartTokenizer,
|
||||||
|
)
|
||||||
|
from transformers.models.bart.modeling_bart import _make_linear_from_emb
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
|
||||||
FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn", "bart_xsum/model.pt"]
|
FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn", "bart_xsum/model.pt"]
|
||||||
|
|||||||
@@ -28,8 +28,8 @@ import re
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ...utils import logging
|
from transformers import BertConfig, BertModel
|
||||||
from . import BertConfig, BertModel
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
|
||||||
logging.set_verbosity_info()
|
logging.set_verbosity_info()
|
||||||
|
|||||||
@@ -19,8 +19,8 @@ import argparse
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ...utils import logging
|
from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
|
||||||
from . import BertConfig, BertForPreTraining, load_tf_weights_in_bert
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
|
||||||
logging.set_verbosity_info()
|
logging.set_verbosity_info()
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import numpy as np
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from . import BertModel
|
from transformers import BertModel
|
||||||
|
|
||||||
|
|
||||||
def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str, model_name: str):
|
def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str, model_name: str):
|
||||||
|
|||||||
@@ -18,8 +18,8 @@ import argparse
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ...models.bart import BartConfig, BartForConditionalGeneration
|
from transformers import BartConfig, BartForConditionalGeneration
|
||||||
from ...utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
|
||||||
logging.set_verbosity_info()
|
logging.set_verbosity_info()
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ import os
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ...file_utils import WEIGHTS_NAME
|
from transformers.file_utils import WEIGHTS_NAME
|
||||||
|
|
||||||
|
|
||||||
DIALOGPT_MODELS = ["small", "medium", "large"]
|
DIALOGPT_MODELS = ["small", "medium", "large"]
|
||||||
|
|||||||
@@ -19,8 +19,7 @@ from pathlib import Path
|
|||||||
import torch
|
import torch
|
||||||
from torch.serialization import default_restore_location
|
from torch.serialization import default_restore_location
|
||||||
|
|
||||||
from ...models.bert import BertConfig
|
from .transformers import BertConfig, DPRConfig, DPRContextEncoder, DPRQuestionEncoder, DPRReader
|
||||||
from . import DPRConfig, DPRContextEncoder, DPRQuestionEncoder, DPRReader
|
|
||||||
|
|
||||||
|
|
||||||
CheckpointState = collections.namedtuple(
|
CheckpointState = collections.namedtuple(
|
||||||
|
|||||||
@@ -19,8 +19,8 @@ import argparse
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ...utils import logging
|
from transformers import ElectraConfig, ElectraForMaskedLM, ElectraForPreTraining, load_tf_weights_in_electra
|
||||||
from . import ElectraConfig, ElectraForMaskedLM, ElectraForPreTraining, load_tf_weights_in_electra
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
|
||||||
logging.set_verbosity_info()
|
logging.set_verbosity_info()
|
||||||
|
|||||||
@@ -31,10 +31,11 @@ import torch
|
|||||||
from fairseq import hub_utils
|
from fairseq import hub_utils
|
||||||
from fairseq.data.dictionary import Dictionary
|
from fairseq.data.dictionary import Dictionary
|
||||||
|
|
||||||
from ...file_utils import WEIGHTS_NAME
|
from transfomers.models.fsmt.tokenization_fsmt import VOCAB_FILES_NAMES
|
||||||
from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE
|
from transformers import FSMTConfig, FSMTForConditionalGeneration
|
||||||
from ...utils import logging
|
from transformers.file_utils import WEIGHTS_NAME
|
||||||
from . import VOCAB_FILES_NAMES, FSMTConfig, FSMTForConditionalGeneration
|
from transformers.tokenization_utils_base import TOKENIZER_CONFIG_FILE
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
|
||||||
logging.set_verbosity_warning()
|
logging.set_verbosity_warning()
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from ...file_utils import _BaseLazyModule, is_tf_available, is_tokenizers_availa
|
|||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
"configuration_funnel": ["FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP", "FunnelConfig"],
|
"configuration_funnel": ["FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP", "FunnelConfig"],
|
||||||
|
"convert_funnel_original_tf_checkpoint_to_pytorch": [],
|
||||||
"tokenization_funnel": ["FunnelTokenizer"],
|
"tokenization_funnel": ["FunnelTokenizer"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -16,14 +16,14 @@
|
|||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from . import FunnelConfig, FunnelForPreTraining, load_tf_weights_in_funnel
|
from transformers import FunnelConfig, FunnelForPreTraining, load_tf_weights_in_funnel
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.set_verbosity_info()
|
||||||
|
|
||||||
|
|
||||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
|
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
|
||||||
|
|||||||
@@ -19,9 +19,9 @@ import argparse
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ...file_utils import CONFIG_NAME, WEIGHTS_NAME
|
from transformers import GPT2Config, GPT2Model, load_tf_weights_in_gpt2
|
||||||
from ...utils import logging
|
from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME
|
||||||
from . import GPT2Config, GPT2Model, load_tf_weights_in_gpt2
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
|
||||||
logging.set_verbosity_info()
|
logging.set_verbosity_info()
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import argparse
|
|||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from . import LongformerForQuestionAnswering, LongformerModel
|
from transformers import LongformerForQuestionAnswering, LongformerModel
|
||||||
|
|
||||||
|
|
||||||
class LightningModel(pl.LightningModule):
|
class LightningModel(pl.LightningModule):
|
||||||
|
|||||||
@@ -16,14 +16,14 @@
|
|||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from . import LxmertConfig, LxmertForPreTraining, load_tf_weights_in_lxmert
|
from transformers import LxmertConfig, LxmertForPreTraining, load_tf_weights_in_lxmert
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.set_verbosity_info()
|
||||||
|
|
||||||
|
|
||||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
|
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
from .convert_marian_to_pytorch import (
|
from transformers.models.marian.convert_marian_to_pytorch import (
|
||||||
FRONT_MATTER_TEMPLATE,
|
FRONT_MATTER_TEMPLATE,
|
||||||
_parse_readme,
|
_parse_readme,
|
||||||
convert_all_sentencepiece_models,
|
convert_all_sentencepiece_models,
|
||||||
|
|||||||
@@ -26,8 +26,8 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from ...hf_api import HfApi
|
from transformers import MarianConfig, MarianMTModel, MarianTokenizer
|
||||||
from . import MarianConfig, MarianMTModel, MarianTokenizer
|
from transformers.hf_api import HfApi
|
||||||
|
|
||||||
|
|
||||||
def remove_suffix(text: str, suffix: str):
|
def remove_suffix(text: str, suffix: str):
|
||||||
|
|||||||
@@ -16,9 +16,8 @@ import argparse
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..bart import BartForConditionalGeneration
|
from transformers import BartForConditionalGeneration, MBartConfig
|
||||||
from ..bart.convert_bart_original_pytorch_checkpoint_to_pytorch import remove_ignore_keys_
|
from transformers.models.bart.convert_bart_original_pytorch_checkpoint_to_pytorch import remove_ignore_keys_
|
||||||
from . import MBartConfig
|
|
||||||
|
|
||||||
|
|
||||||
def convert_fairseq_mbart_checkpoint_from_disk(checkpoint_path, hf_config_path="facebook/mbart-large-en-ro"):
|
def convert_fairseq_mbart_checkpoint_from_disk(checkpoint_path, hf_config_path="facebook/mbart-large-en-ro"):
|
||||||
|
|||||||
@@ -16,8 +16,8 @@ import argparse
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ...utils import logging
|
from transformers import MobileBertConfig, MobileBertForPreTraining, load_tf_weights_in_mobilebert
|
||||||
from . import MobileBertConfig, MobileBertForPreTraining, load_tf_weights_in_mobilebert
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
|
||||||
logging.set_verbosity_info()
|
logging.set_verbosity_info()
|
||||||
|
|||||||
@@ -22,8 +22,8 @@ import tensorflow as tf
|
|||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from . import PegasusConfig, PegasusForConditionalGeneration, PegasusTokenizer
|
from transformers import PegasusConfig, PegasusForConditionalGeneration, PegasusTokenizer
|
||||||
from .configuration_pegasus import DEFAULTS, task_specific_params
|
from transformers.models.pegasus.configuration_pegasus import DEFAULTS, task_specific_params
|
||||||
|
|
||||||
|
|
||||||
PATTERNS = [
|
PATTERNS = [
|
||||||
|
|||||||
@@ -19,6 +19,8 @@ import argparse
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from transformers import ProphetNetForConditionalGeneration, XLMProphetNetForConditionalGeneration, logging
|
||||||
|
|
||||||
# transformers_old should correspond to branch `save_old_prophetnet_model_structure` here
|
# transformers_old should correspond to branch `save_old_prophetnet_model_structure` here
|
||||||
# original prophetnet_checkpoints are saved under `patrickvonplaten/..._old` respectively
|
# original prophetnet_checkpoints are saved under `patrickvonplaten/..._old` respectively
|
||||||
from transformers_old.modeling_prophetnet import (
|
from transformers_old.modeling_prophetnet import (
|
||||||
@@ -28,8 +30,6 @@ from transformers_old.modeling_xlm_prophetnet import (
|
|||||||
XLMProphetNetForConditionalGeneration as XLMProphetNetForConditionalGenerationOld,
|
XLMProphetNetForConditionalGeneration as XLMProphetNetForConditionalGenerationOld,
|
||||||
)
|
)
|
||||||
|
|
||||||
from . import ProphetNetForConditionalGeneration, XLMProphetNetForConditionalGeneration, logging
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
logging.set_verbosity_info()
|
logging.set_verbosity_info()
|
||||||
|
|||||||
@@ -21,8 +21,8 @@ import pickle
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ...utils import logging
|
from transformers import ReformerConfig, ReformerModelWithLMHead
|
||||||
from . import ReformerConfig, ReformerModelWithLMHead
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
|
||||||
logging.set_verbosity_info()
|
logging.set_verbosity_info()
|
||||||
|
|||||||
@@ -24,9 +24,15 @@ from fairseq.models.roberta import RobertaModel as FairseqRobertaModel
|
|||||||
from fairseq.modules import TransformerSentenceEncoderLayer
|
from fairseq.modules import TransformerSentenceEncoderLayer
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from ...models.bert.modeling_bert import BertIntermediate, BertLayer, BertOutput, BertSelfAttention, BertSelfOutput
|
from transformers import RobertaConfig, RobertaForMaskedLM, RobertaForSequenceClassification
|
||||||
from ...utils import logging
|
from transformers.models.bert.modeling_bert import (
|
||||||
from .modeling_roberta import RobertaConfig, RobertaForMaskedLM, RobertaForSequenceClassification
|
BertIntermediate,
|
||||||
|
BertLayer,
|
||||||
|
BertOutput,
|
||||||
|
BertSelfAttention,
|
||||||
|
BertSelfOutput,
|
||||||
|
)
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
|
||||||
if version.parse(fairseq.__version__) < version.parse("0.9.0"):
|
if version.parse(fairseq.__version__) < version.parse("0.9.0"):
|
||||||
|
|||||||
@@ -17,8 +17,8 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from ...utils import logging
|
from transformers import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5
|
||||||
from . import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
|
||||||
logging.set_verbosity_info()
|
logging.set_verbosity_info()
|
||||||
|
|||||||
@@ -17,8 +17,7 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from ...utils import logging
|
from transformers import (
|
||||||
from . import (
|
|
||||||
TapasConfig,
|
TapasConfig,
|
||||||
TapasForMaskedLM,
|
TapasForMaskedLM,
|
||||||
TapasForQuestionAnswering,
|
TapasForQuestionAnswering,
|
||||||
@@ -27,6 +26,7 @@ from . import (
|
|||||||
TapasTokenizer,
|
TapasTokenizer,
|
||||||
load_tf_weights_in_tapas,
|
load_tf_weights_in_tapas,
|
||||||
)
|
)
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
|
||||||
logging.set_verbosity_info()
|
logging.set_verbosity_info()
|
||||||
|
|||||||
@@ -22,11 +22,11 @@ import sys
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ...file_utils import CONFIG_NAME, WEIGHTS_NAME
|
from transformers import TransfoXLConfig, TransfoXLLMHeadModel, load_tf_weights_in_transfo_xl
|
||||||
from ...utils import logging
|
from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME
|
||||||
from . import TransfoXLConfig, TransfoXLLMHeadModel, load_tf_weights_in_transfo_xl
|
from transformers.models.transfo_xl import tokenization_transfo_xl as data_utils
|
||||||
from . import tokenization_transfo_xl as data_utils
|
from transformers.models.transfo_xl.tokenization_transfo_xl import CORPUS_NAME, VOCAB_FILES_NAMES
|
||||||
from .tokenization_transfo_xl import CORPUS_NAME, VOCAB_FILES_NAMES
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
|
||||||
logging.set_verbosity_info()
|
logging.set_verbosity_info()
|
||||||
|
|||||||
@@ -21,9 +21,9 @@ import json
|
|||||||
import numpy
|
import numpy
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ...file_utils import CONFIG_NAME, WEIGHTS_NAME
|
from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME
|
||||||
from ...utils import logging
|
from transformers.models.xlm.tokenization_xlm import VOCAB_FILES_NAMES
|
||||||
from .tokenization_xlm import VOCAB_FILES_NAMES
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
|
||||||
logging.set_verbosity_info()
|
logging.set_verbosity_info()
|
||||||
|
|||||||
@@ -20,15 +20,15 @@ import os
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ...file_utils import CONFIG_NAME, WEIGHTS_NAME
|
from transformers import (
|
||||||
from ...utils import logging
|
|
||||||
from . import (
|
|
||||||
XLNetConfig,
|
XLNetConfig,
|
||||||
XLNetForQuestionAnswering,
|
XLNetForQuestionAnswering,
|
||||||
XLNetForSequenceClassification,
|
XLNetForSequenceClassification,
|
||||||
XLNetLMHeadModel,
|
XLNetLMHeadModel,
|
||||||
load_tf_weights_in_xlnet,
|
load_tf_weights_in_xlnet,
|
||||||
)
|
)
|
||||||
|
from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
|
||||||
GLUE_TASKS_NUM_LABELS = {
|
GLUE_TASKS_NUM_LABELS = {
|
||||||
|
|||||||
Reference in New Issue
Block a user