Sort imports with isort.
This is the result of:
$ isort --recursive examples templates transformers utils hubconf.py setup.py
This commit is contained in:
@@ -18,12 +18,14 @@
|
|||||||
# If checking the tensors placement
|
# If checking the tensors placement
|
||||||
# tf.debugging.set_log_device_placement(True)
|
# tf.debugging.set_log_device_placement(True)
|
||||||
|
|
||||||
from typing import List
|
|
||||||
import timeit
|
|
||||||
from transformers import is_tf_available, is_torch_available
|
|
||||||
from time import time
|
|
||||||
import argparse
|
import argparse
|
||||||
import csv
|
import csv
|
||||||
|
import timeit
|
||||||
|
from time import time
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from transformers import AutoConfig, AutoTokenizer, is_tf_available, is_torch_available
|
||||||
|
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
@@ -33,7 +35,6 @@ if is_torch_available():
|
|||||||
import torch
|
import torch
|
||||||
from transformers import AutoModel
|
from transformers import AutoModel
|
||||||
|
|
||||||
from transformers import AutoConfig, AutoTokenizer
|
|
||||||
|
|
||||||
input_text = """Bent over their instruments, three hundred Fertilizers were plunged, as
|
input_text = """Bent over their instruments, three hundred Fertilizers were plunged, as
|
||||||
the Director of Hatcheries and Conditioning entered the room, in the
|
the Director of Hatcheries and Conditioning entered the room, in the
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
from pathlib import Path
|
|
||||||
import tarfile
|
import tarfile
|
||||||
import urllib.request
|
import urllib.request
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers.tokenization_camembert import CamembertTokenizer
|
|
||||||
from transformers.modeling_camembert import CamembertForMaskedLM
|
from transformers.modeling_camembert import CamembertForMaskedLM
|
||||||
|
from transformers.tokenization_camembert import CamembertTokenizer
|
||||||
|
|
||||||
|
|
||||||
def fill_mask(masked_input, model, tokenizer, topk=5):
|
def fill_mask(masked_input, model, tokenizer, topk=5):
|
||||||
|
|||||||
@@ -28,26 +28,27 @@
|
|||||||
--train_batch_size 16 \
|
--train_batch_size 16 \
|
||||||
"""
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
|
||||||
import csv
|
import csv
|
||||||
import random
|
|
||||||
import logging
|
import logging
|
||||||
from tqdm import tqdm, trange
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||||
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
|
CONFIG_NAME,
|
||||||
|
WEIGHTS_NAME,
|
||||||
|
AdamW,
|
||||||
OpenAIGPTDoubleHeadsModel,
|
OpenAIGPTDoubleHeadsModel,
|
||||||
OpenAIGPTTokenizer,
|
OpenAIGPTTokenizer,
|
||||||
AdamW,
|
|
||||||
cached_path,
|
cached_path,
|
||||||
WEIGHTS_NAME,
|
|
||||||
CONFIG_NAME,
|
|
||||||
get_linear_schedule_with_warmup,
|
get_linear_schedule_with_warmup,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
ROCSTORIES_URL = "https://s3.amazonaws.com/datasets.huggingface.co/ROCStories.tar.gz"
|
ROCSTORIES_URL = "https://s3.amazonaws.com/datasets.huggingface.co/ROCStories.tar.gz"
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
|
|||||||
@@ -19,28 +19,34 @@
|
|||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
|
||||||
import csv
|
import csv
|
||||||
|
import glob
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
import glob
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
WEIGHTS_NAME,
|
||||||
|
AdamW,
|
||||||
|
BertConfig,
|
||||||
|
BertForMultipleChoice,
|
||||||
|
BertTokenizer,
|
||||||
|
get_linear_schedule_with_warmup,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
except:
|
except:
|
||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
|
|
||||||
from tqdm import tqdm, trange
|
|
||||||
|
|
||||||
from transformers import WEIGHTS_NAME, BertConfig, BertForMultipleChoice, BertTokenizer
|
|
||||||
|
|
||||||
from transformers import AdamW, get_linear_schedule_with_warmup
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -23,12 +23,13 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import time
|
|
||||||
import math
|
import math
|
||||||
|
import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import TransfoXLLMHeadModel, TransfoXLCorpus, TransfoXLTokenizer
|
from transformers import TransfoXLCorpus, TransfoXLLMHeadModel, TransfoXLTokenizer
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
|
||||||
|
|||||||
@@ -15,31 +15,31 @@
|
|||||||
""" The distiller to distil the student.
|
""" The distiller to distil the student.
|
||||||
Adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
|
Adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
|
||||||
"""
|
"""
|
||||||
import os
|
|
||||||
import math
|
import math
|
||||||
import psutil
|
import os
|
||||||
import time
|
import time
|
||||||
from tqdm import trange, tqdm
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
|
from torch.utils.data import BatchSampler, DataLoader, RandomSampler
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from torch.utils.data import RandomSampler, BatchSampler, DataLoader
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
|
import psutil
|
||||||
|
from grouped_batch_sampler import GroupedBatchSampler, create_lengths_groups
|
||||||
|
from lm_seqs_dataset import LmSeqsDataset
|
||||||
|
from transformers import get_linear_schedule_with_warmup
|
||||||
|
from utils import logger
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
except:
|
except:
|
||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
|
|
||||||
from transformers import get_linear_schedule_with_warmup
|
|
||||||
|
|
||||||
from utils import logger
|
|
||||||
from lm_seqs_dataset import LmSeqsDataset
|
|
||||||
from grouped_batch_sampler import GroupedBatchSampler, create_lengths_groups
|
|
||||||
|
|
||||||
|
|
||||||
class Distiller:
|
class Distiller:
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@@ -17,8 +17,8 @@
|
|||||||
import bisect
|
import bisect
|
||||||
import copy
|
import copy
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from torch.utils.data.sampler import BatchSampler, Sampler
|
from torch.utils.data.sampler import BatchSampler, Sampler
|
||||||
|
|
||||||
from utils import logger
|
from utils import logger
|
||||||
|
|||||||
@@ -15,10 +15,10 @@
|
|||||||
""" Dataset to distilled models
|
""" Dataset to distilled models
|
||||||
adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
|
adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
|
||||||
"""
|
"""
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from utils import logger
|
from utils import logger
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -18,56 +18,58 @@
|
|||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import glob
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import glob
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
import torch.nn.functional as F
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
try:
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
|
||||||
except:
|
|
||||||
from tensorboardX import SummaryWriter
|
|
||||||
|
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
|
AdamW,
|
||||||
BertConfig,
|
BertConfig,
|
||||||
BertForQuestionAnswering,
|
BertForQuestionAnswering,
|
||||||
BertTokenizer,
|
BertTokenizer,
|
||||||
|
DistilBertConfig,
|
||||||
|
DistilBertForQuestionAnswering,
|
||||||
|
DistilBertTokenizer,
|
||||||
XLMConfig,
|
XLMConfig,
|
||||||
XLMForQuestionAnswering,
|
XLMForQuestionAnswering,
|
||||||
XLMTokenizer,
|
XLMTokenizer,
|
||||||
XLNetConfig,
|
XLNetConfig,
|
||||||
XLNetForQuestionAnswering,
|
XLNetForQuestionAnswering,
|
||||||
XLNetTokenizer,
|
XLNetTokenizer,
|
||||||
DistilBertConfig,
|
get_linear_schedule_with_warmup,
|
||||||
DistilBertForQuestionAnswering,
|
|
||||||
DistilBertTokenizer,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from transformers import AdamW, get_linear_schedule_with_warmup
|
|
||||||
|
|
||||||
from ..utils_squad import (
|
from ..utils_squad import (
|
||||||
read_squad_examples,
|
|
||||||
convert_examples_to_features,
|
|
||||||
RawResult,
|
RawResult,
|
||||||
write_predictions,
|
|
||||||
RawResultExtended,
|
RawResultExtended,
|
||||||
|
convert_examples_to_features,
|
||||||
|
read_squad_examples,
|
||||||
|
write_predictions,
|
||||||
write_predictions_extended,
|
write_predictions_extended,
|
||||||
)
|
)
|
||||||
|
|
||||||
# The follwing import is the official SQuAD evaluation script (2.0).
|
# The follwing import is the official SQuAD evaluation script (2.0).
|
||||||
# You can remove it from the dependencies if you are using this script outside of the library
|
# You can remove it from the dependencies if you are using this script outside of the library
|
||||||
# We've added it here for automated tests (see examples/test_examples.py file)
|
# We've added it here for automated tests (see examples/test_examples.py file)
|
||||||
from ..utils_squad_evaluate import EVAL_OPTS, main as evaluate_on_squad
|
from ..utils_squad_evaluate import EVAL_OPTS
|
||||||
|
from ..utils_squad_evaluate import main as evaluate_on_squad
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
except:
|
||||||
|
from tensorboardX import SummaryWriter
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -16,12 +16,15 @@
|
|||||||
Preprocessing script before distillation.
|
Preprocessing script before distillation.
|
||||||
"""
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
|
import logging
|
||||||
import pickle
|
import pickle
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from transformers import BertTokenizer, RobertaTokenizer, GPT2Tokenizer
|
|
||||||
import logging
|
from transformers import BertTokenizer, GPT2Tokenizer, RobertaTokenizer
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
|
||||||
|
|||||||
@@ -16,10 +16,13 @@
|
|||||||
Preprocessing script before training the distilled model.
|
Preprocessing script before training the distilled model.
|
||||||
Specific to RoBERTa -> DistilRoBERTa and GPT2 -> DistilGPT2.
|
Specific to RoBERTa -> DistilRoBERTa and GPT2 -> DistilGPT2.
|
||||||
"""
|
"""
|
||||||
from transformers import BertForMaskedLM, RobertaForMaskedLM, GPT2LMHeadModel
|
|
||||||
import torch
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers import BertForMaskedLM, GPT2LMHeadModel, RobertaForMaskedLM
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Extraction some layers of the full RobertaForMaskedLM or GPT2LMHeadModel for Transfer Learned Distillation"
|
description="Extraction some layers of the full RobertaForMaskedLM or GPT2LMHeadModel for Transfer Learned Distillation"
|
||||||
|
|||||||
@@ -16,10 +16,13 @@
|
|||||||
Preprocessing script before training DistilBERT.
|
Preprocessing script before training DistilBERT.
|
||||||
Specific to BERT -> DistilBERT.
|
Specific to BERT -> DistilBERT.
|
||||||
"""
|
"""
|
||||||
from transformers import BertForMaskedLM, RobertaForMaskedLM
|
|
||||||
import torch
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers import BertForMaskedLM, RobertaForMaskedLM
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned Distillation"
|
description="Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned Distillation"
|
||||||
|
|||||||
@@ -15,10 +15,11 @@
|
|||||||
"""
|
"""
|
||||||
Preprocessing script before training the distilled model.
|
Preprocessing script before training the distilled model.
|
||||||
"""
|
"""
|
||||||
from collections import Counter
|
|
||||||
import argparse
|
import argparse
|
||||||
import pickle
|
|
||||||
import logging
|
import logging
|
||||||
|
import pickle
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
|
||||||
|
|||||||
@@ -16,22 +16,32 @@
|
|||||||
Training the distilled model.
|
Training the distilled model.
|
||||||
Supported architectures include: BERT -> DistilBERT, RoBERTa -> DistilRoBERTa, GPT2 -> DistilGPT2.
|
Supported architectures include: BERT -> DistilBERT, RoBERTa -> DistilRoBERTa, GPT2 -> DistilGPT2.
|
||||||
"""
|
"""
|
||||||
import os
|
|
||||||
import argparse
|
import argparse
|
||||||
import pickle
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import BertConfig, BertForMaskedLM, BertTokenizer
|
|
||||||
from transformers import RobertaConfig, RobertaForMaskedLM, RobertaTokenizer
|
|
||||||
from transformers import DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer
|
|
||||||
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
|
|
||||||
|
|
||||||
from distiller import Distiller
|
from distiller import Distiller
|
||||||
from utils import git_log, logger, init_gpu_params, set_seed
|
|
||||||
from lm_seqs_dataset import LmSeqsDataset
|
from lm_seqs_dataset import LmSeqsDataset
|
||||||
|
from transformers import (
|
||||||
|
BertConfig,
|
||||||
|
BertForMaskedLM,
|
||||||
|
BertTokenizer,
|
||||||
|
DistilBertConfig,
|
||||||
|
DistilBertForMaskedLM,
|
||||||
|
DistilBertTokenizer,
|
||||||
|
GPT2Config,
|
||||||
|
GPT2LMHeadModel,
|
||||||
|
GPT2Tokenizer,
|
||||||
|
RobertaConfig,
|
||||||
|
RobertaForMaskedLM,
|
||||||
|
RobertaTokenizer,
|
||||||
|
)
|
||||||
|
from utils import git_log, init_gpu_params, logger, set_seed
|
||||||
|
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
MODEL_CLASSES = {
|
||||||
|
|||||||
@@ -15,14 +15,16 @@
|
|||||||
""" Utils to train DistilBERT
|
""" Utils to train DistilBERT
|
||||||
adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
|
adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
|
||||||
"""
|
"""
|
||||||
import git
|
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import logging
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import git
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
format="%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s",
|
format="%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s",
|
||||||
|
|||||||
@@ -19,32 +19,33 @@ from __future__ import absolute_import, division, print_function
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import glob
|
import glob
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import json
|
|
||||||
from sklearn.metrics import f1_score
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from sklearn.metrics import f1_score
|
||||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
try:
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
|
||||||
except:
|
|
||||||
from tensorboardX import SummaryWriter
|
|
||||||
|
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
from utils_mmimdb import ImageEncoder, JsonlDataset, collate_fn, get_mmimdb_labels, get_image_transforms
|
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
|
AdamW,
|
||||||
|
AlbertConfig,
|
||||||
|
AlbertModel,
|
||||||
|
AlbertTokenizer,
|
||||||
BertConfig,
|
BertConfig,
|
||||||
BertModel,
|
BertModel,
|
||||||
BertTokenizer,
|
BertTokenizer,
|
||||||
|
DistilBertConfig,
|
||||||
|
DistilBertModel,
|
||||||
|
DistilBertTokenizer,
|
||||||
|
MMBTConfig,
|
||||||
|
MMBTForClassification,
|
||||||
RobertaConfig,
|
RobertaConfig,
|
||||||
RobertaModel,
|
RobertaModel,
|
||||||
RobertaTokenizer,
|
RobertaTokenizer,
|
||||||
@@ -54,17 +55,16 @@ from transformers import (
|
|||||||
XLNetConfig,
|
XLNetConfig,
|
||||||
XLNetModel,
|
XLNetModel,
|
||||||
XLNetTokenizer,
|
XLNetTokenizer,
|
||||||
DistilBertConfig,
|
get_linear_schedule_with_warmup,
|
||||||
DistilBertModel,
|
|
||||||
DistilBertTokenizer,
|
|
||||||
AlbertConfig,
|
|
||||||
AlbertModel,
|
|
||||||
AlbertTokenizer,
|
|
||||||
MMBTForClassification,
|
|
||||||
MMBTConfig,
|
|
||||||
)
|
)
|
||||||
|
from utils_mmimdb import ImageEncoder, JsonlDataset, collate_fn, get_image_transforms, get_mmimdb_labels
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
except:
|
||||||
|
from tensorboardX import SummaryWriter
|
||||||
|
|
||||||
from transformers import AdamW, get_linear_schedule_with_warmup
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -17,13 +17,15 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
import torchvision
|
import torchvision
|
||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
from torch.utils.data import Dataset
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
POOLING_BREAKDOWN = {1: (1, 1), 2: (2, 1), 3: (3, 1), 4: (2, 2), 5: (5, 1), 6: (3, 2), 7: (7, 1), 8: (4, 2), 9: (3, 3)}
|
POOLING_BREAKDOWN = {1: (1, 1), 2: (2, 1), 3: (3, 1), 4: (2, 2), 5: (5, 1), 6: (3, 2), 7: (7, 1), 8: (4, 2), 9: (3, 3)}
|
||||||
|
|
||||||
|
|||||||
@@ -34,10 +34,11 @@ import torch.nn.functional as F
|
|||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
|
||||||
|
from pplm_classification_head import ClassificationHead
|
||||||
from transformers import GPT2Tokenizer
|
from transformers import GPT2Tokenizer
|
||||||
from transformers.file_utils import cached_path
|
from transformers.file_utils import cached_path
|
||||||
from transformers.modeling_gpt2 import GPT2LMHeadModel
|
from transformers.modeling_gpt2 import GPT2LMHeadModel
|
||||||
from pplm_classification_head import ClassificationHead
|
|
||||||
|
|
||||||
PPLM_BOW = 1
|
PPLM_BOW = 1
|
||||||
PPLM_DISCRIM = 2
|
PPLM_DISCRIM = 2
|
||||||
|
|||||||
@@ -24,16 +24,16 @@ import time
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.optim
|
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
import torch.utils.data as data
|
import torch.utils.data as data
|
||||||
from nltk.tokenize.treebank import TreebankWordDetokenizer
|
|
||||||
from torchtext import data as torchtext_data
|
|
||||||
from torchtext import datasets
|
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
from nltk.tokenize.treebank import TreebankWordDetokenizer
|
||||||
from pplm_classification_head import ClassificationHead
|
from pplm_classification_head import ClassificationHead
|
||||||
|
from torchtext import data as torchtext_data
|
||||||
|
from torchtext import datasets
|
||||||
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
||||||
|
|
||||||
|
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
|
|||||||
@@ -19,19 +19,19 @@
|
|||||||
Some parts of this script are adapted from the code of Michel et al. (http://arxiv.org/abs/1905.10650)
|
Some parts of this script are adapted from the code of Michel et al. (http://arxiv.org/abs/1905.10650)
|
||||||
which is available at https://github.com/pmichel31415/are-16-heads-really-better-than-1
|
which is available at https://github.com/pmichel31415/are-16-heads-really-better-than-1
|
||||||
"""
|
"""
|
||||||
import os
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
from datetime import timedelta, datetime
|
import os
|
||||||
from tqdm import tqdm
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset, Subset
|
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
|
||||||
from torch.nn import CrossEntropyLoss, MSELoss
|
from torch.nn import CrossEntropyLoss, MSELoss
|
||||||
|
from torch.utils.data import DataLoader, SequentialSampler, Subset, TensorDataset
|
||||||
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from run_glue import ALL_MODELS, MODEL_CLASSES, load_and_cache_examples, set_seed
|
||||||
from transformers import (
|
from transformers import (
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
BertConfig,
|
BertConfig,
|
||||||
@@ -44,13 +44,11 @@ from transformers import (
|
|||||||
XLNetForSequenceClassification,
|
XLNetForSequenceClassification,
|
||||||
XLNetTokenizer,
|
XLNetTokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
from run_glue import set_seed, load_and_cache_examples, ALL_MODELS, MODEL_CLASSES
|
|
||||||
|
|
||||||
from transformers import glue_compute_metrics as compute_metrics
|
from transformers import glue_compute_metrics as compute_metrics
|
||||||
from transformers import glue_output_modes as output_modes
|
from transformers import glue_output_modes as output_modes
|
||||||
from transformers import glue_processors as processors
|
from transformers import glue_processors as processors
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -21,15 +21,23 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
from transformers import (
|
||||||
from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer
|
CTRLLMHeadModel,
|
||||||
from transformers import XLNetLMHeadModel, XLNetTokenizer
|
CTRLTokenizer,
|
||||||
from transformers import TransfoXLLMHeadModel, TransfoXLTokenizer
|
GPT2LMHeadModel,
|
||||||
from transformers import CTRLLMHeadModel, CTRLTokenizer
|
GPT2Tokenizer,
|
||||||
from transformers import XLMWithLMHeadModel, XLMTokenizer
|
OpenAIGPTLMHeadModel,
|
||||||
|
OpenAIGPTTokenizer,
|
||||||
|
TransfoXLLMHeadModel,
|
||||||
|
TransfoXLTokenizer,
|
||||||
|
XLMTokenizer,
|
||||||
|
XLMWithLMHeadModel,
|
||||||
|
XLNetLMHeadModel,
|
||||||
|
XLNetTokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
|
|||||||
@@ -19,54 +19,54 @@ from __future__ import absolute_import, division, print_function
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import glob
|
import glob
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import json
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
WEIGHTS_NAME,
|
||||||
|
AdamW,
|
||||||
|
AlbertConfig,
|
||||||
|
AlbertForSequenceClassification,
|
||||||
|
AlbertTokenizer,
|
||||||
|
BertConfig,
|
||||||
|
BertForSequenceClassification,
|
||||||
|
BertTokenizer,
|
||||||
|
DistilBertConfig,
|
||||||
|
DistilBertForSequenceClassification,
|
||||||
|
DistilBertTokenizer,
|
||||||
|
RobertaConfig,
|
||||||
|
RobertaForSequenceClassification,
|
||||||
|
RobertaTokenizer,
|
||||||
|
XLMConfig,
|
||||||
|
XLMForSequenceClassification,
|
||||||
|
XLMRobertaConfig,
|
||||||
|
XLMRobertaForSequenceClassification,
|
||||||
|
XLMRobertaTokenizer,
|
||||||
|
XLMTokenizer,
|
||||||
|
XLNetConfig,
|
||||||
|
XLNetForSequenceClassification,
|
||||||
|
XLNetTokenizer,
|
||||||
|
get_linear_schedule_with_warmup,
|
||||||
|
)
|
||||||
|
from transformers import glue_compute_metrics as compute_metrics
|
||||||
|
from transformers import glue_convert_examples_to_features as convert_examples_to_features
|
||||||
|
from transformers import glue_output_modes as output_modes
|
||||||
|
from transformers import glue_processors as processors
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
except:
|
except:
|
||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
|
|
||||||
from tqdm import tqdm, trange
|
|
||||||
|
|
||||||
from transformers import (
|
|
||||||
WEIGHTS_NAME,
|
|
||||||
BertConfig,
|
|
||||||
BertForSequenceClassification,
|
|
||||||
BertTokenizer,
|
|
||||||
RobertaConfig,
|
|
||||||
RobertaForSequenceClassification,
|
|
||||||
RobertaTokenizer,
|
|
||||||
XLMConfig,
|
|
||||||
XLMForSequenceClassification,
|
|
||||||
XLMTokenizer,
|
|
||||||
XLNetConfig,
|
|
||||||
XLNetForSequenceClassification,
|
|
||||||
XLNetTokenizer,
|
|
||||||
DistilBertConfig,
|
|
||||||
DistilBertForSequenceClassification,
|
|
||||||
DistilBertTokenizer,
|
|
||||||
AlbertConfig,
|
|
||||||
AlbertForSequenceClassification,
|
|
||||||
AlbertTokenizer,
|
|
||||||
XLMRobertaConfig,
|
|
||||||
XLMRobertaForSequenceClassification,
|
|
||||||
XLMRobertaTokenizer,
|
|
||||||
)
|
|
||||||
|
|
||||||
from transformers import AdamW, get_linear_schedule_with_warmup
|
|
||||||
|
|
||||||
from transformers import glue_compute_metrics as compute_metrics
|
|
||||||
from transformers import glue_output_modes as output_modes
|
|
||||||
from transformers import glue_processors as processors
|
|
||||||
from transformers import glue_convert_examples_to_features as convert_examples_to_features
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -32,23 +32,22 @@ import shutil
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler
|
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
try:
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
|
||||||
except:
|
|
||||||
from tensorboardX import SummaryWriter
|
|
||||||
|
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
AdamW,
|
AdamW,
|
||||||
get_linear_schedule_with_warmup,
|
|
||||||
BertConfig,
|
BertConfig,
|
||||||
BertForMaskedLM,
|
BertForMaskedLM,
|
||||||
BertTokenizer,
|
BertTokenizer,
|
||||||
|
CamembertConfig,
|
||||||
|
CamembertForMaskedLM,
|
||||||
|
CamembertTokenizer,
|
||||||
|
DistilBertConfig,
|
||||||
|
DistilBertForMaskedLM,
|
||||||
|
DistilBertTokenizer,
|
||||||
GPT2Config,
|
GPT2Config,
|
||||||
GPT2LMHeadModel,
|
GPT2LMHeadModel,
|
||||||
GPT2Tokenizer,
|
GPT2Tokenizer,
|
||||||
@@ -58,15 +57,16 @@ from transformers import (
|
|||||||
RobertaConfig,
|
RobertaConfig,
|
||||||
RobertaForMaskedLM,
|
RobertaForMaskedLM,
|
||||||
RobertaTokenizer,
|
RobertaTokenizer,
|
||||||
DistilBertConfig,
|
get_linear_schedule_with_warmup,
|
||||||
DistilBertForMaskedLM,
|
|
||||||
DistilBertTokenizer,
|
|
||||||
CamembertConfig,
|
|
||||||
CamembertForMaskedLM,
|
|
||||||
CamembertTokenizer,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
except:
|
||||||
|
from tensorboardX import SummaryWriter
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -23,35 +23,34 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
WEIGHTS_NAME,
|
||||||
|
AdamW,
|
||||||
|
BertConfig,
|
||||||
|
BertForMultipleChoice,
|
||||||
|
BertTokenizer,
|
||||||
|
RobertaConfig,
|
||||||
|
RobertaForMultipleChoice,
|
||||||
|
RobertaTokenizer,
|
||||||
|
XLNetConfig,
|
||||||
|
XLNetForMultipleChoice,
|
||||||
|
XLNetTokenizer,
|
||||||
|
get_linear_schedule_with_warmup,
|
||||||
|
)
|
||||||
|
from utils_multiple_choice import convert_examples_to_features, processors
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
except:
|
except:
|
||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
|
|
||||||
from tqdm import tqdm, trange
|
|
||||||
|
|
||||||
from transformers import (
|
|
||||||
WEIGHTS_NAME,
|
|
||||||
BertConfig,
|
|
||||||
BertForMultipleChoice,
|
|
||||||
BertTokenizer,
|
|
||||||
XLNetConfig,
|
|
||||||
XLNetForMultipleChoice,
|
|
||||||
XLNetTokenizer,
|
|
||||||
RobertaConfig,
|
|
||||||
RobertaForMultipleChoice,
|
|
||||||
RobertaTokenizer,
|
|
||||||
)
|
|
||||||
|
|
||||||
from transformers import AdamW, get_linear_schedule_with_warmup
|
|
||||||
|
|
||||||
from utils_multiple_choice import convert_examples_to_features, processors
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -25,20 +25,35 @@ import random
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from seqeval.metrics import precision_score, recall_score, f1_score
|
|
||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
|
from seqeval.metrics import f1_score, precision_score, recall_score
|
||||||
|
from transformers import (
|
||||||
|
WEIGHTS_NAME,
|
||||||
|
AdamW,
|
||||||
|
BertConfig,
|
||||||
|
BertForTokenClassification,
|
||||||
|
BertTokenizer,
|
||||||
|
CamembertConfig,
|
||||||
|
CamembertForTokenClassification,
|
||||||
|
CamembertTokenizer,
|
||||||
|
DistilBertConfig,
|
||||||
|
DistilBertForTokenClassification,
|
||||||
|
DistilBertTokenizer,
|
||||||
|
RobertaConfig,
|
||||||
|
RobertaForTokenClassification,
|
||||||
|
RobertaTokenizer,
|
||||||
|
XLMRobertaConfig,
|
||||||
|
XLMRobertaForTokenClassification,
|
||||||
|
XLMRobertaTokenizer,
|
||||||
|
get_linear_schedule_with_warmup,
|
||||||
|
)
|
||||||
from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file
|
from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file
|
||||||
|
|
||||||
from transformers import AdamW, get_linear_schedule_with_warmup
|
|
||||||
from transformers import WEIGHTS_NAME, BertConfig, BertForTokenClassification, BertTokenizer
|
|
||||||
from transformers import RobertaConfig, RobertaForTokenClassification, RobertaTokenizer
|
|
||||||
from transformers import DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer
|
|
||||||
from transformers import CamembertConfig, CamembertForTokenClassification, CamembertTokenizer
|
|
||||||
from transformers import XLMRobertaConfig, XLMRobertaForTokenClassification, XLMRobertaTokenizer
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -16,57 +16,57 @@
|
|||||||
""" Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet)."""
|
""" Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet)."""
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
from transformers.data.processors.squad import SquadV1Processor, SquadV2Processor, SquadResult
|
|
||||||
from transformers.data.metrics.squad_metrics import (
|
|
||||||
compute_predictions_logits,
|
|
||||||
compute_predictions_log_probs,
|
|
||||||
squad_evaluate,
|
|
||||||
)
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import glob
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import glob
|
|
||||||
import timeit
|
import timeit
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
try:
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
|
||||||
except:
|
|
||||||
from tensorboardX import SummaryWriter
|
|
||||||
|
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
|
AdamW,
|
||||||
|
AlbertConfig,
|
||||||
|
AlbertForQuestionAnswering,
|
||||||
|
AlbertTokenizer,
|
||||||
BertConfig,
|
BertConfig,
|
||||||
BertForQuestionAnswering,
|
BertForQuestionAnswering,
|
||||||
BertTokenizer,
|
BertTokenizer,
|
||||||
|
DistilBertConfig,
|
||||||
|
DistilBertForQuestionAnswering,
|
||||||
|
DistilBertTokenizer,
|
||||||
|
RobertaConfig,
|
||||||
RobertaForQuestionAnswering,
|
RobertaForQuestionAnswering,
|
||||||
RobertaTokenizer,
|
RobertaTokenizer,
|
||||||
RobertaConfig,
|
|
||||||
XLMConfig,
|
XLMConfig,
|
||||||
XLMForQuestionAnswering,
|
XLMForQuestionAnswering,
|
||||||
XLMTokenizer,
|
XLMTokenizer,
|
||||||
XLNetConfig,
|
XLNetConfig,
|
||||||
XLNetForQuestionAnswering,
|
XLNetForQuestionAnswering,
|
||||||
XLNetTokenizer,
|
XLNetTokenizer,
|
||||||
DistilBertConfig,
|
get_linear_schedule_with_warmup,
|
||||||
DistilBertForQuestionAnswering,
|
squad_convert_examples_to_features,
|
||||||
DistilBertTokenizer,
|
|
||||||
AlbertConfig,
|
|
||||||
AlbertForQuestionAnswering,
|
|
||||||
AlbertTokenizer,
|
|
||||||
XLMConfig,
|
|
||||||
XLMForQuestionAnswering,
|
|
||||||
XLMTokenizer,
|
|
||||||
)
|
)
|
||||||
|
from transformers.data.metrics.squad_metrics import (
|
||||||
|
compute_predictions_log_probs,
|
||||||
|
compute_predictions_logits,
|
||||||
|
squad_evaluate,
|
||||||
|
)
|
||||||
|
from transformers.data.processors.squad import SquadResult, SquadV1Processor, SquadV2Processor
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
except:
|
||||||
|
from tensorboardX import SummaryWriter
|
||||||
|
|
||||||
from transformers import AdamW, get_linear_schedule_with_warmup, squad_convert_examples_to_features
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,18 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
import tensorflow_datasets
|
import tensorflow_datasets
|
||||||
from transformers import (
|
from transformers import (
|
||||||
|
BertConfig,
|
||||||
|
BertForSequenceClassification,
|
||||||
BertTokenizer,
|
BertTokenizer,
|
||||||
TFBertForSequenceClassification,
|
TFBertForSequenceClassification,
|
||||||
BertConfig,
|
|
||||||
glue_convert_examples_to_features,
|
glue_convert_examples_to_features,
|
||||||
BertForSequenceClassification,
|
|
||||||
glue_processors,
|
glue_processors,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# script parameters
|
# script parameters
|
||||||
BATCH_SIZE = 32
|
BATCH_SIZE = 32
|
||||||
EVAL_BATCH_SIZE = BATCH_SIZE * 2
|
EVAL_BATCH_SIZE = BATCH_SIZE * 2
|
||||||
|
|||||||
@@ -1,23 +1,33 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
import datetime
|
|
||||||
import os
|
|
||||||
import math
|
|
||||||
import glob
|
|
||||||
import re
|
|
||||||
import tensorflow as tf
|
|
||||||
import collections
|
|
||||||
import numpy as np
|
|
||||||
from seqeval import metrics
|
|
||||||
import _pickle as pickle
|
import _pickle as pickle
|
||||||
from absl import logging
|
import collections
|
||||||
from transformers import TF2_WEIGHTS_NAME, BertConfig, BertTokenizer, TFBertForTokenClassification
|
import datetime
|
||||||
from transformers import RobertaConfig, RobertaTokenizer, TFRobertaForTokenClassification
|
import glob
|
||||||
from transformers import DistilBertConfig, DistilBertTokenizer, TFDistilBertForTokenClassification
|
import math
|
||||||
from transformers import create_optimizer, GradientAccumulator
|
import os
|
||||||
from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file
|
import re
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
from absl import app, flags, logging
|
||||||
|
|
||||||
from fastprogress import master_bar, progress_bar
|
from fastprogress import master_bar, progress_bar
|
||||||
from absl import flags
|
from seqeval import metrics
|
||||||
from absl import app
|
from transformers import (
|
||||||
|
TF2_WEIGHTS_NAME,
|
||||||
|
BertConfig,
|
||||||
|
BertTokenizer,
|
||||||
|
DistilBertConfig,
|
||||||
|
DistilBertTokenizer,
|
||||||
|
GradientAccumulator,
|
||||||
|
RobertaConfig,
|
||||||
|
RobertaTokenizer,
|
||||||
|
TFBertForTokenClassification,
|
||||||
|
TFDistilBertForTokenClassification,
|
||||||
|
TFRobertaForTokenClassification,
|
||||||
|
create_optimizer,
|
||||||
|
)
|
||||||
|
from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file
|
||||||
|
|
||||||
|
|
||||||
ALL_MODELS = sum(
|
ALL_MODELS = sum(
|
||||||
|
|||||||
@@ -28,34 +28,33 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
WEIGHTS_NAME,
|
||||||
|
AdamW,
|
||||||
|
BertConfig,
|
||||||
|
BertForSequenceClassification,
|
||||||
|
BertTokenizer,
|
||||||
|
DistilBertConfig,
|
||||||
|
DistilBertForSequenceClassification,
|
||||||
|
DistilBertTokenizer,
|
||||||
|
XLMConfig,
|
||||||
|
XLMForSequenceClassification,
|
||||||
|
XLMTokenizer,
|
||||||
|
get_linear_schedule_with_warmup,
|
||||||
|
)
|
||||||
|
from transformers import glue_convert_examples_to_features as convert_examples_to_features
|
||||||
|
from transformers import xnli_compute_metrics as compute_metrics
|
||||||
|
from transformers import xnli_output_modes as output_modes
|
||||||
|
from transformers import xnli_processors as processors
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
except:
|
except:
|
||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
|
|
||||||
from tqdm import tqdm, trange
|
|
||||||
|
|
||||||
from transformers import (
|
|
||||||
WEIGHTS_NAME,
|
|
||||||
BertConfig,
|
|
||||||
BertForSequenceClassification,
|
|
||||||
BertTokenizer,
|
|
||||||
XLMConfig,
|
|
||||||
XLMForSequenceClassification,
|
|
||||||
XLMTokenizer,
|
|
||||||
DistilBertConfig,
|
|
||||||
DistilBertForSequenceClassification,
|
|
||||||
DistilBertTokenizer,
|
|
||||||
)
|
|
||||||
|
|
||||||
from transformers import AdamW, get_linear_schedule_with_warmup
|
|
||||||
|
|
||||||
from transformers import xnli_compute_metrics as compute_metrics
|
|
||||||
from transformers import xnli_output_modes as output_modes
|
|
||||||
from transformers import xnli_processors as processors
|
|
||||||
|
|
||||||
from transformers import glue_convert_examples_to_features as convert_examples_to_features
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -20,13 +20,13 @@ the model within the original codebase to be able to only save its `state_dict`.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from collections import namedtuple
|
|
||||||
import logging
|
import logging
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from models.model_builder import AbsSummarizer # The authors' implementation
|
|
||||||
from model_bertabs import BertAbsSummarizer
|
from model_bertabs import BertAbsSummarizer
|
||||||
|
from models.model_builder import AbsSummarizer # The authors' implementation
|
||||||
from transformers import BertTokenizer
|
from transformers import BertTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -27,9 +27,8 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn.init import xavier_uniform_
|
from torch.nn.init import xavier_uniform_
|
||||||
|
|
||||||
from transformers import BertModel, BertConfig, PreTrainedModel
|
|
||||||
|
|
||||||
from configuration_bertabs import BertAbsConfig
|
from configuration_bertabs import BertAbsConfig
|
||||||
|
from transformers import BertConfig, BertModel, PreTrainedModel
|
||||||
|
|
||||||
|
|
||||||
MAX_SIZE = 5000
|
MAX_SIZE = 5000
|
||||||
|
|||||||
@@ -1,26 +1,25 @@
|
|||||||
#! /usr/bin/python3
|
#! /usr/bin/python3
|
||||||
import argparse
|
import argparse
|
||||||
from collections import namedtuple
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader, SequentialSampler
|
from torch.utils.data import DataLoader, SequentialSampler
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from transformers import BertTokenizer
|
|
||||||
|
|
||||||
from modeling_bertabs import BertAbs, build_predictor
|
from modeling_bertabs import BertAbs, build_predictor
|
||||||
|
from transformers import BertTokenizer
|
||||||
from utils_summarization import (
|
from utils_summarization import (
|
||||||
SummarizationDataset,
|
SummarizationDataset,
|
||||||
encode_for_summarization,
|
|
||||||
build_mask,
|
build_mask,
|
||||||
fit_to_block_size,
|
|
||||||
compute_token_type_ids,
|
compute_token_type_ids,
|
||||||
|
encode_for_summarization,
|
||||||
|
fit_to_block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from collections import deque
|
|
||||||
import os
|
import os
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|||||||
@@ -17,12 +17,7 @@ import unittest
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from utils_summarization import (
|
from utils_summarization import build_mask, compute_token_type_ids, fit_to_block_size, process_story
|
||||||
compute_token_type_ids,
|
|
||||||
fit_to_block_size,
|
|
||||||
build_mask,
|
|
||||||
process_story,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SummarizationDataProcessingTest(unittest.TestCase):
|
class SummarizationDataProcessingTest(unittest.TestCase):
|
||||||
|
|||||||
@@ -12,14 +12,17 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import, division, print_function
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import unittest
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
import sys
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import run_generation
|
||||||
|
import run_glue
|
||||||
|
import run_squad
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# python 3.4+ can use builtin unittest.mock instead of mock package
|
# python 3.4+ can use builtin unittest.mock instead of mock package
|
||||||
@@ -27,9 +30,6 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
from mock import patch
|
from mock import patch
|
||||||
|
|
||||||
import run_glue
|
|
||||||
import run_squad
|
|
||||||
import run_generation
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
|
||||||
|
|||||||
@@ -17,16 +17,17 @@
|
|||||||
|
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
import csv
|
||||||
|
import glob
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from io import open
|
from io import open
|
||||||
import json
|
|
||||||
import csv
|
|
||||||
import glob
|
|
||||||
import tqdm
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
import tqdm
|
||||||
|
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
from io import open
|
from io import open
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,13 +1,14 @@
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
AutoTokenizer,
|
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModel,
|
AutoModel,
|
||||||
AutoModelWithLMHead,
|
|
||||||
AutoModelForSequenceClassification,
|
|
||||||
AutoModelForQuestionAnswering,
|
AutoModelForQuestionAnswering,
|
||||||
|
AutoModelForSequenceClassification,
|
||||||
|
AutoModelWithLMHead,
|
||||||
|
AutoTokenizer,
|
||||||
)
|
)
|
||||||
from transformers.file_utils import add_start_docstrings
|
from transformers.file_utils import add_start_docstrings
|
||||||
|
|
||||||
|
|
||||||
dependencies = ["torch", "tqdm", "boto3", "requests", "regex", "sentencepiece", "sacremoses"]
|
dependencies = ["torch", "tqdm", "boto3", "requests", "regex", "sentencepiece", "sacremoses"]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
1
setup.py
1
setup.py
@@ -34,6 +34,7 @@ To create the package for pypi.
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
from io import open
|
from io import open
|
||||||
|
|
||||||
from setuptools import find_packages, setup
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -17,54 +17,55 @@
|
|||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import glob
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import glob
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
try:
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
|
||||||
except:
|
|
||||||
from tensorboardX import SummaryWriter
|
|
||||||
|
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
|
AdamW,
|
||||||
BertConfig,
|
BertConfig,
|
||||||
BertForQuestionAnswering,
|
BertForQuestionAnswering,
|
||||||
BertTokenizer,
|
BertTokenizer,
|
||||||
|
DistilBertConfig,
|
||||||
|
DistilBertForQuestionAnswering,
|
||||||
|
DistilBertTokenizer,
|
||||||
XLMConfig,
|
XLMConfig,
|
||||||
XLMForQuestionAnswering,
|
XLMForQuestionAnswering,
|
||||||
XLMTokenizer,
|
XLMTokenizer,
|
||||||
XLNetConfig,
|
XLNetConfig,
|
||||||
XLNetForQuestionAnswering,
|
XLNetForQuestionAnswering,
|
||||||
XLNetTokenizer,
|
XLNetTokenizer,
|
||||||
DistilBertConfig,
|
get_linear_schedule_with_warmup,
|
||||||
DistilBertForQuestionAnswering,
|
|
||||||
DistilBertTokenizer,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from transformers import AdamW, get_linear_schedule_with_warmup
|
|
||||||
|
|
||||||
from utils_squad import (
|
from utils_squad import (
|
||||||
read_squad_examples,
|
|
||||||
convert_examples_to_features,
|
|
||||||
RawResult,
|
RawResult,
|
||||||
write_predictions,
|
|
||||||
RawResultExtended,
|
RawResultExtended,
|
||||||
|
convert_examples_to_features,
|
||||||
|
read_squad_examples,
|
||||||
|
write_predictions,
|
||||||
write_predictions_extended,
|
write_predictions_extended,
|
||||||
)
|
)
|
||||||
|
|
||||||
# The follwing import is the official SQuAD evaluation script (2.0).
|
# The follwing import is the official SQuAD evaluation script (2.0).
|
||||||
# You can remove it from the dependencies if you are using this script outside of the library
|
# You can remove it from the dependencies if you are using this script outside of the library
|
||||||
# We've added it here for automated tests (see examples/test_examples.py file)
|
# We've added it here for automated tests (see examples/test_examples.py file)
|
||||||
from utils_squad_evaluate import EVAL_OPTS, main as evaluate_on_squad
|
from utils_squad_evaluate import EVAL_OPTS
|
||||||
|
from utils_squad_evaluate import main as evaluate_on_squad
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
except:
|
||||||
|
from tensorboardX import SummaryWriter
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -16,16 +16,17 @@
|
|||||||
|
|
||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import collections
|
|
||||||
from io import open
|
from io import open
|
||||||
|
|
||||||
from transformers.tokenization_bert import BasicTokenizer, whitespace_tokenize
|
from transformers.tokenization_bert import BasicTokenizer, whitespace_tokenize
|
||||||
|
|
||||||
# Required by XLNet evaluation method to compute optimal threshold (see write_predictions_extended() method)
|
# Required by XLNet evaluation method to compute optimal threshold (see write_predictions_extended() method)
|
||||||
from utils_squad_evaluate import find_all_best_thresh_v2, make_qid_to_has_ans, get_raw_scores
|
from utils_squad_evaluate import find_all_best_thresh_v2, get_raw_scores, make_qid_to_has_ans
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -19,11 +19,13 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
import six
|
|
||||||
from io import open
|
from io import open
|
||||||
|
|
||||||
|
import six
|
||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
XXX_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
XXX_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
|
|||||||
@@ -14,16 +14,15 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Convert XXX checkpoint."""
|
"""Convert XXX checkpoint."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import, division, print_function
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import logging
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import XxxConfig, XxxForPreTraining, load_tf_weights_in_xxx
|
from transformers import XxxConfig, XxxForPreTraining, load_tf_weights_in_xxx
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|||||||
@@ -21,21 +21,22 @@
|
|||||||
|
|
||||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import itertools
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import copy
|
|
||||||
import itertools
|
|
||||||
from io import open
|
from io import open
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from .configuration_xxx import XxxConfig
|
from .configuration_xxx import XxxConfig
|
||||||
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list
|
|
||||||
from .file_utils import add_start_docstrings
|
from .file_utils import add_start_docstrings
|
||||||
|
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -20,22 +20,23 @@
|
|||||||
|
|
||||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import itertools
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import copy
|
|
||||||
import itertools
|
|
||||||
from io import open
|
from io import open
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import CrossEntropyLoss, MSELoss
|
from torch.nn import CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from .modeling_utils import PreTrainedModel, prune_linear_layer
|
|
||||||
from .configuration_xxx import XxxConfig
|
from .configuration_xxx import XxxConfig
|
||||||
from .file_utils import add_start_docstrings
|
from .file_utils import add_start_docstrings
|
||||||
|
from .modeling_utils import PreTrainedModel, prune_linear_layer
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -12,19 +12,18 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import, division, print_function
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import unittest
|
|
||||||
import sys
|
import sys
|
||||||
|
import unittest
|
||||||
from .modeling_tf_common_test import TFCommonTestCases, ids_tensor
|
|
||||||
from .configuration_common_test import ConfigTester
|
|
||||||
from .utils import CACHE_DIR, require_tf, slow
|
|
||||||
|
|
||||||
from transformers import XxxConfig, is_tf_available
|
from transformers import XxxConfig, is_tf_available
|
||||||
|
|
||||||
|
from .configuration_common_test import ConfigTester
|
||||||
|
from .modeling_tf_common_test import TFCommonTestCases, ids_tensor
|
||||||
|
from .utils import CACHE_DIR, require_tf, slow
|
||||||
|
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from transformers.modeling_tf_xxx import (
|
from transformers.modeling_tf_xxx import (
|
||||||
|
|||||||
@@ -12,18 +12,17 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import, division, print_function
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
|
|
||||||
from .modeling_common_test import CommonTestCases, ids_tensor
|
|
||||||
from .configuration_common_test import ConfigTester
|
from .configuration_common_test import ConfigTester
|
||||||
|
from .modeling_common_test import CommonTestCases, ids_tensor
|
||||||
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
from .utils import CACHE_DIR, require_torch, slow, torch_device
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from transformers import (
|
from transformers import (
|
||||||
XxxConfig,
|
XxxConfig,
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ import os
|
|||||||
import unittest
|
import unittest
|
||||||
from io import open
|
from io import open
|
||||||
|
|
||||||
from transformers.tokenization_bert import XxxTokenizer, VOCAB_FILES_NAMES
|
from transformers.tokenization_bert import VOCAB_FILES_NAMES, XxxTokenizer
|
||||||
|
|
||||||
from .tokenization_tests_commons import CommonTestCases
|
from .tokenization_tests_commons import CommonTestCases
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from io import open
|
|||||||
|
|
||||||
from .tokenization_utils import PreTrainedTokenizer
|
from .tokenization_utils import PreTrainedTokenizer
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
####################################################
|
####################################################
|
||||||
|
|||||||
@@ -15,86 +15,114 @@ except:
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
|
||||||
|
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, AutoConfig
|
||||||
|
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
|
||||||
|
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
|
||||||
|
from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
|
||||||
|
from .configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig
|
||||||
|
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
|
||||||
|
from .configuration_mmbt import MMBTConfig
|
||||||
|
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
|
||||||
|
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
|
||||||
|
from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
|
||||||
|
from .configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
|
||||||
|
|
||||||
|
# Configurations
|
||||||
|
from .configuration_utils import PretrainedConfig
|
||||||
|
from .configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig
|
||||||
|
from .configuration_xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig
|
||||||
|
from .configuration_xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig
|
||||||
|
from .data import (
|
||||||
|
DataProcessor,
|
||||||
|
InputExample,
|
||||||
|
InputFeatures,
|
||||||
|
SingleSentenceClassificationProcessor,
|
||||||
|
SquadExample,
|
||||||
|
SquadFeatures,
|
||||||
|
SquadV1Processor,
|
||||||
|
SquadV2Processor,
|
||||||
|
glue_convert_examples_to_features,
|
||||||
|
glue_output_modes,
|
||||||
|
glue_processors,
|
||||||
|
glue_tasks_num_labels,
|
||||||
|
is_sklearn_available,
|
||||||
|
squad_convert_examples_to_features,
|
||||||
|
xnli_output_modes,
|
||||||
|
xnli_processors,
|
||||||
|
xnli_tasks_num_labels,
|
||||||
|
)
|
||||||
|
|
||||||
# Files and general utilities
|
# Files and general utilities
|
||||||
from .file_utils import (
|
from .file_utils import (
|
||||||
TRANSFORMERS_CACHE,
|
|
||||||
PYTORCH_TRANSFORMERS_CACHE,
|
|
||||||
PYTORCH_PRETRAINED_BERT_CACHE,
|
|
||||||
cached_path,
|
|
||||||
add_start_docstrings,
|
|
||||||
add_end_docstrings,
|
|
||||||
WEIGHTS_NAME,
|
|
||||||
TF2_WEIGHTS_NAME,
|
|
||||||
TF_WEIGHTS_NAME,
|
|
||||||
CONFIG_NAME,
|
CONFIG_NAME,
|
||||||
MODEL_CARD_NAME,
|
MODEL_CARD_NAME,
|
||||||
|
PYTORCH_PRETRAINED_BERT_CACHE,
|
||||||
|
PYTORCH_TRANSFORMERS_CACHE,
|
||||||
|
TF2_WEIGHTS_NAME,
|
||||||
|
TF_WEIGHTS_NAME,
|
||||||
|
TRANSFORMERS_CACHE,
|
||||||
|
WEIGHTS_NAME,
|
||||||
|
add_end_docstrings,
|
||||||
|
add_start_docstrings,
|
||||||
|
cached_path,
|
||||||
is_tf_available,
|
is_tf_available,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .data import (
|
# Model Cards
|
||||||
is_sklearn_available,
|
from .modelcard import ModelCard
|
||||||
InputExample,
|
|
||||||
InputFeatures,
|
# TF 2.0 <=> PyTorch conversion utilities
|
||||||
DataProcessor,
|
from .modeling_tf_pytorch_utils import (
|
||||||
SingleSentenceClassificationProcessor,
|
convert_tf_weight_name_to_pt_weight_name,
|
||||||
glue_output_modes,
|
load_pytorch_checkpoint_in_tf2_model,
|
||||||
glue_convert_examples_to_features,
|
load_pytorch_model_in_tf2_model,
|
||||||
glue_processors,
|
load_pytorch_weights_in_tf2_model,
|
||||||
glue_tasks_num_labels,
|
load_tf2_checkpoint_in_pytorch_model,
|
||||||
xnli_output_modes,
|
load_tf2_model_in_pytorch_model,
|
||||||
xnli_processors,
|
load_tf2_weights_in_pytorch_model,
|
||||||
xnli_tasks_num_labels,
|
|
||||||
squad_convert_examples_to_features,
|
|
||||||
SquadFeatures,
|
|
||||||
SquadExample,
|
|
||||||
SquadV1Processor,
|
|
||||||
SquadV2Processor,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Pipelines
|
||||||
|
from .pipelines import (
|
||||||
|
CsvPipelineDataFormat,
|
||||||
|
FeatureExtractionPipeline,
|
||||||
|
JsonPipelineDataFormat,
|
||||||
|
NerPipeline,
|
||||||
|
PipedPipelineDataFormat,
|
||||||
|
Pipeline,
|
||||||
|
PipelineDataFormat,
|
||||||
|
QuestionAnsweringPipeline,
|
||||||
|
TextClassificationPipeline,
|
||||||
|
pipeline,
|
||||||
|
)
|
||||||
|
from .tokenization_albert import AlbertTokenizer
|
||||||
|
from .tokenization_auto import AutoTokenizer
|
||||||
|
from .tokenization_bert import BasicTokenizer, BertTokenizer, WordpieceTokenizer
|
||||||
|
from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer
|
||||||
|
from .tokenization_camembert import CamembertTokenizer
|
||||||
|
from .tokenization_ctrl import CTRLTokenizer
|
||||||
|
from .tokenization_distilbert import DistilBertTokenizer
|
||||||
|
from .tokenization_gpt2 import GPT2Tokenizer
|
||||||
|
from .tokenization_openai import OpenAIGPTTokenizer
|
||||||
|
from .tokenization_roberta import RobertaTokenizer
|
||||||
|
from .tokenization_t5 import T5Tokenizer
|
||||||
|
from .tokenization_transfo_xl import TransfoXLCorpus, TransfoXLTokenizer
|
||||||
|
|
||||||
|
# Tokenizers
|
||||||
|
from .tokenization_utils import PreTrainedTokenizer
|
||||||
|
from .tokenization_xlm import XLMTokenizer
|
||||||
|
from .tokenization_xlm_roberta import XLMRobertaTokenizer
|
||||||
|
from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
if is_sklearn_available():
|
if is_sklearn_available():
|
||||||
from .data import glue_compute_metrics, xnli_compute_metrics
|
from .data import glue_compute_metrics, xnli_compute_metrics
|
||||||
|
|
||||||
# Model Cards
|
|
||||||
from .modelcard import ModelCard
|
|
||||||
|
|
||||||
# Tokenizers
|
|
||||||
from .tokenization_utils import PreTrainedTokenizer
|
|
||||||
from .tokenization_auto import AutoTokenizer
|
|
||||||
from .tokenization_bert import BertTokenizer, BasicTokenizer, WordpieceTokenizer
|
|
||||||
from .tokenization_bert_japanese import BertJapaneseTokenizer, MecabTokenizer, CharacterTokenizer
|
|
||||||
from .tokenization_openai import OpenAIGPTTokenizer
|
|
||||||
from .tokenization_transfo_xl import TransfoXLTokenizer, TransfoXLCorpus
|
|
||||||
from .tokenization_gpt2 import GPT2Tokenizer
|
|
||||||
from .tokenization_ctrl import CTRLTokenizer
|
|
||||||
from .tokenization_xlnet import XLNetTokenizer, SPIECE_UNDERLINE
|
|
||||||
from .tokenization_xlm import XLMTokenizer
|
|
||||||
from .tokenization_roberta import RobertaTokenizer
|
|
||||||
from .tokenization_distilbert import DistilBertTokenizer
|
|
||||||
from .tokenization_albert import AlbertTokenizer
|
|
||||||
from .tokenization_camembert import CamembertTokenizer
|
|
||||||
from .tokenization_t5 import T5Tokenizer
|
|
||||||
from .tokenization_xlm_roberta import XLMRobertaTokenizer
|
|
||||||
|
|
||||||
# Configurations
|
|
||||||
from .configuration_utils import PretrainedConfig
|
|
||||||
from .configuration_auto import AutoConfig, ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
|
|
||||||
from .configuration_bert import BertConfig, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
|
||||||
from .configuration_openai import OpenAIGPTConfig, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
|
||||||
from .configuration_transfo_xl import TransfoXLConfig, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
|
|
||||||
from .configuration_gpt2 import GPT2Config, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
|
|
||||||
from .configuration_ctrl import CTRLConfig, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
|
|
||||||
from .configuration_xlnet import XLNetConfig, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
|
|
||||||
from .configuration_xlm import XLMConfig, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
|
|
||||||
from .configuration_roberta import RobertaConfig, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
|
|
||||||
from .configuration_distilbert import DistilBertConfig, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
|
||||||
from .configuration_albert import AlbertConfig, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
|
||||||
from .configuration_camembert import CamembertConfig, CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
|
||||||
from .configuration_t5 import T5Config, T5_PRETRAINED_CONFIG_ARCHIVE_MAP
|
|
||||||
from .configuration_xlm_roberta import XLMRobertaConfig, XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
|
|
||||||
from .configuration_mmbt import MMBTConfig
|
|
||||||
|
|
||||||
# Modeling
|
# Modeling
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -345,30 +373,6 @@ if is_tf_available():
|
|||||||
# Optimization
|
# Optimization
|
||||||
from .optimization_tf import WarmUp, create_optimizer, AdamWeightDecay, GradientAccumulator
|
from .optimization_tf import WarmUp, create_optimizer, AdamWeightDecay, GradientAccumulator
|
||||||
|
|
||||||
# TF 2.0 <=> PyTorch conversion utilities
|
|
||||||
from .modeling_tf_pytorch_utils import (
|
|
||||||
convert_tf_weight_name_to_pt_weight_name,
|
|
||||||
load_pytorch_checkpoint_in_tf2_model,
|
|
||||||
load_pytorch_weights_in_tf2_model,
|
|
||||||
load_pytorch_model_in_tf2_model,
|
|
||||||
load_tf2_checkpoint_in_pytorch_model,
|
|
||||||
load_tf2_weights_in_pytorch_model,
|
|
||||||
load_tf2_model_in_pytorch_model,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Pipelines
|
|
||||||
from .pipelines import (
|
|
||||||
pipeline,
|
|
||||||
PipelineDataFormat,
|
|
||||||
CsvPipelineDataFormat,
|
|
||||||
JsonPipelineDataFormat,
|
|
||||||
PipedPipelineDataFormat,
|
|
||||||
Pipeline,
|
|
||||||
FeatureExtractionPipeline,
|
|
||||||
QuestionAnsweringPipeline,
|
|
||||||
NerPipeline,
|
|
||||||
TextClassificationPipeline,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not is_tf_available() and not is_torch_available():
|
if not is_tf_available() and not is_torch_available():
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
from argparse import ArgumentParser, Namespace
|
from argparse import ArgumentParser, Namespace
|
||||||
|
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
|
|
||||||
from transformers import AutoModel, AutoTokenizer
|
from transformers import AutoModel, AutoTokenizer
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import logging
|
|||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
from transformers.commands import BaseTransformersCLICommand
|
from transformers.commands import BaseTransformersCLICommand
|
||||||
from transformers.pipelines import pipeline, Pipeline, PipelineDataFormat, SUPPORTED_TASKS
|
from transformers.pipelines import SUPPORTED_TASKS, Pipeline, PipelineDataFormat, pipeline
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
|||||||
@@ -1,7 +1,11 @@
|
|||||||
from argparse import ArgumentParser, Namespace
|
|
||||||
from typing import List, Optional, Union, Any
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from argparse import ArgumentParser, Namespace
|
||||||
|
from typing import Any, List, Optional, Union
|
||||||
|
|
||||||
|
from transformers import Pipeline
|
||||||
|
from transformers.commands import BaseTransformersCLICommand
|
||||||
|
from transformers.pipelines import SUPPORTED_TASKS, pipeline
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from uvicorn import run
|
from uvicorn import run
|
||||||
@@ -14,9 +18,6 @@ except (ImportError, AttributeError):
|
|||||||
Body = lambda *x, **y: None
|
Body = lambda *x, **y: None
|
||||||
_serve_dependancies_installed = False
|
_serve_dependancies_installed = False
|
||||||
|
|
||||||
from transformers import Pipeline
|
|
||||||
from transformers.commands import BaseTransformersCLICommand
|
|
||||||
from transformers.pipelines import SUPPORTED_TASKS, pipeline
|
|
||||||
|
|
||||||
logger = logging.getLogger("transformers-cli/serving")
|
logger = logging.getLogger("transformers-cli/serving")
|
||||||
|
|
||||||
|
|||||||
@@ -2,13 +2,10 @@ import os
|
|||||||
from argparse import ArgumentParser, Namespace
|
from argparse import ArgumentParser, Namespace
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
|
|
||||||
|
from transformers import SingleSentenceClassificationProcessor as Processor
|
||||||
|
from transformers import TextClassificationPipeline, is_tf_available, is_torch_available
|
||||||
from transformers.commands import BaseTransformersCLICommand
|
from transformers.commands import BaseTransformersCLICommand
|
||||||
from transformers import (
|
|
||||||
is_tf_available,
|
|
||||||
is_torch_available,
|
|
||||||
TextClassificationPipeline,
|
|
||||||
SingleSentenceClassificationProcessor as Processor,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not is_tf_available() and not is_torch_available():
|
if not is_tf_available() and not is_torch_available():
|
||||||
raise ImportError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training")
|
raise ImportError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training")
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
|
import os
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from getpass import getpass
|
from getpass import getpass
|
||||||
import os
|
|
||||||
|
|
||||||
from transformers.commands import BaseTransformersCLICommand
|
from transformers.commands import BaseTransformersCLICommand
|
||||||
from transformers.hf_api import HfApi, HfFolder, HTTPError
|
from transformers.hf_api import HfApi, HfFolder, HTTPError
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
"albert-base-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-config.json",
|
"albert-base-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-config.json",
|
||||||
"albert-large-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-config.json",
|
"albert-large-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-config.json",
|
||||||
|
|||||||
@@ -18,19 +18,20 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from .configuration_bert import BertConfig, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
|
||||||
from .configuration_openai import OpenAIGPTConfig, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
|
||||||
from .configuration_transfo_xl import TransfoXLConfig, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
|
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
|
||||||
from .configuration_gpt2 import GPT2Config, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
|
from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
|
||||||
from .configuration_ctrl import CTRLConfig, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
|
from .configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig
|
||||||
from .configuration_xlnet import XLNetConfig, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
|
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
|
||||||
from .configuration_xlm import XLMConfig, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
|
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
|
||||||
from .configuration_roberta import RobertaConfig, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
|
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
|
||||||
from .configuration_distilbert import DistilBertConfig, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
|
||||||
from .configuration_albert import AlbertConfig, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
from .configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
|
||||||
from .configuration_camembert import CamembertConfig, CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
from .configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig
|
||||||
from .configuration_t5 import T5Config, T5_PRETRAINED_CONFIG_ARCHIVE_MAP
|
from .configuration_xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig
|
||||||
from .configuration_xlm_roberta import XLMRobertaConfig, XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
|
from .configuration_xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from io import open
|
|||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import logging
|
|||||||
|
|
||||||
from .configuration_roberta import RobertaConfig
|
from .configuration_roberta import RobertaConfig
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from io import open
|
|||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP = {"ctrl": "https://storage.googleapis.com/sf-ctrl/pytorch/ctrl-config.json"}
|
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP = {"ctrl": "https://storage.googleapis.com/sf-ctrl/pytorch/ctrl-config.json"}
|
||||||
|
|||||||
@@ -15,13 +15,14 @@
|
|||||||
""" DistilBERT model configuration """
|
""" DistilBERT model configuration """
|
||||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
|
|
||||||
import sys
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import sys
|
||||||
from io import open
|
from io import open
|
||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from io import open
|
|||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from io import open
|
|||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import logging
|
|||||||
|
|
||||||
from .configuration_bert import BertConfig
|
from .configuration_bert import BertConfig
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
|
|||||||
@@ -19,11 +19,13 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
import six
|
|
||||||
from io import open
|
from io import open
|
||||||
|
|
||||||
|
import six
|
||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
T5_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
T5_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from io import open
|
|||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
|
|||||||
@@ -23,7 +23,8 @@ import logging
|
|||||||
import os
|
import os
|
||||||
from io import open
|
from io import open
|
||||||
|
|
||||||
from .file_utils import CONFIG_NAME, cached_path, is_remote_url, hf_bucket_url
|
from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from io import open
|
|||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import logging
|
|||||||
|
|
||||||
from .configuration_roberta import RobertaConfig
|
from .configuration_roberta import RobertaConfig
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from io import open
|
|||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
|
|||||||
@@ -14,16 +14,15 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Convert ALBERT checkpoint."""
|
"""Convert ALBERT checkpoint."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import, division, print_function
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import logging
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import AlbertConfig, AlbertForMaskedLM, load_tf_weights_in_albert
|
from transformers import AlbertConfig, AlbertForMaskedLM, load_tf_weights_in_albert
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|||||||
@@ -14,16 +14,15 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Convert BERT checkpoint."""
|
"""Convert BERT checkpoint."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import, division, print_function
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import logging
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
|
from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|||||||
@@ -15,11 +15,13 @@
|
|||||||
|
|
||||||
"""Convert Huggingface Pytorch checkpoint to Tensorflow checkpoint."""
|
"""Convert Huggingface Pytorch checkpoint to Tensorflow checkpoint."""
|
||||||
|
|
||||||
import os
|
|
||||||
import argparse
|
import argparse
|
||||||
import torch
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
import torch
|
||||||
|
|
||||||
from transformers import BertModel
|
from transformers import BertModel
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -17,13 +17,13 @@
|
|||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import logging
|
||||||
from io import open
|
from io import open
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import CONFIG_NAME, WEIGHTS_NAME, GPT2Config, GPT2Model, load_tf_weights_in_gpt2
|
from transformers import CONFIG_NAME, WEIGHTS_NAME, GPT2Config, GPT2Model, load_tf_weights_in_gpt2
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|||||||
@@ -17,13 +17,13 @@
|
|||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import logging
|
||||||
from io import open
|
from io import open
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import CONFIG_NAME, WEIGHTS_NAME, OpenAIGPTConfig, OpenAIGPTModel, load_tf_weights_in_openai_gpt
|
from transformers import CONFIG_NAME, WEIGHTS_NAME, OpenAIGPTConfig, OpenAIGPTModel, load_tf_weights_in_openai_gpt
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|||||||
@@ -14,58 +14,59 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" Convert pytorch checkpoints to TensorFlow """
|
""" Convert pytorch checkpoints to TensorFlow """
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import, division, print_function
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import os
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from transformers import is_torch_available, cached_path
|
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
load_pytorch_checkpoint_in_tf2_model,
|
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
AlbertConfig,
|
||||||
BertConfig,
|
BertConfig,
|
||||||
|
CTRLConfig,
|
||||||
|
DistilBertConfig,
|
||||||
|
GPT2Config,
|
||||||
|
OpenAIGPTConfig,
|
||||||
|
RobertaConfig,
|
||||||
|
T5Config,
|
||||||
|
TFAlbertForMaskedLM,
|
||||||
TFBertForPreTraining,
|
TFBertForPreTraining,
|
||||||
TFBertForQuestionAnswering,
|
TFBertForQuestionAnswering,
|
||||||
TFBertForSequenceClassification,
|
TFBertForSequenceClassification,
|
||||||
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
TFCTRLLMHeadModel,
|
||||||
GPT2Config,
|
|
||||||
TFGPT2LMHeadModel,
|
|
||||||
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
|
||||||
XLNetConfig,
|
|
||||||
TFXLNetLMHeadModel,
|
|
||||||
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
|
||||||
XLMConfig,
|
|
||||||
TFXLMWithLMHeadModel,
|
|
||||||
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
|
||||||
TransfoXLConfig,
|
|
||||||
TFTransfoXLLMHeadModel,
|
|
||||||
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
|
||||||
OpenAIGPTConfig,
|
|
||||||
TFOpenAIGPTLMHeadModel,
|
|
||||||
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
|
||||||
RobertaConfig,
|
|
||||||
TFRobertaForMaskedLM,
|
|
||||||
TFRobertaForSequenceClassification,
|
|
||||||
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
|
||||||
DistilBertConfig,
|
|
||||||
TFDistilBertForMaskedLM,
|
TFDistilBertForMaskedLM,
|
||||||
TFDistilBertForQuestionAnswering,
|
TFDistilBertForQuestionAnswering,
|
||||||
TFDistilBertForSequenceClassification,
|
TFDistilBertForSequenceClassification,
|
||||||
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
TFGPT2LMHeadModel,
|
||||||
CTRLConfig,
|
TFOpenAIGPTLMHeadModel,
|
||||||
TFCTRLLMHeadModel,
|
TFRobertaForMaskedLM,
|
||||||
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
TFRobertaForSequenceClassification,
|
||||||
AlbertConfig,
|
|
||||||
TFAlbertForMaskedLM,
|
|
||||||
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
|
||||||
T5Config,
|
|
||||||
TFT5WithLMHeadModel,
|
TFT5WithLMHeadModel,
|
||||||
T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
TFTransfoXLLMHeadModel,
|
||||||
|
TFXLMWithLMHeadModel,
|
||||||
|
TFXLNetLMHeadModel,
|
||||||
|
TransfoXLConfig,
|
||||||
|
XLMConfig,
|
||||||
|
XLNetConfig,
|
||||||
|
cached_path,
|
||||||
|
is_torch_available,
|
||||||
|
load_pytorch_checkpoint_in_tf2_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -158,8 +159,6 @@ else:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
MODEL_CLASSES = {
|
||||||
|
|||||||
@@ -18,16 +18,13 @@ from __future__ import absolute_import, division, print_function
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import pathlib
|
import pathlib
|
||||||
|
|
||||||
import fairseq
|
import numpy as np
|
||||||
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
if version.parse(fairseq.__version__) < version.parse("0.9.0"):
|
import fairseq
|
||||||
raise Exception("requires fairseq >= 0.9.0")
|
|
||||||
|
|
||||||
from fairseq.models.roberta import RobertaModel as FairseqRobertaModel
|
from fairseq.models.roberta import RobertaModel as FairseqRobertaModel
|
||||||
from fairseq.modules import TransformerSentenceEncoderLayer
|
from fairseq.modules import TransformerSentenceEncoderLayer
|
||||||
from transformers.modeling_bert import (
|
from transformers.modeling_bert import (
|
||||||
@@ -47,6 +44,11 @@ from transformers.modeling_roberta import (
|
|||||||
RobertaModel,
|
RobertaModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if version.parse(fairseq.__version__) < version.parse("0.9.0"):
|
||||||
|
raise Exception("requires fairseq >= 0.9.0")
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -14,16 +14,15 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Convert T5 checkpoint."""
|
"""Convert T5 checkpoint."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import, division, print_function
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import logging
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import T5Config, T5Model, load_tf_weights_in_t5
|
from transformers import T5Config, T5Model, load_tf_weights_in_t5
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from io import open
|
from io import open
|
||||||
@@ -24,17 +25,21 @@ from io import open
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import transformers.tokenization_transfo_xl as data_utils
|
import transformers.tokenization_transfo_xl as data_utils
|
||||||
|
from transformers import (
|
||||||
from transformers import CONFIG_NAME, WEIGHTS_NAME
|
CONFIG_NAME,
|
||||||
from transformers import TransfoXLConfig, TransfoXLLMHeadModel, load_tf_weights_in_transfo_xl
|
WEIGHTS_NAME,
|
||||||
|
TransfoXLConfig,
|
||||||
|
TransfoXLLMHeadModel,
|
||||||
|
load_tf_weights_in_transfo_xl,
|
||||||
|
)
|
||||||
from transformers.tokenization_transfo_xl import CORPUS_NAME, VOCAB_FILES_NAMES
|
from transformers.tokenization_transfo_xl import CORPUS_NAME, VOCAB_FILES_NAMES
|
||||||
|
|
||||||
|
|
||||||
if sys.version_info[0] == 2:
|
if sys.version_info[0] == 2:
|
||||||
import cPickle as pickle
|
import cPickle as pickle
|
||||||
else:
|
else:
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|||||||
@@ -18,15 +18,15 @@ from __future__ import absolute_import, division, print_function
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
from io import open
|
from io import open
|
||||||
|
|
||||||
import torch
|
|
||||||
import numpy
|
import numpy
|
||||||
|
import torch
|
||||||
|
|
||||||
from transformers import CONFIG_NAME, WEIGHTS_NAME
|
from transformers import CONFIG_NAME, WEIGHTS_NAME
|
||||||
from transformers.tokenization_xlm import VOCAB_FILES_NAMES
|
from transformers.tokenization_xlm import VOCAB_FILES_NAMES
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|||||||
@@ -14,24 +14,25 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Convert BERT checkpoint."""
|
"""Convert BERT checkpoint."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import, division, print_function
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import os
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
CONFIG_NAME,
|
CONFIG_NAME,
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
XLNetConfig,
|
XLNetConfig,
|
||||||
XLNetLMHeadModel,
|
|
||||||
XLNetForQuestionAnswering,
|
XLNetForQuestionAnswering,
|
||||||
XLNetForSequenceClassification,
|
XLNetForSequenceClassification,
|
||||||
|
XLNetLMHeadModel,
|
||||||
load_tf_weights_in_xlnet,
|
load_tf_weights_in_xlnet,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
GLUE_TASKS_NUM_LABELS = {
|
GLUE_TASKS_NUM_LABELS = {
|
||||||
"cola": 2,
|
"cola": 2,
|
||||||
"mnli": 3,
|
"mnli": 3,
|
||||||
@@ -44,7 +45,6 @@ GLUE_TASKS_NUM_LABELS = {
|
|||||||
"wnli": 2,
|
"wnli": 2,
|
||||||
}
|
}
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,23 @@
|
|||||||
|
from .metrics import is_sklearn_available
|
||||||
from .processors import (
|
from .processors import (
|
||||||
|
DataProcessor,
|
||||||
InputExample,
|
InputExample,
|
||||||
InputFeatures,
|
InputFeatures,
|
||||||
DataProcessor,
|
|
||||||
SquadFeatures,
|
|
||||||
SingleSentenceClassificationProcessor,
|
SingleSentenceClassificationProcessor,
|
||||||
|
SquadExample,
|
||||||
|
SquadFeatures,
|
||||||
|
SquadV1Processor,
|
||||||
|
SquadV2Processor,
|
||||||
|
glue_convert_examples_to_features,
|
||||||
|
glue_output_modes,
|
||||||
|
glue_processors,
|
||||||
|
glue_tasks_num_labels,
|
||||||
|
squad_convert_examples_to_features,
|
||||||
|
xnli_output_modes,
|
||||||
|
xnli_processors,
|
||||||
|
xnli_tasks_num_labels,
|
||||||
)
|
)
|
||||||
from .processors import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features
|
|
||||||
from .processors import squad_convert_examples_to_features, SquadExample, SquadV1Processor, SquadV2Processor
|
|
||||||
from .processors import xnli_output_modes, xnli_processors, xnli_tasks_num_labels
|
|
||||||
|
|
||||||
from .metrics import is_sklearn_available
|
|
||||||
|
|
||||||
if is_sklearn_available():
|
if is_sklearn_available():
|
||||||
from .metrics import glue_compute_metrics, xnli_compute_metrics
|
from .metrics import glue_compute_metrics, xnli_compute_metrics
|
||||||
|
|||||||
@@ -15,8 +15,9 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import csv
|
import csv
|
||||||
import sys
|
|
||||||
import logging
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -8,17 +8,19 @@ that a question is unanswerable.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import collections
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import collections
|
|
||||||
from io import open
|
|
||||||
from tqdm import tqdm
|
|
||||||
import string
|
|
||||||
import re
|
import re
|
||||||
|
import string
|
||||||
|
from io import open
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from transformers.tokenization_bert import BasicTokenizer, whitespace_tokenize
|
from transformers.tokenization_bert import BasicTokenizer, whitespace_tokenize
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from .utils import InputExample, InputFeatures, DataProcessor, SingleSentenceClassificationProcessor
|
from .glue import glue_convert_examples_to_features, glue_output_modes, glue_processors, glue_tasks_num_labels
|
||||||
from .glue import glue_output_modes, glue_processors, glue_tasks_num_labels, glue_convert_examples_to_features
|
from .squad import SquadExample, SquadFeatures, SquadV1Processor, SquadV2Processor, squad_convert_examples_to_features
|
||||||
from .squad import squad_convert_examples_to_features, SquadFeatures, SquadExample, SquadV1Processor, SquadV2Processor
|
from .utils import DataProcessor, InputExample, InputFeatures, SingleSentenceClassificationProcessor
|
||||||
from .xnli import xnli_output_modes, xnli_processors, xnli_tasks_num_labels
|
from .xnli import xnli_output_modes, xnli_processors, xnli_tasks_num_labels
|
||||||
|
|||||||
@@ -18,8 +18,9 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from .utils import DataProcessor, InputExample, InputFeatures
|
|
||||||
from ...file_utils import is_tf_available
|
from ...file_utils import is_tf_available
|
||||||
|
from .utils import DataProcessor, InputExample, InputFeatures
|
||||||
|
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|||||||
@@ -1,16 +1,17 @@
|
|||||||
from tqdm import tqdm
|
|
||||||
import collections
|
import collections
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import json
|
|
||||||
import numpy as np
|
|
||||||
from multiprocessing import Pool
|
|
||||||
from multiprocessing import cpu_count
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from multiprocessing import Pool, cpu_count
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from ...file_utils import is_tf_available, is_torch_available
|
||||||
from ...tokenization_bert import BasicTokenizer, whitespace_tokenize
|
from ...tokenization_bert import BasicTokenizer, whitespace_tokenize
|
||||||
from .utils import DataProcessor, InputExample, InputFeatures
|
from .utils import DataProcessor, InputExample, InputFeatures
|
||||||
from ...file_utils import is_tf_available, is_torch_available
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@@ -14,14 +14,15 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import csv
|
|
||||||
import sys
|
|
||||||
import copy
|
import copy
|
||||||
|
import csv
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
from ...file_utils import is_tf_available, is_torch_available
|
from ...file_utils import is_tf_available, is_torch_available
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import os
|
|||||||
|
|
||||||
from .utils import DataProcessor, InputExample
|
from .utils import DataProcessor, InputExample
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -5,26 +5,27 @@ Copyright by the AllenNLP authors.
|
|||||||
"""
|
"""
|
||||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
|
|
||||||
import sys
|
import fnmatch
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import six
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import fnmatch
|
from contextlib import contextmanager
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from io import open
|
from io import open
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
|
import requests
|
||||||
|
import six
|
||||||
from botocore.config import Config
|
from botocore.config import Config
|
||||||
from botocore.exceptions import ClientError
|
from botocore.exceptions import ClientError
|
||||||
import requests
|
from filelock import FileLock
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from contextlib import contextmanager
|
|
||||||
from . import __version__
|
from . import __version__
|
||||||
|
|
||||||
from filelock import FileLock
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import six
|
|||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
ENDPOINT = "https://huggingface.co"
|
ENDPOINT = "https://huggingface.co"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -23,15 +23,14 @@ import os
|
|||||||
from io import open
|
from io import open
|
||||||
|
|
||||||
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
|
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
|
|
||||||
from .file_utils import (
|
from .file_utils import (
|
||||||
CONFIG_NAME,
|
CONFIG_NAME,
|
||||||
MODEL_CARD_NAME,
|
MODEL_CARD_NAME,
|
||||||
WEIGHTS_NAME,
|
|
||||||
TF2_WEIGHTS_NAME,
|
TF2_WEIGHTS_NAME,
|
||||||
|
WEIGHTS_NAME,
|
||||||
cached_path,
|
cached_path,
|
||||||
is_remote_url,
|
|
||||||
hf_bucket_url,
|
hf_bucket_url,
|
||||||
|
is_remote_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -14,17 +14,21 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""PyTorch ALBERT model. """
|
"""PyTorch ALBERT model. """
|
||||||
|
|
||||||
import os
|
|
||||||
import math
|
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn import CrossEntropyLoss, MSELoss
|
from torch.nn import CrossEntropyLoss, MSELoss
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
|
||||||
from transformers.configuration_albert import AlbertConfig
|
from transformers.configuration_albert import AlbertConfig
|
||||||
from transformers.modeling_bert import BertEmbeddings, BertSelfAttention, prune_linear_layer, ACT2FN
|
from transformers.modeling_bert import ACT2FN, BertEmbeddings, BertSelfAttention, prune_linear_layer
|
||||||
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
|
||||||
from .file_utils import add_start_docstrings
|
from .file_utils import add_start_docstrings
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -29,80 +29,78 @@ from .configuration_auto import (
|
|||||||
RobertaConfig,
|
RobertaConfig,
|
||||||
TransfoXLConfig,
|
TransfoXLConfig,
|
||||||
XLMConfig,
|
XLMConfig,
|
||||||
XLNetConfig,
|
|
||||||
XLMRobertaConfig,
|
XLMRobertaConfig,
|
||||||
|
XLNetConfig,
|
||||||
|
)
|
||||||
|
from .file_utils import add_start_docstrings
|
||||||
|
from .modeling_albert import (
|
||||||
|
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
AlbertForMaskedLM,
|
||||||
|
AlbertForQuestionAnswering,
|
||||||
|
AlbertForSequenceClassification,
|
||||||
|
AlbertModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .modeling_bert import (
|
from .modeling_bert import (
|
||||||
BertModel,
|
|
||||||
BertForMaskedLM,
|
|
||||||
BertForSequenceClassification,
|
|
||||||
BertForQuestionAnswering,
|
|
||||||
BertForTokenClassification,
|
|
||||||
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
BertForMaskedLM,
|
||||||
|
BertForQuestionAnswering,
|
||||||
|
BertForSequenceClassification,
|
||||||
|
BertForTokenClassification,
|
||||||
|
BertModel,
|
||||||
)
|
)
|
||||||
from .modeling_openai import OpenAIGPTModel, OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
|
from .modeling_camembert import (
|
||||||
from .modeling_gpt2 import GPT2Model, GPT2LMHeadModel, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
|
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
from .modeling_ctrl import CTRLModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
|
CamembertForMaskedLM,
|
||||||
from .modeling_transfo_xl import TransfoXLModel, TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
|
CamembertForMultipleChoice,
|
||||||
from .modeling_xlnet import (
|
CamembertForSequenceClassification,
|
||||||
XLNetModel,
|
CamembertForTokenClassification,
|
||||||
XLNetLMHeadModel,
|
CamembertModel,
|
||||||
XLNetForSequenceClassification,
|
|
||||||
XLNetForQuestionAnswering,
|
|
||||||
XLNetForTokenClassification,
|
|
||||||
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
|
||||||
)
|
)
|
||||||
from .modeling_xlm import (
|
from .modeling_ctrl import CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRLLMHeadModel, CTRLModel
|
||||||
XLMModel,
|
from .modeling_distilbert import (
|
||||||
XLMWithLMHeadModel,
|
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
XLMForSequenceClassification,
|
DistilBertForMaskedLM,
|
||||||
XLMForQuestionAnswering,
|
DistilBertForQuestionAnswering,
|
||||||
XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
DistilBertForSequenceClassification,
|
||||||
|
DistilBertForTokenClassification,
|
||||||
|
DistilBertModel,
|
||||||
)
|
)
|
||||||
|
from .modeling_gpt2 import GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2LMHeadModel, GPT2Model
|
||||||
|
from .modeling_openai import OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OpenAIGPTLMHeadModel, OpenAIGPTModel
|
||||||
from .modeling_roberta import (
|
from .modeling_roberta import (
|
||||||
RobertaModel,
|
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
RobertaForMaskedLM,
|
RobertaForMaskedLM,
|
||||||
RobertaForSequenceClassification,
|
RobertaForSequenceClassification,
|
||||||
RobertaForTokenClassification,
|
RobertaForTokenClassification,
|
||||||
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
RobertaModel,
|
||||||
)
|
)
|
||||||
from .modeling_distilbert import (
|
from .modeling_t5 import T5_PRETRAINED_MODEL_ARCHIVE_MAP, T5Model, T5WithLMHeadModel
|
||||||
DistilBertModel,
|
from .modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TransfoXLLMHeadModel, TransfoXLModel
|
||||||
DistilBertForQuestionAnswering,
|
|
||||||
DistilBertForMaskedLM,
|
|
||||||
DistilBertForSequenceClassification,
|
|
||||||
DistilBertForTokenClassification,
|
|
||||||
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
|
||||||
)
|
|
||||||
from .modeling_camembert import (
|
|
||||||
CamembertModel,
|
|
||||||
CamembertForMaskedLM,
|
|
||||||
CamembertForSequenceClassification,
|
|
||||||
CamembertForMultipleChoice,
|
|
||||||
CamembertForTokenClassification,
|
|
||||||
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
|
||||||
)
|
|
||||||
from .modeling_albert import (
|
|
||||||
AlbertModel,
|
|
||||||
AlbertForMaskedLM,
|
|
||||||
AlbertForSequenceClassification,
|
|
||||||
AlbertForQuestionAnswering,
|
|
||||||
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
|
||||||
)
|
|
||||||
from .modeling_t5 import T5Model, T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP
|
|
||||||
from .modeling_xlm_roberta import (
|
|
||||||
XLMRobertaModel,
|
|
||||||
XLMRobertaForMaskedLM,
|
|
||||||
XLMRobertaForSequenceClassification,
|
|
||||||
XLMRobertaForMultipleChoice,
|
|
||||||
XLMRobertaForTokenClassification,
|
|
||||||
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .modeling_utils import PreTrainedModel, SequenceSummary
|
from .modeling_utils import PreTrainedModel, SequenceSummary
|
||||||
|
from .modeling_xlm import (
|
||||||
|
XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
XLMForQuestionAnswering,
|
||||||
|
XLMForSequenceClassification,
|
||||||
|
XLMModel,
|
||||||
|
XLMWithLMHeadModel,
|
||||||
|
)
|
||||||
|
from .modeling_xlm_roberta import (
|
||||||
|
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
XLMRobertaForMaskedLM,
|
||||||
|
XLMRobertaForMultipleChoice,
|
||||||
|
XLMRobertaForSequenceClassification,
|
||||||
|
XLMRobertaForTokenClassification,
|
||||||
|
XLMRobertaModel,
|
||||||
|
)
|
||||||
|
from .modeling_xlnet import (
|
||||||
|
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
|
XLNetForQuestionAnswering,
|
||||||
|
XLNetForSequenceClassification,
|
||||||
|
XLNetForTokenClassification,
|
||||||
|
XLNetLMHeadModel,
|
||||||
|
XLNetModel,
|
||||||
|
)
|
||||||
|
|
||||||
from .file_utils import add_start_docstrings
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -26,9 +26,10 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import CrossEntropyLoss, MSELoss
|
from torch.nn import CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from .modeling_utils import PreTrainedModel, prune_linear_layer
|
|
||||||
from .configuration_bert import BertConfig
|
from .configuration_bert import BertConfig
|
||||||
from .file_utils import add_start_docstrings
|
from .file_utils import add_start_docstrings
|
||||||
|
from .modeling_utils import PreTrainedModel, prune_linear_layer
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -19,15 +19,16 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from .modeling_roberta import (
|
|
||||||
RobertaModel,
|
|
||||||
RobertaForMaskedLM,
|
|
||||||
RobertaForSequenceClassification,
|
|
||||||
RobertaForMultipleChoice,
|
|
||||||
RobertaForTokenClassification,
|
|
||||||
)
|
|
||||||
from .configuration_camembert import CamembertConfig
|
from .configuration_camembert import CamembertConfig
|
||||||
from .file_utils import add_start_docstrings
|
from .file_utils import add_start_docstrings
|
||||||
|
from .modeling_roberta import (
|
||||||
|
RobertaForMaskedLM,
|
||||||
|
RobertaForMultipleChoice,
|
||||||
|
RobertaForSequenceClassification,
|
||||||
|
RobertaForTokenClassification,
|
||||||
|
RobertaModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -24,15 +24,17 @@ import math
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from io import open
|
from io import open
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from .modeling_utils import PreTrainedModel, Conv1D, prune_conv1d_layer, SequenceSummary
|
|
||||||
from .configuration_ctrl import CTRLConfig
|
from .configuration_ctrl import CTRLConfig
|
||||||
from .file_utils import add_start_docstrings
|
from .file_utils import add_start_docstrings
|
||||||
|
from .modeling_utils import Conv1D, PreTrainedModel, SequenceSummary, prune_conv1d_layer
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -18,25 +18,23 @@
|
|||||||
"""
|
"""
|
||||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import itertools
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import copy
|
|
||||||
import sys
|
import sys
|
||||||
from io import open
|
from io import open
|
||||||
|
|
||||||
import itertools
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
from .modeling_utils import PreTrainedModel, prune_linear_layer
|
|
||||||
from .configuration_distilbert import DistilBertConfig
|
from .configuration_distilbert import DistilBertConfig
|
||||||
from .file_utils import add_start_docstrings
|
from .file_utils import add_start_docstrings
|
||||||
|
from .modeling_utils import PreTrainedModel, prune_linear_layer
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from tqdm import trange
|
|||||||
|
|
||||||
from .modeling_auto import AutoModel, AutoModelWithLMHead
|
from .modeling_auto import AutoModel, AutoModelWithLMHead
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user