directly load from TF checkpoints + code cleanup
This commit is contained in:
@@ -2,6 +2,7 @@ __version__ = "0.5.0"
|
|||||||
from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer
|
from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer
|
||||||
from .tokenization_openai import OpenAIGPTTokenizer
|
from .tokenization_openai import OpenAIGPTTokenizer
|
||||||
from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus)
|
from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus)
|
||||||
|
|
||||||
from .modeling import (BertConfig, BertModel, BertForPreTraining,
|
from .modeling import (BertConfig, BertModel, BertForPreTraining,
|
||||||
BertForMaskedLM, BertForNextSentencePrediction,
|
BertForMaskedLM, BertForNextSentencePrediction,
|
||||||
BertForSequenceClassification, BertForMultipleChoice,
|
BertForSequenceClassification, BertForMultipleChoice,
|
||||||
@@ -9,6 +10,11 @@ from .modeling import (BertConfig, BertModel, BertForPreTraining,
|
|||||||
from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel,
|
from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel,
|
||||||
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel)
|
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel)
|
||||||
from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel)
|
from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel)
|
||||||
|
|
||||||
from .optimization import BertAdam
|
from .optimization import BertAdam
|
||||||
from .optimization_openai import OpenAIAdam
|
from .optimization_openai import OpenAIAdam
|
||||||
|
|
||||||
|
from .convert_openai_checkpoint_to_pytorch import load_tf_weights_in_openai_gpt
|
||||||
|
from .convert_tf_checkpoint_to_pytorch import load_tf_weights_in_bert
|
||||||
|
from .convert_transfo_xl_checkpoint_to_pytorch import load_tf_weights_in_transfo_xl
|
||||||
from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||||
|
|||||||
@@ -26,9 +26,29 @@ import numpy as np
|
|||||||
|
|
||||||
from .modeling_openai import OpenAIGPTConfig, OpenAIGPTModel, CONFIG_NAME, WEIGHTS_NAME
|
from .modeling_openai import OpenAIGPTConfig, OpenAIGPTModel, CONFIG_NAME, WEIGHTS_NAME
|
||||||
|
|
||||||
|
|
||||||
def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path):
|
def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path):
|
||||||
# Load weights from TF model
|
# Construct model
|
||||||
|
if openai_config_file == "":
|
||||||
|
config = OpenAIGPTConfig()
|
||||||
|
else:
|
||||||
|
config = OpenAIGPTConfig(openai_config_file)
|
||||||
|
model = OpenAIGPTModel(config)
|
||||||
|
|
||||||
|
# Load weights from numpy
|
||||||
|
load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path)
|
||||||
|
|
||||||
|
# Save pytorch-model
|
||||||
|
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
|
||||||
|
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
|
||||||
|
print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
|
||||||
|
torch.save(model.state_dict(), pytorch_weights_dump_path)
|
||||||
|
print("Save configuration file to {}".format(pytorch_config_dump_path))
|
||||||
|
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(config.to_json_string())
|
||||||
|
|
||||||
|
def load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path):
|
||||||
|
""" Load tf pre-trained weights in a pytorch model (from NumPy arrays here)
|
||||||
|
"""
|
||||||
print("Loading weights...")
|
print("Loading weights...")
|
||||||
names = json.load(open(openai_checkpoint_folder_path + '/parameters_names.json', "r", encoding='utf-8'))
|
names = json.load(open(openai_checkpoint_folder_path + '/parameters_names.json', "r", encoding='utf-8'))
|
||||||
shapes = json.load(open(openai_checkpoint_folder_path + '/params_shapes.json', "r", encoding='utf-8'))
|
shapes = json.load(open(openai_checkpoint_folder_path + '/params_shapes.json', "r", encoding='utf-8'))
|
||||||
@@ -36,35 +56,11 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c
|
|||||||
init_params = [np.load(openai_checkpoint_folder_path + '/params_{}.npy'.format(n)) for n in range(10)]
|
init_params = [np.load(openai_checkpoint_folder_path + '/params_{}.npy'.format(n)) for n in range(10)]
|
||||||
init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
|
init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
|
||||||
init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]
|
init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]
|
||||||
# if n_ctx > 0:
|
|
||||||
# init_params[0] = init_params[0][:n_ctx]
|
|
||||||
# if n_special > 0:
|
|
||||||
# init_params[0] = np.concatenate(
|
|
||||||
# [init_params[1],
|
|
||||||
# (np.random.randn(n_special, n_embd) * 0.02).astype(np.float32),
|
|
||||||
# init_params[0]
|
|
||||||
# ], 0)
|
|
||||||
# else:
|
|
||||||
# init_params[0] = np.concatenate(
|
|
||||||
# [init_params[1],
|
|
||||||
# init_params[0]
|
|
||||||
# ], 0)
|
|
||||||
# del init_params[1]
|
|
||||||
# if n_transfer == -1:
|
|
||||||
# n_transfer = 0
|
|
||||||
# else:
|
|
||||||
# n_transfer = 1 + n_transfer * 12
|
|
||||||
|
|
||||||
init_params[0] = np.concatenate([init_params[1], init_params[0]], 0)
|
init_params[0] = np.concatenate([init_params[1], init_params[0]], 0)
|
||||||
del init_params[1]
|
del init_params[1]
|
||||||
init_params = [arr.squeeze() for arr in init_params]
|
init_params = [arr.squeeze() for arr in init_params]
|
||||||
|
|
||||||
# Construct model
|
|
||||||
if openai_config_file == "":
|
|
||||||
config = OpenAIGPTConfig()
|
|
||||||
else:
|
|
||||||
config = OpenAIGPTConfig(openai_config_file)
|
|
||||||
model = OpenAIGPTModel(config)
|
|
||||||
try:
|
try:
|
||||||
assert model.embed.weight.shape == init_params[0].shape
|
assert model.embed.weight.shape == init_params[0].shape
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
@@ -109,15 +105,7 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c
|
|||||||
raise
|
raise
|
||||||
print("Initialize PyTorch weight {}".format(name))
|
print("Initialize PyTorch weight {}".format(name))
|
||||||
pointer.data = torch.from_numpy(array)
|
pointer.data = torch.from_numpy(array)
|
||||||
|
return model
|
||||||
# Save pytorch-model
|
|
||||||
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
|
|
||||||
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
|
|
||||||
print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
|
|
||||||
torch.save(model.state_dict(), pytorch_weights_dump_path)
|
|
||||||
print("Save configuration file to {}".format(pytorch_config_dump_path))
|
|
||||||
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
|
|
||||||
f.write(config.to_json_string())
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|||||||
@@ -28,9 +28,23 @@ import numpy as np
|
|||||||
from .modeling import BertConfig, BertForPreTraining
|
from .modeling import BertConfig, BertForPreTraining
|
||||||
|
|
||||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
|
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
|
||||||
config_path = os.path.abspath(bert_config_file)
|
# Initialise PyTorch model
|
||||||
|
config = BertConfig.from_json_file(bert_config_file)
|
||||||
|
print("Building PyTorch model from configuration: {}".format(str(config)))
|
||||||
|
model = BertForPreTraining(config)
|
||||||
|
|
||||||
|
# Load weights from tf checkpoint
|
||||||
|
load_tf_weights_in_bert(model, tf_checkpoint_path)
|
||||||
|
|
||||||
|
# Save pytorch-model
|
||||||
|
print("Save PyTorch model to {}".format(pytorch_dump_path))
|
||||||
|
torch.save(model.state_dict(), pytorch_dump_path)
|
||||||
|
|
||||||
|
def load_tf_weights_in_bert(model, tf_checkpoint_path):
|
||||||
|
""" Load tf checkpoints in a pytorch model
|
||||||
|
"""
|
||||||
tf_path = os.path.abspath(tf_checkpoint_path)
|
tf_path = os.path.abspath(tf_checkpoint_path)
|
||||||
print("Converting TensorFlow checkpoint from {} with config at {}".format(tf_path, config_path))
|
print("Converting TensorFlow checkpoint from {}".format(tf_path))
|
||||||
# Load weights from TF model
|
# Load weights from TF model
|
||||||
init_vars = tf.train.list_variables(tf_path)
|
init_vars = tf.train.list_variables(tf_path)
|
||||||
names = []
|
names = []
|
||||||
@@ -41,11 +55,6 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor
|
|||||||
names.append(name)
|
names.append(name)
|
||||||
arrays.append(array)
|
arrays.append(array)
|
||||||
|
|
||||||
# Initialise PyTorch model
|
|
||||||
config = BertConfig.from_json_file(bert_config_file)
|
|
||||||
print("Building PyTorch model from configuration: {}".format(str(config)))
|
|
||||||
model = BertForPreTraining(config)
|
|
||||||
|
|
||||||
for name, array in zip(names, arrays):
|
for name, array in zip(names, arrays):
|
||||||
name = name.split('/')
|
name = name.split('/')
|
||||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||||
@@ -81,11 +90,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor
|
|||||||
raise
|
raise
|
||||||
print("Initialize PyTorch weight {}".format(name))
|
print("Initialize PyTorch weight {}".format(name))
|
||||||
pointer.data = torch.from_numpy(array)
|
pointer.data = torch.from_numpy(array)
|
||||||
|
return model
|
||||||
# Save pytorch-model
|
|
||||||
print("Save PyTorch model to {}".format(pytorch_dump_path))
|
|
||||||
torch.save(model.state_dict(), pytorch_dump_path)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|||||||
@@ -106,7 +106,6 @@ def build_tf_to_pytorch_map(model, config):
|
|||||||
'transformer/r_w_bias': r_w_list})
|
'transformer/r_w_bias': r_w_list})
|
||||||
return tf_to_pt_map
|
return tf_to_pt_map
|
||||||
|
|
||||||
|
|
||||||
def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
|
def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
|
||||||
transfo_xl_config_file,
|
transfo_xl_config_file,
|
||||||
pytorch_dump_folder_path,
|
pytorch_dump_folder_path,
|
||||||
@@ -140,6 +139,20 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
|
|||||||
print("Building PyTorch model from configuration: {}".format(str(config)))
|
print("Building PyTorch model from configuration: {}".format(str(config)))
|
||||||
model = TransfoXLModel(config)
|
model = TransfoXLModel(config)
|
||||||
|
|
||||||
|
model = load_tf_weights_in_transfo_xl(model, config, tf_path)
|
||||||
|
# Save pytorch-model
|
||||||
|
pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME)
|
||||||
|
pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME)
|
||||||
|
print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path)))
|
||||||
|
torch.save(model.state_dict(), pytorch_weights_dump_path)
|
||||||
|
print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path)))
|
||||||
|
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(config.to_json_string())
|
||||||
|
|
||||||
|
|
||||||
|
def load_tf_weights_in_transfo_xl(model, config, tf_path):
|
||||||
|
""" Load tf checkpoints in a pytorch model
|
||||||
|
"""
|
||||||
# Build TF to PyTorch weights loading map
|
# Build TF to PyTorch weights loading map
|
||||||
tf_to_pt_map = build_tf_to_pytorch_map(model, config)
|
tf_to_pt_map = build_tf_to_pytorch_map(model, config)
|
||||||
|
|
||||||
@@ -183,16 +196,7 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
|
|||||||
tf_weights.pop(name + '/Adam_1', None)
|
tf_weights.pop(name + '/Adam_1', None)
|
||||||
|
|
||||||
print("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys())))
|
print("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys())))
|
||||||
|
return model
|
||||||
# Save pytorch-model
|
|
||||||
pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME)
|
|
||||||
pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME)
|
|
||||||
print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path)))
|
|
||||||
torch.save(model.state_dict(), pytorch_weights_dump_path)
|
|
||||||
print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path)))
|
|
||||||
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
|
|
||||||
f.write(config.to_json_string())
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ from torch import nn
|
|||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
from .file_utils import cached_path
|
from .file_utils import cached_path
|
||||||
|
from .convert_tf_checkpoint_to_pytorch import load_tf_weights_in_bert
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -47,6 +48,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
|
|||||||
}
|
}
|
||||||
CONFIG_NAME = 'bert_config.json'
|
CONFIG_NAME = 'bert_config.json'
|
||||||
WEIGHTS_NAME = 'pytorch_model.bin'
|
WEIGHTS_NAME = 'pytorch_model.bin'
|
||||||
|
TF_WEIGHTS_NAME = 'model.ckpt'
|
||||||
|
|
||||||
def gelu(x):
|
def gelu(x):
|
||||||
"""Implementation of the gelu activation function.
|
"""Implementation of the gelu activation function.
|
||||||
@@ -445,7 +447,8 @@ class BertPreTrainedModel(nn.Module):
|
|||||||
module.bias.data.zero_()
|
module.bias.data.zero_()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs):
|
def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None,
|
||||||
|
from_tf=False, *inputs, **kwargs):
|
||||||
"""
|
"""
|
||||||
Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
||||||
Download and cache the pre-trained model file if needed.
|
Download and cache the pre-trained model file if needed.
|
||||||
@@ -463,6 +466,10 @@ class BertPreTrainedModel(nn.Module):
|
|||||||
- a path or url to a pretrained model archive containing:
|
- a path or url to a pretrained model archive containing:
|
||||||
. `bert_config.json` a configuration file for the model
|
. `bert_config.json` a configuration file for the model
|
||||||
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
|
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
|
||||||
|
- a path or url to a pretrained model archive containing:
|
||||||
|
. `bert_config.json` a configuration file for the model
|
||||||
|
. `model.chkpt` a TensorFlow checkpoint
|
||||||
|
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
|
||||||
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
|
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
|
||||||
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
|
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
|
||||||
*inputs, **kwargs: additional input for the specific Bert class
|
*inputs, **kwargs: additional input for the specific Bert class
|
||||||
@@ -490,7 +497,7 @@ class BertPreTrainedModel(nn.Module):
|
|||||||
logger.info("loading archive file {} from cache at {}".format(
|
logger.info("loading archive file {} from cache at {}".format(
|
||||||
archive_file, resolved_archive_file))
|
archive_file, resolved_archive_file))
|
||||||
tempdir = None
|
tempdir = None
|
||||||
if os.path.isdir(resolved_archive_file):
|
if os.path.isdir(resolved_archive_file) or from_tf:
|
||||||
serialization_dir = resolved_archive_file
|
serialization_dir = resolved_archive_file
|
||||||
else:
|
else:
|
||||||
# Extract archive to temp dir
|
# Extract archive to temp dir
|
||||||
@@ -506,10 +513,17 @@ class BertPreTrainedModel(nn.Module):
|
|||||||
logger.info("Model config {}".format(config))
|
logger.info("Model config {}".format(config))
|
||||||
# Instantiate model.
|
# Instantiate model.
|
||||||
model = cls(config, *inputs, **kwargs)
|
model = cls(config, *inputs, **kwargs)
|
||||||
if state_dict is None:
|
if state_dict is None and not from_tf:
|
||||||
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
|
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
|
||||||
state_dict = torch.load(weights_path)
|
state_dict = torch.load(weights_path)
|
||||||
|
if tempdir:
|
||||||
|
# Clean up temp dir
|
||||||
|
shutil.rmtree(tempdir)
|
||||||
|
if from_tf:
|
||||||
|
# Directly load from a TensorFlow checkpoint
|
||||||
|
weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME)
|
||||||
|
return load_tf_weights_in_bert(model, weights_path)
|
||||||
|
# Load from a PyTorch state_dict
|
||||||
old_keys = []
|
old_keys = []
|
||||||
new_keys = []
|
new_keys = []
|
||||||
for key in state_dict.keys():
|
for key in state_dict.keys():
|
||||||
@@ -550,9 +564,6 @@ class BertPreTrainedModel(nn.Module):
|
|||||||
if len(error_msgs) > 0:
|
if len(error_msgs) > 0:
|
||||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
||||||
model.__class__.__name__, "\n\t".join(error_msgs)))
|
model.__class__.__name__, "\n\t".join(error_msgs)))
|
||||||
if tempdir:
|
|
||||||
# Clean up temp dir
|
|
||||||
shutil.rmtree(tempdir)
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -32,14 +32,14 @@ from torch.nn.parameter import Parameter
|
|||||||
|
|
||||||
from .modeling import BertLayerNorm as LayerNorm
|
from .modeling import BertLayerNorm as LayerNorm
|
||||||
from .file_utils import cached_path
|
from .file_utils import cached_path
|
||||||
|
from .convert_openai_checkpoint_to_pytorch import load_tf_weights_in_openai_gpt
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
PRETRAINED_MODEL_ARCHIVE_MAP = {
|
PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt.tar.gz"}
|
||||||
'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt.tar.gz",
|
CONFIG_NAME = "openai_gpt_config.json"
|
||||||
}
|
WEIGHTS_NAME = "pytorch_model.bin"
|
||||||
CONFIG_NAME = 'openai_gpt_config.json'
|
|
||||||
WEIGHTS_NAME = 'pytorch_model.bin'
|
|
||||||
|
|
||||||
def gelu(x):
|
def gelu(x):
|
||||||
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||||
@@ -49,16 +49,15 @@ def swish(x):
|
|||||||
return x * torch.sigmoid(x)
|
return x * torch.sigmoid(x)
|
||||||
|
|
||||||
|
|
||||||
ACT_FNS = {
|
ACT_FNS = {"relu": nn.ReLU, "swish": swish, "gelu": gelu}
|
||||||
'relu': nn.ReLU,
|
|
||||||
'swish': swish,
|
|
||||||
'gelu': gelu
|
|
||||||
}
|
|
||||||
|
|
||||||
class OpenAIGPTConfig(object):
|
class OpenAIGPTConfig(object):
|
||||||
"""Configuration class to store the configuration of a `OpenAIGPTModel`.
|
"""Configuration class to store the configuration of a `OpenAIGPTModel`.
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
vocab_size_or_config_json_file=40478,
|
vocab_size_or_config_json_file=40478,
|
||||||
n_special=0,
|
n_special=0,
|
||||||
n_ctx=512,
|
n_ctx=512,
|
||||||
@@ -69,7 +68,8 @@ class OpenAIGPTConfig(object):
|
|||||||
resid_pdrop=0.1,
|
resid_pdrop=0.1,
|
||||||
embd_pdrop=0.1,
|
embd_pdrop=0.1,
|
||||||
attn_pdrop=0.1,
|
attn_pdrop=0.1,
|
||||||
initializer_range=0.02):
|
initializer_range=0.02,
|
||||||
|
):
|
||||||
"""Constructs OpenAIGPTConfig.
|
"""Constructs OpenAIGPTConfig.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -91,7 +91,7 @@ class OpenAIGPTConfig(object):
|
|||||||
initializing all weight matrices.
|
initializing all weight matrices.
|
||||||
"""
|
"""
|
||||||
if isinstance(vocab_size_or_config_json_file, str):
|
if isinstance(vocab_size_or_config_json_file, str):
|
||||||
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
|
with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
|
||||||
json_config = json.loads(reader.read())
|
json_config = json.loads(reader.read())
|
||||||
for key, value in json_config.items():
|
for key, value in json_config.items():
|
||||||
self.__dict__[key] = value
|
self.__dict__[key] = value
|
||||||
@@ -108,8 +108,10 @@ class OpenAIGPTConfig(object):
|
|||||||
self.attn_pdrop = attn_pdrop
|
self.attn_pdrop = attn_pdrop
|
||||||
self.initializer_range = initializer_range
|
self.initializer_range = initializer_range
|
||||||
else:
|
else:
|
||||||
raise ValueError("First argument must be either a vocabulary size (int)"
|
raise ValueError(
|
||||||
"or the path to a pretrained model config file (str)")
|
"First argument must be either a vocabulary size (int)"
|
||||||
|
"or the path to a pretrained model config file (str)"
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def total_num_embeddings(self):
|
def total_num_embeddings(self):
|
||||||
@@ -126,7 +128,7 @@ class OpenAIGPTConfig(object):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_json_file(cls, json_file):
|
def from_json_file(cls, json_file):
|
||||||
"""Constructs a `OpenAIGPTConfig` from a json file of parameters."""
|
"""Constructs a `OpenAIGPTConfig` from a json file of parameters."""
|
||||||
with open(json_file, "r", encoding='utf-8') as reader:
|
with open(json_file, "r", encoding="utf-8") as reader:
|
||||||
text = reader.read()
|
text = reader.read()
|
||||||
return cls.from_dict(json.loads(text))
|
return cls.from_dict(json.loads(text))
|
||||||
|
|
||||||
@@ -142,6 +144,7 @@ class OpenAIGPTConfig(object):
|
|||||||
"""Serializes this instance to a JSON string."""
|
"""Serializes this instance to a JSON string."""
|
||||||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
||||||
|
|
||||||
|
|
||||||
class Conv1D(nn.Module):
|
class Conv1D(nn.Module):
|
||||||
def __init__(self, nf, rf, nx):
|
def __init__(self, nf, rf, nx):
|
||||||
super(Conv1D, self).__init__()
|
super(Conv1D, self).__init__()
|
||||||
@@ -171,7 +174,7 @@ class Attention(nn.Module):
|
|||||||
n_state = nx # in Attention: n_state=768 (nx=n_embd)
|
n_state = nx # in Attention: n_state=768 (nx=n_embd)
|
||||||
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
|
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
|
||||||
assert n_state % config.n_head == 0
|
assert n_state % config.n_head == 0
|
||||||
self.register_buffer('b', torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
|
self.register_buffer("b", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
|
||||||
self.n_head = config.n_head
|
self.n_head = config.n_head
|
||||||
self.split_size = n_state
|
self.split_size = n_state
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
@@ -284,11 +287,12 @@ class OpenAIGPTMultipleChoiceHead(nn.Module):
|
|||||||
nn.init.normal_(self.linear.weight, std=0.02)
|
nn.init.normal_(self.linear.weight, std=0.02)
|
||||||
nn.init.normal_(self.linear.bias, 0)
|
nn.init.normal_(self.linear.bias, 0)
|
||||||
|
|
||||||
def forward(self, hidden_states, multiple_choice_token_mask):
|
def forward(self, hidden_states, mc_token_mask):
|
||||||
# Classification logits
|
# Classification logits
|
||||||
# hidden_states = hidden_states.view(-1, self.n_embd)
|
# hidden_states = hidden_states.view(-1, self.n_embd)
|
||||||
# multiple_choice_token_mask = multiple_choice_token_mask.view(-1, 1).expand_as(hidden_states)
|
# mc_token_mask = mc_token_mask.view(-1, 1).expand_as(hidden_states)
|
||||||
multiple_choice_h = hidden_states * multiple_choice_token_mask.unsqueeze(-1)
|
mc_token_mask = mc_token_mask.float()
|
||||||
|
multiple_choice_h = hidden_states * mc_token_mask.unsqueeze(-1)
|
||||||
multiple_choice_h = multiple_choice_h.sum(dim=-2)
|
multiple_choice_h = multiple_choice_h.sum(dim=-2)
|
||||||
# flat = x[..., 0].contiguous().view(-1)
|
# flat = x[..., 0].contiguous().view(-1)
|
||||||
# multiple_choice_h = multiple_choice_h[flat == self.multiple_choice_token, :]
|
# multiple_choice_h = multiple_choice_h[flat == self.multiple_choice_token, :]
|
||||||
@@ -307,6 +311,7 @@ class OpenAIGPTPreTrainedModel(nn.Module):
|
|||||||
""" An abstract class to handle weights initialization and
|
""" An abstract class to handle weights initialization and
|
||||||
a simple interface for dowloading and loading pretrained models.
|
a simple interface for dowloading and loading pretrained models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super(OpenAIGPTPreTrainedModel, self).__init__()
|
super(OpenAIGPTPreTrainedModel, self).__init__()
|
||||||
if not isinstance(config, OpenAIGPTConfig):
|
if not isinstance(config, OpenAIGPTConfig):
|
||||||
@@ -315,7 +320,8 @@ class OpenAIGPTPreTrainedModel(nn.Module):
|
|||||||
"To create a model from a pretrained model use "
|
"To create a model from a pretrained model use "
|
||||||
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
||||||
self.__class__.__name__, self.__class__.__name__
|
self.__class__.__name__, self.__class__.__name__
|
||||||
))
|
)
|
||||||
|
)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
def init_weights(self, module):
|
def init_weights(self, module):
|
||||||
@@ -335,8 +341,9 @@ class OpenAIGPTPreTrainedModel(nn.Module):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name, num_special_tokens=0, state_dict=None, cache_dir=None,
|
def from_pretrained(
|
||||||
*inputs, **kwargs):
|
cls, pretrained_model_name, num_special_tokens=None, state_dict=None, cache_dir=None, from_tf=False, *inputs, **kwargs
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Instantiate a OpenAIGPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
Instantiate a OpenAIGPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
||||||
Download and cache the pre-trained model file if needed.
|
Download and cache the pre-trained model file if needed.
|
||||||
@@ -348,6 +355,10 @@ class OpenAIGPTPreTrainedModel(nn.Module):
|
|||||||
- a path or url to a pretrained model archive containing:
|
- a path or url to a pretrained model archive containing:
|
||||||
. `openai_gpt_config.json` a configuration file for the model
|
. `openai_gpt_config.json` a configuration file for the model
|
||||||
. `pytorch_model.bin` a PyTorch dump of a OpenAIGPTModel instance
|
. `pytorch_model.bin` a PyTorch dump of a OpenAIGPTModel instance
|
||||||
|
- a path or url to a pretrained model archive containing:
|
||||||
|
. `bert_config.json` a configuration file for the model
|
||||||
|
. a series of NumPy files containing OpenAI TensorFlow trained weights
|
||||||
|
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
|
||||||
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
|
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
|
||||||
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
|
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
|
||||||
*inputs, **kwargs: additional input for the specific Bert class
|
*inputs, **kwargs: additional input for the specific Bert class
|
||||||
@@ -365,24 +376,22 @@ class OpenAIGPTPreTrainedModel(nn.Module):
|
|||||||
"Model name '{}' was not found in model name list ({}). "
|
"Model name '{}' was not found in model name list ({}). "
|
||||||
"We assumed '{}' was a path or url but couldn't find any file "
|
"We assumed '{}' was a path or url but couldn't find any file "
|
||||||
"associated to this path or url.".format(
|
"associated to this path or url.".format(
|
||||||
pretrained_model_name,
|
pretrained_model_name, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), archive_file
|
||||||
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
|
)
|
||||||
archive_file))
|
)
|
||||||
return None
|
return None
|
||||||
if resolved_archive_file == archive_file:
|
if resolved_archive_file == archive_file:
|
||||||
logger.info("loading archive file {}".format(archive_file))
|
logger.info("loading archive file {}".format(archive_file))
|
||||||
else:
|
else:
|
||||||
logger.info("loading archive file {} from cache at {}".format(
|
logger.info("loading archive file {} from cache at {}".format(archive_file, resolved_archive_file))
|
||||||
archive_file, resolved_archive_file))
|
|
||||||
tempdir = None
|
tempdir = None
|
||||||
if os.path.isdir(resolved_archive_file):
|
if os.path.isdir(resolved_archive_file):
|
||||||
serialization_dir = resolved_archive_file
|
serialization_dir = resolved_archive_file
|
||||||
else:
|
else:
|
||||||
# Extract archive to temp dir
|
# Extract archive to temp dir
|
||||||
tempdir = tempfile.mkdtemp()
|
tempdir = tempfile.mkdtemp()
|
||||||
logger.info("extracting archive file {} to temp dir {}".format(
|
logger.info("extracting archive file {} to temp dir {}".format(resolved_archive_file, tempdir))
|
||||||
resolved_archive_file, tempdir))
|
with tarfile.open(resolved_archive_file, "r:gz") as archive:
|
||||||
with tarfile.open(resolved_archive_file, 'r:gz') as archive:
|
|
||||||
archive.extractall(tempdir)
|
archive.extractall(tempdir)
|
||||||
serialization_dir = tempdir
|
serialization_dir = tempdir
|
||||||
# Load config
|
# Load config
|
||||||
@@ -391,18 +400,24 @@ class OpenAIGPTPreTrainedModel(nn.Module):
|
|||||||
logger.info("Model config {}".format(config))
|
logger.info("Model config {}".format(config))
|
||||||
# Instantiate model.
|
# Instantiate model.
|
||||||
model = cls(config, *inputs, **kwargs)
|
model = cls(config, *inputs, **kwargs)
|
||||||
if state_dict is None:
|
if state_dict is None and not from_tf:
|
||||||
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
|
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
|
||||||
state_dict = torch.load(weights_path)
|
state_dict = torch.load(weights_path, map_location='cpu' if not torch.cuda.is_available() else None)
|
||||||
|
if tempdir:
|
||||||
|
# Clean up temp dir
|
||||||
|
shutil.rmtree(tempdir)
|
||||||
|
if from_tf:
|
||||||
|
# Directly load from a TensorFlow checkpoint (stored as NumPy array)
|
||||||
|
return load_tf_weights_in_openai_gpt(model, serialization_dir)
|
||||||
|
|
||||||
old_keys = []
|
old_keys = []
|
||||||
new_keys = []
|
new_keys = []
|
||||||
for key in state_dict.keys():
|
for key in state_dict.keys():
|
||||||
new_key = None
|
new_key = None
|
||||||
if 'gamma' in key:
|
if "gamma" in key:
|
||||||
new_key = key.replace('gamma', 'weight')
|
new_key = key.replace("gamma", "weight")
|
||||||
if 'beta' in key:
|
if "beta" in key:
|
||||||
new_key = key.replace('beta', 'bias')
|
new_key = key.replace("beta", "bias")
|
||||||
if new_key:
|
if new_key:
|
||||||
old_keys.append(key)
|
old_keys.append(key)
|
||||||
new_keys.append(new_key)
|
new_keys.append(new_key)
|
||||||
@@ -413,34 +428,36 @@ class OpenAIGPTPreTrainedModel(nn.Module):
|
|||||||
unexpected_keys = []
|
unexpected_keys = []
|
||||||
error_msgs = []
|
error_msgs = []
|
||||||
# copy state_dict so _load_from_state_dict can modify it
|
# copy state_dict so _load_from_state_dict can modify it
|
||||||
metadata = getattr(state_dict, '_metadata', None)
|
metadata = getattr(state_dict, "_metadata", None)
|
||||||
state_dict = state_dict.copy()
|
state_dict = state_dict.copy()
|
||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
state_dict._metadata = metadata
|
state_dict._metadata = metadata
|
||||||
|
|
||||||
def load(module, prefix=''):
|
def load(module, prefix=""):
|
||||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||||
module._load_from_state_dict(
|
module._load_from_state_dict(
|
||||||
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
|
||||||
|
)
|
||||||
for name, child in module._modules.items():
|
for name, child in module._modules.items():
|
||||||
if child is not None:
|
if child is not None:
|
||||||
load(child, prefix + name + '.')
|
load(child, prefix + name + ".")
|
||||||
load(model.transformer if hasattr(model, 'transformer') else model, prefix='')
|
|
||||||
|
load(model.transformer if hasattr(model, "transformer") else model, prefix="")
|
||||||
if len(missing_keys) > 0:
|
if len(missing_keys) > 0:
|
||||||
logger.info("Weights of {} not initialized from pretrained model: {}".format(
|
logger.info(
|
||||||
model.__class__.__name__, missing_keys))
|
"Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, missing_keys)
|
||||||
|
)
|
||||||
if len(unexpected_keys) > 0:
|
if len(unexpected_keys) > 0:
|
||||||
logger.info("Weights from pretrained model not used in {}: {}".format(
|
logger.info(
|
||||||
model.__class__.__name__, unexpected_keys))
|
"Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, unexpected_keys)
|
||||||
|
)
|
||||||
if len(error_msgs) > 0:
|
if len(error_msgs) > 0:
|
||||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
raise RuntimeError(
|
||||||
model.__class__.__name__, "\n\t".join(error_msgs)))
|
"Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
|
||||||
|
)
|
||||||
# Add additional embeddings for special tokens if needed
|
# Add additional embeddings for special tokens if needed
|
||||||
if num_special_tokens != config.n_special:
|
if num_special_tokens is not None and num_special_tokens != config.n_special:
|
||||||
model.set_num_special_tokens(num_special_tokens)
|
model.set_num_special_tokens(num_special_tokens)
|
||||||
if tempdir:
|
|
||||||
# Clean up temp dir
|
|
||||||
shutil.rmtree(tempdir)
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@@ -495,6 +512,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
|||||||
hidden_states = model(input_ids)
|
hidden_states = model(input_ids)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(OpenAIGPTModel, self).__init__(config)
|
super(OpenAIGPTModel, self).__init__(config)
|
||||||
total_embeddings_size = config.vocab_size + config.n_special + config.n_ctx
|
total_embeddings_size = config.vocab_size + config.n_special + config.n_ctx
|
||||||
@@ -544,6 +562,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
|||||||
hidden_states = block(hidden_states)
|
hidden_states = block(hidden_states)
|
||||||
return hidden_states.view(*input_shape, hidden_states.size(-1))
|
return hidden_states.view(*input_shape, hidden_states.size(-1))
|
||||||
|
|
||||||
|
|
||||||
class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
||||||
"""OpenAI GPT model with a Language Modeling head ("Improving Language Understanding by Generative Pre-Training").
|
"""OpenAI GPT model with a Language Modeling head ("Improving Language Understanding by Generative Pre-Training").
|
||||||
|
|
||||||
@@ -602,6 +621,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
|||||||
lm_logits = model(input_ids)
|
lm_logits = model(input_ids)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(OpenAIGPTLMHeadModel, self).__init__(config)
|
super(OpenAIGPTLMHeadModel, self).__init__(config)
|
||||||
self.transformer = OpenAIGPTModel(config)
|
self.transformer = OpenAIGPTModel(config)
|
||||||
@@ -622,6 +642,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
|||||||
return loss
|
return loss
|
||||||
return lm_logits
|
return lm_logits
|
||||||
|
|
||||||
|
|
||||||
class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
||||||
"""OpenAI GPT model with a Language Modeling and a Multiple Choice heads ("Improving Language Understanding by Generative Pre-Training").
|
"""OpenAI GPT model with a Language Modeling and a Multiple Choice heads ("Improving Language Understanding by Generative Pre-Training").
|
||||||
|
|
||||||
@@ -653,7 +674,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
|||||||
Inputs:
|
Inputs:
|
||||||
`input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
|
`input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
|
||||||
with the word BPE token indices selected in the range [0, config.vocab_size[
|
with the word BPE token indices selected in the range [0, config.vocab_size[
|
||||||
`multiple_choice_token_mask`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
|
`mc_token_mask`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
|
||||||
with a value of 1 were the last hidden state is (usually the [CLS] token) and 0 otherwise.
|
with a value of 1 were the last hidden state is (usually the [CLS] token) and 0 otherwise.
|
||||||
`position_ids`: an optional torch.LongTensor with the same shape as input_ids
|
`position_ids`: an optional torch.LongTensor with the same shape as input_ids
|
||||||
with the position indices (selected in the range [config.vocab_size + config.n_special,
|
with the position indices (selected in the range [config.vocab_size + config.n_special,
|
||||||
@@ -678,14 +699,15 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
|||||||
```python
|
```python
|
||||||
# Already been converted into BPE token ids
|
# Already been converted into BPE token ids
|
||||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||||
multiple_choice_token_mask = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
mc_token_mask = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
||||||
|
|
||||||
config = modeling_openai.OpenAIGPTConfig()
|
config = modeling_openai.OpenAIGPTConfig()
|
||||||
|
|
||||||
model = modeling_openai.OpenAIGPTLMHeadModel(config)
|
model = modeling_openai.OpenAIGPTLMHeadModel(config)
|
||||||
lm_logits, multiple_choice_logits = model(input_ids, multiple_choice_token_mask)
|
lm_logits, multiple_choice_logits = model(input_ids, mc_token_mask)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(OpenAIGPTDoubleHeadsModel, self).__init__(config)
|
super(OpenAIGPTDoubleHeadsModel, self).__init__(config)
|
||||||
self.transformer = OpenAIGPTModel(config)
|
self.transformer = OpenAIGPTModel(config)
|
||||||
@@ -698,18 +720,17 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
|||||||
self.transformer.set_num_special_tokens(num_special_tokens)
|
self.transformer.set_num_special_tokens(num_special_tokens)
|
||||||
self.lm_head.set_embeddings_weights(self.transformer.embed.weight)
|
self.lm_head.set_embeddings_weights(self.transformer.embed.weight)
|
||||||
|
|
||||||
def forward(self, input_ids, multiple_choice_token_mask, position_ids=None, token_type_ids=None,
|
def forward(self, input_ids, mc_token_mask, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None):
|
||||||
lm_labels=None, multiple_choice_labels=None):
|
|
||||||
hidden_states = self.transformer(input_ids, position_ids, token_type_ids)
|
hidden_states = self.transformer(input_ids, position_ids, token_type_ids)
|
||||||
lm_logits = self.lm_head(hidden_states)
|
lm_logits = self.lm_head(hidden_states)
|
||||||
multiple_choice_logits = self.multiple_choice_head(hidden_states, multiple_choice_token_mask)
|
mc_logits = self.multiple_choice_head(hidden_states, mc_token_mask)
|
||||||
losses = []
|
losses = []
|
||||||
if lm_labels is not None:
|
if lm_labels is not None:
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||||
losses.append(loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1)))
|
losses.append(loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1)))
|
||||||
if multiple_choice_labels is not None:
|
if mc_labels is not None:
|
||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
losses.append(loss_fct(multiple_choice_logits, multiple_choice_labels.view(-1)))
|
losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)))
|
||||||
if losses:
|
if losses:
|
||||||
return losses
|
return losses
|
||||||
return lm_logits, multiple_choice_logits
|
return lm_logits, mc_logits
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ from torch.nn.parameter import Parameter
|
|||||||
from .modeling import BertLayerNorm as LayerNorm
|
from .modeling import BertLayerNorm as LayerNorm
|
||||||
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits
|
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits
|
||||||
from .file_utils import cached_path
|
from .file_utils import cached_path
|
||||||
|
from .convert_transfo_xl_checkpoint_to_pytorch import load_tf_weights_in_transfo_xl
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -48,6 +49,7 @@ PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
|||||||
}
|
}
|
||||||
CONFIG_NAME = 'transfo_xl_config.json'
|
CONFIG_NAME = 'transfo_xl_config.json'
|
||||||
WEIGHTS_NAME = 'pytorch_model.bin'
|
WEIGHTS_NAME = 'pytorch_model.bin'
|
||||||
|
TF_WEIGHTS_NAME = 'model.ckpt'
|
||||||
|
|
||||||
class TransfoXLConfig(object):
|
class TransfoXLConfig(object):
|
||||||
"""Configuration class to store the configuration of a `TransfoXLModel`.
|
"""Configuration class to store the configuration of a `TransfoXLModel`.
|
||||||
@@ -749,7 +751,7 @@ class TransfoXLPreTrainedModel(nn.Module):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None,
|
def from_pretrained(cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None,
|
||||||
*inputs, **kwargs):
|
from_tf=False, *inputs, **kwargs):
|
||||||
"""
|
"""
|
||||||
Instantiate a TransfoXLPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
Instantiate a TransfoXLPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
||||||
Download and cache the pre-trained model file if needed.
|
Download and cache the pre-trained model file if needed.
|
||||||
@@ -761,6 +763,10 @@ class TransfoXLPreTrainedModel(nn.Module):
|
|||||||
- a path or url to a pretrained model archive containing:
|
- a path or url to a pretrained model archive containing:
|
||||||
. `transfo_xl_config.json` a configuration file for the model
|
. `transfo_xl_config.json` a configuration file for the model
|
||||||
. `pytorch_model.bin` a PyTorch dump of a TransfoXLModel instance
|
. `pytorch_model.bin` a PyTorch dump of a TransfoXLModel instance
|
||||||
|
- a path or url to a pretrained model archive containing:
|
||||||
|
. `bert_config.json` a configuration file for the model
|
||||||
|
. `model.chkpt` a TensorFlow checkpoint
|
||||||
|
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
|
||||||
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
|
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
|
||||||
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
|
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
|
||||||
*inputs, **kwargs: additional input for the specific Bert class
|
*inputs, **kwargs: additional input for the specific Bert class
|
||||||
@@ -799,9 +805,12 @@ class TransfoXLPreTrainedModel(nn.Module):
|
|||||||
logger.info("Model config {}".format(config))
|
logger.info("Model config {}".format(config))
|
||||||
# Instantiate model.
|
# Instantiate model.
|
||||||
model = cls(config, *inputs, **kwargs)
|
model = cls(config, *inputs, **kwargs)
|
||||||
if state_dict is None:
|
if state_dict is None and not from_tf:
|
||||||
state_dict = torch.load(resolved_archive_file)
|
state_dict = torch.load(resolved_archive_file)
|
||||||
|
if from_tf:
|
||||||
|
# Directly load from a TensorFlow checkpoint
|
||||||
|
weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME)
|
||||||
|
return load_tf_weights_in_transfo_xl(model, weights_path)
|
||||||
missing_keys = []
|
missing_keys = []
|
||||||
unexpected_keys = []
|
unexpected_keys = []
|
||||||
error_msgs = []
|
error_msgs = []
|
||||||
|
|||||||
@@ -130,6 +130,9 @@ class OpenAIGPTTokenizer(object):
|
|||||||
else:
|
else:
|
||||||
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
|
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.encoder) + len(self.special_tokens)
|
||||||
|
|
||||||
def set_special_tokens(self, special_tokens):
|
def set_special_tokens(self, special_tokens):
|
||||||
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
|
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user