directly load from TF checkpoints + code cleanup
This commit is contained in:
@@ -2,6 +2,7 @@ __version__ = "0.5.0"
|
||||
from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer
|
||||
from .tokenization_openai import OpenAIGPTTokenizer
|
||||
from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus)
|
||||
|
||||
from .modeling import (BertConfig, BertModel, BertForPreTraining,
|
||||
BertForMaskedLM, BertForNextSentencePrediction,
|
||||
BertForSequenceClassification, BertForMultipleChoice,
|
||||
@@ -9,6 +10,11 @@ from .modeling import (BertConfig, BertModel, BertForPreTraining,
|
||||
from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel,
|
||||
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel)
|
||||
from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel)
|
||||
|
||||
from .optimization import BertAdam
|
||||
from .optimization_openai import OpenAIAdam
|
||||
|
||||
from .convert_openai_checkpoint_to_pytorch import load_tf_weights_in_openai_gpt
|
||||
from .convert_tf_checkpoint_to_pytorch import load_tf_weights_in_bert
|
||||
from .convert_transfo_xl_checkpoint_to_pytorch import load_tf_weights_in_transfo_xl
|
||||
from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||
|
||||
@@ -26,9 +26,29 @@ import numpy as np
|
||||
|
||||
from .modeling_openai import OpenAIGPTConfig, OpenAIGPTModel, CONFIG_NAME, WEIGHTS_NAME
|
||||
|
||||
|
||||
def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path):
|
||||
# Load weights from TF model
|
||||
# Construct model
|
||||
if openai_config_file == "":
|
||||
config = OpenAIGPTConfig()
|
||||
else:
|
||||
config = OpenAIGPTConfig(openai_config_file)
|
||||
model = OpenAIGPTModel(config)
|
||||
|
||||
# Load weights from numpy
|
||||
load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path)
|
||||
|
||||
# Save pytorch-model
|
||||
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
|
||||
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
|
||||
print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
|
||||
torch.save(model.state_dict(), pytorch_weights_dump_path)
|
||||
print("Save configuration file to {}".format(pytorch_config_dump_path))
|
||||
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
|
||||
f.write(config.to_json_string())
|
||||
|
||||
def load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path):
|
||||
""" Load tf pre-trained weights in a pytorch model (from NumPy arrays here)
|
||||
"""
|
||||
print("Loading weights...")
|
||||
names = json.load(open(openai_checkpoint_folder_path + '/parameters_names.json', "r", encoding='utf-8'))
|
||||
shapes = json.load(open(openai_checkpoint_folder_path + '/params_shapes.json', "r", encoding='utf-8'))
|
||||
@@ -36,35 +56,11 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c
|
||||
init_params = [np.load(openai_checkpoint_folder_path + '/params_{}.npy'.format(n)) for n in range(10)]
|
||||
init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
|
||||
init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]
|
||||
# if n_ctx > 0:
|
||||
# init_params[0] = init_params[0][:n_ctx]
|
||||
# if n_special > 0:
|
||||
# init_params[0] = np.concatenate(
|
||||
# [init_params[1],
|
||||
# (np.random.randn(n_special, n_embd) * 0.02).astype(np.float32),
|
||||
# init_params[0]
|
||||
# ], 0)
|
||||
# else:
|
||||
# init_params[0] = np.concatenate(
|
||||
# [init_params[1],
|
||||
# init_params[0]
|
||||
# ], 0)
|
||||
# del init_params[1]
|
||||
# if n_transfer == -1:
|
||||
# n_transfer = 0
|
||||
# else:
|
||||
# n_transfer = 1 + n_transfer * 12
|
||||
|
||||
init_params[0] = np.concatenate([init_params[1], init_params[0]], 0)
|
||||
del init_params[1]
|
||||
init_params = [arr.squeeze() for arr in init_params]
|
||||
|
||||
# Construct model
|
||||
if openai_config_file == "":
|
||||
config = OpenAIGPTConfig()
|
||||
else:
|
||||
config = OpenAIGPTConfig(openai_config_file)
|
||||
model = OpenAIGPTModel(config)
|
||||
try:
|
||||
assert model.embed.weight.shape == init_params[0].shape
|
||||
except AssertionError as e:
|
||||
@@ -109,15 +105,7 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c
|
||||
raise
|
||||
print("Initialize PyTorch weight {}".format(name))
|
||||
pointer.data = torch.from_numpy(array)
|
||||
|
||||
# Save pytorch-model
|
||||
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
|
||||
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
|
||||
print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
|
||||
torch.save(model.state_dict(), pytorch_weights_dump_path)
|
||||
print("Save configuration file to {}".format(pytorch_config_dump_path))
|
||||
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
|
||||
f.write(config.to_json_string())
|
||||
return model
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
@@ -28,9 +28,23 @@ import numpy as np
|
||||
from .modeling import BertConfig, BertForPreTraining
|
||||
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
|
||||
config_path = os.path.abspath(bert_config_file)
|
||||
# Initialise PyTorch model
|
||||
config = BertConfig.from_json_file(bert_config_file)
|
||||
print("Building PyTorch model from configuration: {}".format(str(config)))
|
||||
model = BertForPreTraining(config)
|
||||
|
||||
# Load weights from tf checkpoint
|
||||
load_tf_weights_in_bert(model, tf_checkpoint_path)
|
||||
|
||||
# Save pytorch-model
|
||||
print("Save PyTorch model to {}".format(pytorch_dump_path))
|
||||
torch.save(model.state_dict(), pytorch_dump_path)
|
||||
|
||||
def load_tf_weights_in_bert(model, tf_checkpoint_path):
|
||||
""" Load tf checkpoints in a pytorch model
|
||||
"""
|
||||
tf_path = os.path.abspath(tf_checkpoint_path)
|
||||
print("Converting TensorFlow checkpoint from {} with config at {}".format(tf_path, config_path))
|
||||
print("Converting TensorFlow checkpoint from {}".format(tf_path))
|
||||
# Load weights from TF model
|
||||
init_vars = tf.train.list_variables(tf_path)
|
||||
names = []
|
||||
@@ -41,11 +55,6 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor
|
||||
names.append(name)
|
||||
arrays.append(array)
|
||||
|
||||
# Initialise PyTorch model
|
||||
config = BertConfig.from_json_file(bert_config_file)
|
||||
print("Building PyTorch model from configuration: {}".format(str(config)))
|
||||
model = BertForPreTraining(config)
|
||||
|
||||
for name, array in zip(names, arrays):
|
||||
name = name.split('/')
|
||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||
@@ -81,11 +90,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor
|
||||
raise
|
||||
print("Initialize PyTorch weight {}".format(name))
|
||||
pointer.data = torch.from_numpy(array)
|
||||
|
||||
# Save pytorch-model
|
||||
print("Save PyTorch model to {}".format(pytorch_dump_path))
|
||||
torch.save(model.state_dict(), pytorch_dump_path)
|
||||
|
||||
return model
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
@@ -106,7 +106,6 @@ def build_tf_to_pytorch_map(model, config):
|
||||
'transformer/r_w_bias': r_w_list})
|
||||
return tf_to_pt_map
|
||||
|
||||
|
||||
def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
|
||||
transfo_xl_config_file,
|
||||
pytorch_dump_folder_path,
|
||||
@@ -140,6 +139,20 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
|
||||
print("Building PyTorch model from configuration: {}".format(str(config)))
|
||||
model = TransfoXLModel(config)
|
||||
|
||||
model = load_tf_weights_in_transfo_xl(model, config, tf_path)
|
||||
# Save pytorch-model
|
||||
pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME)
|
||||
pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME)
|
||||
print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path)))
|
||||
torch.save(model.state_dict(), pytorch_weights_dump_path)
|
||||
print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path)))
|
||||
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
|
||||
f.write(config.to_json_string())
|
||||
|
||||
|
||||
def load_tf_weights_in_transfo_xl(model, config, tf_path):
|
||||
""" Load tf checkpoints in a pytorch model
|
||||
"""
|
||||
# Build TF to PyTorch weights loading map
|
||||
tf_to_pt_map = build_tf_to_pytorch_map(model, config)
|
||||
|
||||
@@ -183,16 +196,7 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
|
||||
tf_weights.pop(name + '/Adam_1', None)
|
||||
|
||||
print("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys())))
|
||||
|
||||
# Save pytorch-model
|
||||
pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME)
|
||||
pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME)
|
||||
print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path)))
|
||||
torch.save(model.state_dict(), pytorch_weights_dump_path)
|
||||
print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path)))
|
||||
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
|
||||
f.write(config.to_json_string())
|
||||
|
||||
return model
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
@@ -33,6 +33,7 @@ from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from .file_utils import cached_path
|
||||
from .convert_tf_checkpoint_to_pytorch import load_tf_weights_in_bert
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -47,6 +48,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
}
|
||||
CONFIG_NAME = 'bert_config.json'
|
||||
WEIGHTS_NAME = 'pytorch_model.bin'
|
||||
TF_WEIGHTS_NAME = 'model.ckpt'
|
||||
|
||||
def gelu(x):
|
||||
"""Implementation of the gelu activation function.
|
||||
@@ -445,7 +447,8 @@ class BertPreTrainedModel(nn.Module):
|
||||
module.bias.data.zero_()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs):
|
||||
def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None,
|
||||
from_tf=False, *inputs, **kwargs):
|
||||
"""
|
||||
Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
@@ -463,6 +466,10 @@ class BertPreTrainedModel(nn.Module):
|
||||
- a path or url to a pretrained model archive containing:
|
||||
. `bert_config.json` a configuration file for the model
|
||||
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
|
||||
- a path or url to a pretrained model archive containing:
|
||||
. `bert_config.json` a configuration file for the model
|
||||
. `model.chkpt` a TensorFlow checkpoint
|
||||
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
|
||||
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
|
||||
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
|
||||
*inputs, **kwargs: additional input for the specific Bert class
|
||||
@@ -490,7 +497,7 @@ class BertPreTrainedModel(nn.Module):
|
||||
logger.info("loading archive file {} from cache at {}".format(
|
||||
archive_file, resolved_archive_file))
|
||||
tempdir = None
|
||||
if os.path.isdir(resolved_archive_file):
|
||||
if os.path.isdir(resolved_archive_file) or from_tf:
|
||||
serialization_dir = resolved_archive_file
|
||||
else:
|
||||
# Extract archive to temp dir
|
||||
@@ -506,10 +513,17 @@ class BertPreTrainedModel(nn.Module):
|
||||
logger.info("Model config {}".format(config))
|
||||
# Instantiate model.
|
||||
model = cls(config, *inputs, **kwargs)
|
||||
if state_dict is None:
|
||||
if state_dict is None and not from_tf:
|
||||
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
|
||||
state_dict = torch.load(weights_path)
|
||||
|
||||
if tempdir:
|
||||
# Clean up temp dir
|
||||
shutil.rmtree(tempdir)
|
||||
if from_tf:
|
||||
# Directly load from a TensorFlow checkpoint
|
||||
weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME)
|
||||
return load_tf_weights_in_bert(model, weights_path)
|
||||
# Load from a PyTorch state_dict
|
||||
old_keys = []
|
||||
new_keys = []
|
||||
for key in state_dict.keys():
|
||||
@@ -550,9 +564,6 @@ class BertPreTrainedModel(nn.Module):
|
||||
if len(error_msgs) > 0:
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
||||
model.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
if tempdir:
|
||||
# Clean up temp dir
|
||||
shutil.rmtree(tempdir)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
@@ -32,14 +32,14 @@ from torch.nn.parameter import Parameter
|
||||
|
||||
from .modeling import BertLayerNorm as LayerNorm
|
||||
from .file_utils import cached_path
|
||||
from .convert_openai_checkpoint_to_pytorch import load_tf_weights_in_openai_gpt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt.tar.gz",
|
||||
}
|
||||
CONFIG_NAME = 'openai_gpt_config.json'
|
||||
WEIGHTS_NAME = 'pytorch_model.bin'
|
||||
PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt.tar.gz"}
|
||||
CONFIG_NAME = "openai_gpt_config.json"
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
|
||||
|
||||
def gelu(x):
|
||||
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
@@ -49,16 +49,15 @@ def swish(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
ACT_FNS = {
|
||||
'relu': nn.ReLU,
|
||||
'swish': swish,
|
||||
'gelu': gelu
|
||||
}
|
||||
ACT_FNS = {"relu": nn.ReLU, "swish": swish, "gelu": gelu}
|
||||
|
||||
|
||||
class OpenAIGPTConfig(object):
|
||||
"""Configuration class to store the configuration of a `OpenAIGPTModel`.
|
||||
"""
|
||||
def __init__(self,
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size_or_config_json_file=40478,
|
||||
n_special=0,
|
||||
n_ctx=512,
|
||||
@@ -69,7 +68,8 @@ class OpenAIGPTConfig(object):
|
||||
resid_pdrop=0.1,
|
||||
embd_pdrop=0.1,
|
||||
attn_pdrop=0.1,
|
||||
initializer_range=0.02):
|
||||
initializer_range=0.02,
|
||||
):
|
||||
"""Constructs OpenAIGPTConfig.
|
||||
|
||||
Args:
|
||||
@@ -91,7 +91,7 @@ class OpenAIGPTConfig(object):
|
||||
initializing all weight matrices.
|
||||
"""
|
||||
if isinstance(vocab_size_or_config_json_file, str):
|
||||
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
|
||||
with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
|
||||
json_config = json.loads(reader.read())
|
||||
for key, value in json_config.items():
|
||||
self.__dict__[key] = value
|
||||
@@ -108,8 +108,10 @@ class OpenAIGPTConfig(object):
|
||||
self.attn_pdrop = attn_pdrop
|
||||
self.initializer_range = initializer_range
|
||||
else:
|
||||
raise ValueError("First argument must be either a vocabulary size (int)"
|
||||
"or the path to a pretrained model config file (str)")
|
||||
raise ValueError(
|
||||
"First argument must be either a vocabulary size (int)"
|
||||
"or the path to a pretrained model config file (str)"
|
||||
)
|
||||
|
||||
@property
|
||||
def total_num_embeddings(self):
|
||||
@@ -126,7 +128,7 @@ class OpenAIGPTConfig(object):
|
||||
@classmethod
|
||||
def from_json_file(cls, json_file):
|
||||
"""Constructs a `OpenAIGPTConfig` from a json file of parameters."""
|
||||
with open(json_file, "r", encoding='utf-8') as reader:
|
||||
with open(json_file, "r", encoding="utf-8") as reader:
|
||||
text = reader.read()
|
||||
return cls.from_dict(json.loads(text))
|
||||
|
||||
@@ -142,6 +144,7 @@ class OpenAIGPTConfig(object):
|
||||
"""Serializes this instance to a JSON string."""
|
||||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
||||
|
||||
|
||||
class Conv1D(nn.Module):
|
||||
def __init__(self, nf, rf, nx):
|
||||
super(Conv1D, self).__init__()
|
||||
@@ -171,7 +174,7 @@ class Attention(nn.Module):
|
||||
n_state = nx # in Attention: n_state=768 (nx=n_embd)
|
||||
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
|
||||
assert n_state % config.n_head == 0
|
||||
self.register_buffer('b', torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
|
||||
self.register_buffer("b", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
|
||||
self.n_head = config.n_head
|
||||
self.split_size = n_state
|
||||
self.scale = scale
|
||||
@@ -284,11 +287,12 @@ class OpenAIGPTMultipleChoiceHead(nn.Module):
|
||||
nn.init.normal_(self.linear.weight, std=0.02)
|
||||
nn.init.normal_(self.linear.bias, 0)
|
||||
|
||||
def forward(self, hidden_states, multiple_choice_token_mask):
|
||||
def forward(self, hidden_states, mc_token_mask):
|
||||
# Classification logits
|
||||
# hidden_states = hidden_states.view(-1, self.n_embd)
|
||||
# multiple_choice_token_mask = multiple_choice_token_mask.view(-1, 1).expand_as(hidden_states)
|
||||
multiple_choice_h = hidden_states * multiple_choice_token_mask.unsqueeze(-1)
|
||||
# mc_token_mask = mc_token_mask.view(-1, 1).expand_as(hidden_states)
|
||||
mc_token_mask = mc_token_mask.float()
|
||||
multiple_choice_h = hidden_states * mc_token_mask.unsqueeze(-1)
|
||||
multiple_choice_h = multiple_choice_h.sum(dim=-2)
|
||||
# flat = x[..., 0].contiguous().view(-1)
|
||||
# multiple_choice_h = multiple_choice_h[flat == self.multiple_choice_token, :]
|
||||
@@ -307,6 +311,7 @@ class OpenAIGPTPreTrainedModel(nn.Module):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super(OpenAIGPTPreTrainedModel, self).__init__()
|
||||
if not isinstance(config, OpenAIGPTConfig):
|
||||
@@ -315,7 +320,8 @@ class OpenAIGPTPreTrainedModel(nn.Module):
|
||||
"To create a model from a pretrained model use "
|
||||
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
||||
self.__class__.__name__, self.__class__.__name__
|
||||
))
|
||||
)
|
||||
)
|
||||
self.config = config
|
||||
|
||||
def init_weights(self, module):
|
||||
@@ -335,8 +341,9 @@ class OpenAIGPTPreTrainedModel(nn.Module):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name, num_special_tokens=0, state_dict=None, cache_dir=None,
|
||||
*inputs, **kwargs):
|
||||
def from_pretrained(
|
||||
cls, pretrained_model_name, num_special_tokens=None, state_dict=None, cache_dir=None, from_tf=False, *inputs, **kwargs
|
||||
):
|
||||
"""
|
||||
Instantiate a OpenAIGPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
@@ -348,6 +355,10 @@ class OpenAIGPTPreTrainedModel(nn.Module):
|
||||
- a path or url to a pretrained model archive containing:
|
||||
. `openai_gpt_config.json` a configuration file for the model
|
||||
. `pytorch_model.bin` a PyTorch dump of a OpenAIGPTModel instance
|
||||
- a path or url to a pretrained model archive containing:
|
||||
. `bert_config.json` a configuration file for the model
|
||||
. a series of NumPy files containing OpenAI TensorFlow trained weights
|
||||
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
|
||||
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
|
||||
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
|
||||
*inputs, **kwargs: additional input for the specific Bert class
|
||||
@@ -365,24 +376,22 @@ class OpenAIGPTPreTrainedModel(nn.Module):
|
||||
"Model name '{}' was not found in model name list ({}). "
|
||||
"We assumed '{}' was a path or url but couldn't find any file "
|
||||
"associated to this path or url.".format(
|
||||
pretrained_model_name,
|
||||
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
|
||||
archive_file))
|
||||
pretrained_model_name, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), archive_file
|
||||
)
|
||||
)
|
||||
return None
|
||||
if resolved_archive_file == archive_file:
|
||||
logger.info("loading archive file {}".format(archive_file))
|
||||
else:
|
||||
logger.info("loading archive file {} from cache at {}".format(
|
||||
archive_file, resolved_archive_file))
|
||||
logger.info("loading archive file {} from cache at {}".format(archive_file, resolved_archive_file))
|
||||
tempdir = None
|
||||
if os.path.isdir(resolved_archive_file):
|
||||
serialization_dir = resolved_archive_file
|
||||
else:
|
||||
# Extract archive to temp dir
|
||||
tempdir = tempfile.mkdtemp()
|
||||
logger.info("extracting archive file {} to temp dir {}".format(
|
||||
resolved_archive_file, tempdir))
|
||||
with tarfile.open(resolved_archive_file, 'r:gz') as archive:
|
||||
logger.info("extracting archive file {} to temp dir {}".format(resolved_archive_file, tempdir))
|
||||
with tarfile.open(resolved_archive_file, "r:gz") as archive:
|
||||
archive.extractall(tempdir)
|
||||
serialization_dir = tempdir
|
||||
# Load config
|
||||
@@ -391,18 +400,24 @@ class OpenAIGPTPreTrainedModel(nn.Module):
|
||||
logger.info("Model config {}".format(config))
|
||||
# Instantiate model.
|
||||
model = cls(config, *inputs, **kwargs)
|
||||
if state_dict is None:
|
||||
if state_dict is None and not from_tf:
|
||||
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
|
||||
state_dict = torch.load(weights_path)
|
||||
state_dict = torch.load(weights_path, map_location='cpu' if not torch.cuda.is_available() else None)
|
||||
if tempdir:
|
||||
# Clean up temp dir
|
||||
shutil.rmtree(tempdir)
|
||||
if from_tf:
|
||||
# Directly load from a TensorFlow checkpoint (stored as NumPy array)
|
||||
return load_tf_weights_in_openai_gpt(model, serialization_dir)
|
||||
|
||||
old_keys = []
|
||||
new_keys = []
|
||||
for key in state_dict.keys():
|
||||
new_key = None
|
||||
if 'gamma' in key:
|
||||
new_key = key.replace('gamma', 'weight')
|
||||
if 'beta' in key:
|
||||
new_key = key.replace('beta', 'bias')
|
||||
if "gamma" in key:
|
||||
new_key = key.replace("gamma", "weight")
|
||||
if "beta" in key:
|
||||
new_key = key.replace("beta", "bias")
|
||||
if new_key:
|
||||
old_keys.append(key)
|
||||
new_keys.append(new_key)
|
||||
@@ -413,34 +428,36 @@ class OpenAIGPTPreTrainedModel(nn.Module):
|
||||
unexpected_keys = []
|
||||
error_msgs = []
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
metadata = getattr(state_dict, '_metadata', None)
|
||||
metadata = getattr(state_dict, "_metadata", None)
|
||||
state_dict = state_dict.copy()
|
||||
if metadata is not None:
|
||||
state_dict._metadata = metadata
|
||||
|
||||
def load(module, prefix=''):
|
||||
def load(module, prefix=""):
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
module._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
||||
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
|
||||
)
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, prefix + name + '.')
|
||||
load(model.transformer if hasattr(model, 'transformer') else model, prefix='')
|
||||
load(child, prefix + name + ".")
|
||||
|
||||
load(model.transformer if hasattr(model, "transformer") else model, prefix="")
|
||||
if len(missing_keys) > 0:
|
||||
logger.info("Weights of {} not initialized from pretrained model: {}".format(
|
||||
model.__class__.__name__, missing_keys))
|
||||
logger.info(
|
||||
"Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, missing_keys)
|
||||
)
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.info("Weights from pretrained model not used in {}: {}".format(
|
||||
model.__class__.__name__, unexpected_keys))
|
||||
logger.info(
|
||||
"Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, unexpected_keys)
|
||||
)
|
||||
if len(error_msgs) > 0:
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
||||
model.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
raise RuntimeError(
|
||||
"Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
|
||||
)
|
||||
# Add additional embeddings for special tokens if needed
|
||||
if num_special_tokens != config.n_special:
|
||||
if num_special_tokens is not None and num_special_tokens != config.n_special:
|
||||
model.set_num_special_tokens(num_special_tokens)
|
||||
if tempdir:
|
||||
# Clean up temp dir
|
||||
shutil.rmtree(tempdir)
|
||||
return model
|
||||
|
||||
|
||||
@@ -495,6 +512,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
||||
hidden_states = model(input_ids)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(OpenAIGPTModel, self).__init__(config)
|
||||
total_embeddings_size = config.vocab_size + config.n_special + config.n_ctx
|
||||
@@ -544,6 +562,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
||||
hidden_states = block(hidden_states)
|
||||
return hidden_states.view(*input_shape, hidden_states.size(-1))
|
||||
|
||||
|
||||
class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
||||
"""OpenAI GPT model with a Language Modeling head ("Improving Language Understanding by Generative Pre-Training").
|
||||
|
||||
@@ -602,6 +621,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
||||
lm_logits = model(input_ids)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(OpenAIGPTLMHeadModel, self).__init__(config)
|
||||
self.transformer = OpenAIGPTModel(config)
|
||||
@@ -622,6 +642,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
||||
return loss
|
||||
return lm_logits
|
||||
|
||||
|
||||
class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
||||
"""OpenAI GPT model with a Language Modeling and a Multiple Choice heads ("Improving Language Understanding by Generative Pre-Training").
|
||||
|
||||
@@ -653,7 +674,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
||||
Inputs:
|
||||
`input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
|
||||
with the word BPE token indices selected in the range [0, config.vocab_size[
|
||||
`multiple_choice_token_mask`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
|
||||
`mc_token_mask`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
|
||||
with a value of 1 were the last hidden state is (usually the [CLS] token) and 0 otherwise.
|
||||
`position_ids`: an optional torch.LongTensor with the same shape as input_ids
|
||||
with the position indices (selected in the range [config.vocab_size + config.n_special,
|
||||
@@ -678,14 +699,15 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
||||
```python
|
||||
# Already been converted into BPE token ids
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
multiple_choice_token_mask = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
||||
mc_token_mask = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
||||
|
||||
config = modeling_openai.OpenAIGPTConfig()
|
||||
|
||||
model = modeling_openai.OpenAIGPTLMHeadModel(config)
|
||||
lm_logits, multiple_choice_logits = model(input_ids, multiple_choice_token_mask)
|
||||
lm_logits, multiple_choice_logits = model(input_ids, mc_token_mask)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(OpenAIGPTDoubleHeadsModel, self).__init__(config)
|
||||
self.transformer = OpenAIGPTModel(config)
|
||||
@@ -698,18 +720,17 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
||||
self.transformer.set_num_special_tokens(num_special_tokens)
|
||||
self.lm_head.set_embeddings_weights(self.transformer.embed.weight)
|
||||
|
||||
def forward(self, input_ids, multiple_choice_token_mask, position_ids=None, token_type_ids=None,
|
||||
lm_labels=None, multiple_choice_labels=None):
|
||||
def forward(self, input_ids, mc_token_mask, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None):
|
||||
hidden_states = self.transformer(input_ids, position_ids, token_type_ids)
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
multiple_choice_logits = self.multiple_choice_head(hidden_states, multiple_choice_token_mask)
|
||||
mc_logits = self.multiple_choice_head(hidden_states, mc_token_mask)
|
||||
losses = []
|
||||
if lm_labels is not None:
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
losses.append(loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1)))
|
||||
if multiple_choice_labels is not None:
|
||||
if mc_labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
losses.append(loss_fct(multiple_choice_logits, multiple_choice_labels.view(-1)))
|
||||
losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)))
|
||||
if losses:
|
||||
return losses
|
||||
return lm_logits, multiple_choice_logits
|
||||
return lm_logits, mc_logits
|
||||
|
||||
@@ -37,6 +37,7 @@ from torch.nn.parameter import Parameter
|
||||
from .modeling import BertLayerNorm as LayerNorm
|
||||
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits
|
||||
from .file_utils import cached_path
|
||||
from .convert_transfo_xl_checkpoint_to_pytorch import load_tf_weights_in_transfo_xl
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -48,6 +49,7 @@ PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
}
|
||||
CONFIG_NAME = 'transfo_xl_config.json'
|
||||
WEIGHTS_NAME = 'pytorch_model.bin'
|
||||
TF_WEIGHTS_NAME = 'model.ckpt'
|
||||
|
||||
class TransfoXLConfig(object):
|
||||
"""Configuration class to store the configuration of a `TransfoXLModel`.
|
||||
@@ -749,7 +751,7 @@ class TransfoXLPreTrainedModel(nn.Module):
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None,
|
||||
*inputs, **kwargs):
|
||||
from_tf=False, *inputs, **kwargs):
|
||||
"""
|
||||
Instantiate a TransfoXLPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
@@ -761,6 +763,10 @@ class TransfoXLPreTrainedModel(nn.Module):
|
||||
- a path or url to a pretrained model archive containing:
|
||||
. `transfo_xl_config.json` a configuration file for the model
|
||||
. `pytorch_model.bin` a PyTorch dump of a TransfoXLModel instance
|
||||
- a path or url to a pretrained model archive containing:
|
||||
. `bert_config.json` a configuration file for the model
|
||||
. `model.chkpt` a TensorFlow checkpoint
|
||||
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
|
||||
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
|
||||
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
|
||||
*inputs, **kwargs: additional input for the specific Bert class
|
||||
@@ -799,9 +805,12 @@ class TransfoXLPreTrainedModel(nn.Module):
|
||||
logger.info("Model config {}".format(config))
|
||||
# Instantiate model.
|
||||
model = cls(config, *inputs, **kwargs)
|
||||
if state_dict is None:
|
||||
if state_dict is None and not from_tf:
|
||||
state_dict = torch.load(resolved_archive_file)
|
||||
|
||||
if from_tf:
|
||||
# Directly load from a TensorFlow checkpoint
|
||||
weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME)
|
||||
return load_tf_weights_in_transfo_xl(model, weights_path)
|
||||
missing_keys = []
|
||||
unexpected_keys = []
|
||||
error_msgs = []
|
||||
|
||||
@@ -130,6 +130,9 @@ class OpenAIGPTTokenizer(object):
|
||||
else:
|
||||
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.encoder) + len(self.special_tokens)
|
||||
|
||||
def set_special_tokens(self, special_tokens):
|
||||
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user