update conversion scripts
This commit is contained in:
@@ -1,4 +1,15 @@
|
|||||||
__version__ = "1.2.0"
|
__version__ = "1.2.0"
|
||||||
|
# Work around to update TensorFlow's absl.logging threshold which alters the
|
||||||
|
# default Python logging output behavior when present.
|
||||||
|
# see: https://github.com/abseil/abseil-py/issues/99
|
||||||
|
# and: https://github.com/tensorflow/tensorflow/issues/26691#issuecomment-500369493
|
||||||
|
try:
|
||||||
|
import absl.logging
|
||||||
|
absl.logging.set_verbosity('info')
|
||||||
|
absl.logging.set_stderrthreshold('info')
|
||||||
|
absl.logging._warn_preinit_stderr = False
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
# Tokenizer
|
# Tokenizer
|
||||||
from .tokenization_utils import (PreTrainedTokenizer)
|
from .tokenization_utils import (PreTrainedTokenizer)
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from io import open
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch_transformers.modeling_gpt2 import (CONFIG_NAME, WEIGHTS_NAME,
|
from pytorch_transformers import (CONFIG_NAME, WEIGHTS_NAME,
|
||||||
GPT2Config,
|
GPT2Config,
|
||||||
GPT2Model,
|
GPT2Model,
|
||||||
load_tf_weights_in_gpt2)
|
load_tf_weights_in_gpt2)
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from io import open
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch_transformers.modeling_openai import (CONFIG_NAME, WEIGHTS_NAME,
|
from pytorch_transformers import (CONFIG_NAME, WEIGHTS_NAME,
|
||||||
OpenAIGPTConfig,
|
OpenAIGPTConfig,
|
||||||
OpenAIGPTModel,
|
OpenAIGPTModel,
|
||||||
load_tf_weights_in_openai_gpt)
|
load_tf_weights_in_openai_gpt)
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import argparse
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from pytorch_transformers.modeling import BertModel
|
from pytorch_transformers import BertModel
|
||||||
|
|
||||||
|
|
||||||
def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:str):
|
def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:str):
|
||||||
|
|||||||
@@ -23,12 +23,12 @@ import torch
|
|||||||
|
|
||||||
from fairseq.models.roberta import RobertaModel as FairseqRobertaModel
|
from fairseq.models.roberta import RobertaModel as FairseqRobertaModel
|
||||||
from fairseq.modules import TransformerSentenceEncoderLayer
|
from fairseq.modules import TransformerSentenceEncoderLayer
|
||||||
from pytorch_transformers.modeling_bert import (BertConfig, BertEncoder,
|
from pytorch_transformers import (BertConfig, BertEncoder,
|
||||||
BertIntermediate, BertLayer,
|
BertIntermediate, BertLayer,
|
||||||
BertModel, BertOutput,
|
BertModel, BertOutput,
|
||||||
BertSelfAttention,
|
BertSelfAttention,
|
||||||
BertSelfOutput)
|
BertSelfOutput)
|
||||||
from pytorch_transformers.modeling_roberta import (RobertaEmbeddings,
|
from pytorch_transformers import (RobertaEmbeddings,
|
||||||
RobertaForMaskedLM,
|
RobertaForMaskedLM,
|
||||||
RobertaForSequenceClassification,
|
RobertaForSequenceClassification,
|
||||||
RobertaModel)
|
RobertaModel)
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from __future__ import print_function
|
|||||||
import argparse
|
import argparse
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch_transformers.modeling_bert import BertConfig, BertForPreTraining, load_tf_weights_in_bert
|
from pytorch_transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ import torch
|
|||||||
import pytorch_transformers.tokenization_transfo_xl as data_utils
|
import pytorch_transformers.tokenization_transfo_xl as data_utils
|
||||||
|
|
||||||
from pytorch_transformers import CONFIG_NAME, WEIGHTS_NAME
|
from pytorch_transformers import CONFIG_NAME, WEIGHTS_NAME
|
||||||
from pytorch_transformers.modeling_transfo_xl import (TransfoXLConfig, TransfoXLLMHeadModel,
|
from pytorch_transformers import (TransfoXLConfig, TransfoXLLMHeadModel,
|
||||||
load_tf_weights_in_transfo_xl)
|
load_tf_weights_in_transfo_xl)
|
||||||
from pytorch_transformers.tokenization_transfo_xl import (CORPUS_NAME, VOCAB_FILES_NAMES)
|
from pytorch_transformers.tokenization_transfo_xl import (CORPUS_NAME, VOCAB_FILES_NAMES)
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ from io import open
|
|||||||
import torch
|
import torch
|
||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
from pytorch_transformers.modeling_utils import CONFIG_NAME, WEIGHTS_NAME
|
from pytorch_transformers import CONFIG_NAME, WEIGHTS_NAME
|
||||||
from pytorch_transformers.tokenization_xlm import VOCAB_FILES_NAMES
|
from pytorch_transformers.tokenization_xlm import VOCAB_FILES_NAMES
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import os
|
|||||||
import argparse
|
import argparse
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch_transformers.modeling_xlnet import (CONFIG_NAME, WEIGHTS_NAME,
|
from pytorch_transformers import (CONFIG_NAME, WEIGHTS_NAME,
|
||||||
XLNetConfig,
|
XLNetConfig,
|
||||||
XLNetLMHeadModel, XLNetForQuestionAnswering,
|
XLNetLMHeadModel, XLNetForQuestionAnswering,
|
||||||
XLNetForSequenceClassification,
|
XLNetForSequenceClassification,
|
||||||
|
|||||||
Reference in New Issue
Block a user