From d77dd62ff823d788b7e635e57b6572f204e83264 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 28 Jan 2019 16:50:23 +0100 Subject: [PATCH] directly load from TF checkpoints + code cleanup --- pytorch_pretrained_bert/__init__.py | 6 + .../convert_openai_checkpoint_to_pytorch.py | 58 +++--- .../convert_tf_checkpoint_to_pytorch.py | 29 +-- ...onvert_transfo_xl_checkpoint_to_pytorch.py | 94 +++++----- pytorch_pretrained_bert/modeling.py | 25 ++- pytorch_pretrained_bert/modeling_openai.py | 173 ++++++++++-------- .../modeling_transfo_xl.py | 15 +- .../tokenization_openai.py | 3 + 8 files changed, 225 insertions(+), 178 deletions(-) diff --git a/pytorch_pretrained_bert/__init__.py b/pytorch_pretrained_bert/__init__.py index 85f2422af6..249607bded 100644 --- a/pytorch_pretrained_bert/__init__.py +++ b/pytorch_pretrained_bert/__init__.py @@ -2,6 +2,7 @@ __version__ = "0.5.0" from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer from .tokenization_openai import OpenAIGPTTokenizer from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus) + from .modeling import (BertConfig, BertModel, BertForPreTraining, BertForMaskedLM, BertForNextSentencePrediction, BertForSequenceClassification, BertForMultipleChoice, @@ -9,6 +10,11 @@ from .modeling import (BertConfig, BertModel, BertForPreTraining, from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel) from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel) + from .optimization import BertAdam from .optimization_openai import OpenAIAdam + +from .convert_openai_checkpoint_to_pytorch import load_tf_weights_in_openai_gpt +from .convert_tf_checkpoint_to_pytorch import load_tf_weights_in_bert +from .convert_transfo_xl_checkpoint_to_pytorch import load_tf_weights_in_transfo_xl from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE diff --git a/pytorch_pretrained_bert/convert_openai_checkpoint_to_pytorch.py b/pytorch_pretrained_bert/convert_openai_checkpoint_to_pytorch.py index 0c41741d9a..40740083d0 100755 --- a/pytorch_pretrained_bert/convert_openai_checkpoint_to_pytorch.py +++ b/pytorch_pretrained_bert/convert_openai_checkpoint_to_pytorch.py @@ -26,9 +26,29 @@ import numpy as np from .modeling_openai import OpenAIGPTConfig, OpenAIGPTModel, CONFIG_NAME, WEIGHTS_NAME - def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path): - # Load weights from TF model + # Construct model + if openai_config_file == "": + config = OpenAIGPTConfig() + else: + config = OpenAIGPTConfig(openai_config_file) + model = OpenAIGPTModel(config) + + # Load weights from numpy + load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path) + + # Save pytorch-model + pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME + pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME + print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) + torch.save(model.state_dict(), pytorch_weights_dump_path) + print("Save configuration file to {}".format(pytorch_config_dump_path)) + with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: + f.write(config.to_json_string()) + +def load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path): + """ Load tf pre-trained weights in a pytorch model (from NumPy arrays here) + """ print("Loading weights...") names = json.load(open(openai_checkpoint_folder_path + '/parameters_names.json', "r", encoding='utf-8')) shapes = json.load(open(openai_checkpoint_folder_path + '/params_shapes.json', "r", encoding='utf-8')) @@ -36,35 +56,11 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c init_params = [np.load(openai_checkpoint_folder_path + '/params_{}.npy'.format(n)) for n in range(10)] init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1] init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)] - # if n_ctx > 0: - # init_params[0] = init_params[0][:n_ctx] - # if n_special > 0: - # init_params[0] = np.concatenate( - # [init_params[1], - # (np.random.randn(n_special, n_embd) * 0.02).astype(np.float32), - # init_params[0] - # ], 0) - # else: - # init_params[0] = np.concatenate( - # [init_params[1], - # init_params[0] - # ], 0) - # del init_params[1] - # if n_transfer == -1: - # n_transfer = 0 - # else: - # n_transfer = 1 + n_transfer * 12 init_params[0] = np.concatenate([init_params[1], init_params[0]], 0) del init_params[1] init_params = [arr.squeeze() for arr in init_params] - # Construct model - if openai_config_file == "": - config = OpenAIGPTConfig() - else: - config = OpenAIGPTConfig(openai_config_file) - model = OpenAIGPTModel(config) try: assert model.embed.weight.shape == init_params[0].shape except AssertionError as e: @@ -109,15 +105,7 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c raise print("Initialize PyTorch weight {}".format(name)) pointer.data = torch.from_numpy(array) - - # Save pytorch-model - pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME - pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME - print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) - torch.save(model.state_dict(), pytorch_weights_dump_path) - print("Save configuration file to {}".format(pytorch_config_dump_path)) - with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: - f.write(config.to_json_string()) + return model if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py b/pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py index 120624bc1b..74622bbb70 100755 --- a/pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py +++ b/pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py @@ -28,9 +28,23 @@ import numpy as np from .modeling import BertConfig, BertForPreTraining def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): - config_path = os.path.abspath(bert_config_file) + # Initialise PyTorch model + config = BertConfig.from_json_file(bert_config_file) + print("Building PyTorch model from configuration: {}".format(str(config))) + model = BertForPreTraining(config) + + # Load weights from tf checkpoint + load_tf_weights_in_bert(model, tf_checkpoint_path) + + # Save pytorch-model + print("Save PyTorch model to {}".format(pytorch_dump_path)) + torch.save(model.state_dict(), pytorch_dump_path) + +def load_tf_weights_in_bert(model, tf_checkpoint_path): + """ Load tf checkpoints in a pytorch model + """ tf_path = os.path.abspath(tf_checkpoint_path) - print("Converting TensorFlow checkpoint from {} with config at {}".format(tf_path, config_path)) + print("Converting TensorFlow checkpoint from {}".format(tf_path)) # Load weights from TF model init_vars = tf.train.list_variables(tf_path) names = [] @@ -41,11 +55,6 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor names.append(name) arrays.append(array) - # Initialise PyTorch model - config = BertConfig.from_json_file(bert_config_file) - print("Building PyTorch model from configuration: {}".format(str(config))) - model = BertForPreTraining(config) - for name, array in zip(names, arrays): name = name.split('/') # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v @@ -81,11 +90,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor raise print("Initialize PyTorch weight {}".format(name)) pointer.data = torch.from_numpy(array) - - # Save pytorch-model - print("Save PyTorch model to {}".format(pytorch_dump_path)) - torch.save(model.state_dict(), pytorch_dump_path) - + return model if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py b/pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py index eb6b8183ef..4dbb2067d6 100755 --- a/pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py +++ b/pytorch_pretrained_bert/convert_transfo_xl_checkpoint_to_pytorch.py @@ -106,7 +106,6 @@ def build_tf_to_pytorch_map(model, config): 'transformer/r_w_bias': r_w_list}) return tf_to_pt_map - def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, transfo_xl_config_file, pytorch_dump_folder_path, @@ -140,50 +139,7 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, print("Building PyTorch model from configuration: {}".format(str(config))) model = TransfoXLModel(config) - # Build TF to PyTorch weights loading map - tf_to_pt_map = build_tf_to_pytorch_map(model, config) - - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - tf_weights = {} - for name, shape in init_vars: - print("Loading TF weight {} with shape {}".format(name, shape)) - array = tf.train.load_variable(tf_path, name) - tf_weights[name] = array - - for name, pointer in tf_to_pt_map.items(): - assert name in tf_weights - array = tf_weights[name] - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - if 'kernel' in name or 'proj' in name: - array = np.transpose(array) - if ('r_r_bias' in name or 'r_w_bias' in name) and len(pointer) > 1: - # Here we will split the TF weigths - assert len(pointer) == array.shape[0] - for i, p_i in enumerate(pointer): - arr_i = array[i, ...] - try: - assert p_i.shape == arr_i.shape - except AssertionError as e: - e.args += (p_i.shape, arr_i.shape) - raise - print("Initialize PyTorch weight {} for layer {}".format(name, i)) - p_i.data = torch.from_numpy(arr_i) - else: - try: - assert pointer.shape == array.shape - except AssertionError as e: - e.args += (pointer.shape, array.shape) - raise - print("Initialize PyTorch weight {}".format(name)) - pointer.data = torch.from_numpy(array) - tf_weights.pop(name, None) - tf_weights.pop(name + '/Adam', None) - tf_weights.pop(name + '/Adam_1', None) - - print("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys()))) - + model = load_tf_weights_in_transfo_xl(model, config, tf_path) # Save pytorch-model pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME) @@ -194,6 +150,54 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, f.write(config.to_json_string()) +def load_tf_weights_in_transfo_xl(model, config, tf_path): + """ Load tf checkpoints in a pytorch model + """ + # Build TF to PyTorch weights loading map + tf_to_pt_map = build_tf_to_pytorch_map(model, config) + + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + tf_weights = {} + for name, shape in init_vars: + print("Loading TF weight {} with shape {}".format(name, shape)) + array = tf.train.load_variable(tf_path, name) + tf_weights[name] = array + + for name, pointer in tf_to_pt_map.items(): + assert name in tf_weights + array = tf_weights[name] + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if 'kernel' in name or 'proj' in name: + array = np.transpose(array) + if ('r_r_bias' in name or 'r_w_bias' in name) and len(pointer) > 1: + # Here we will split the TF weigths + assert len(pointer) == array.shape[0] + for i, p_i in enumerate(pointer): + arr_i = array[i, ...] + try: + assert p_i.shape == arr_i.shape + except AssertionError as e: + e.args += (p_i.shape, arr_i.shape) + raise + print("Initialize PyTorch weight {} for layer {}".format(name, i)) + p_i.data = torch.from_numpy(arr_i) + else: + try: + assert pointer.shape == array.shape + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + print("Initialize PyTorch weight {}".format(name)) + pointer.data = torch.from_numpy(array) + tf_weights.pop(name, None) + tf_weights.pop(name + '/Adam', None) + tf_weights.pop(name + '/Adam_1', None) + + print("Weights not copied to PyTorch model: {}".format(', '.join(tf_weights.keys()))) + return model + if __name__ == "__main__": parser = argparse.ArgumentParser() ## Required parameters diff --git a/pytorch_pretrained_bert/modeling.py b/pytorch_pretrained_bert/modeling.py index 591082f7ce..1e6966757e 100644 --- a/pytorch_pretrained_bert/modeling.py +++ b/pytorch_pretrained_bert/modeling.py @@ -33,6 +33,7 @@ from torch import nn from torch.nn import CrossEntropyLoss from .file_utils import cached_path +from .convert_tf_checkpoint_to_pytorch import load_tf_weights_in_bert logger = logging.getLogger(__name__) @@ -47,6 +48,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = { } CONFIG_NAME = 'bert_config.json' WEIGHTS_NAME = 'pytorch_model.bin' +TF_WEIGHTS_NAME = 'model.ckpt' def gelu(x): """Implementation of the gelu activation function. @@ -445,7 +447,8 @@ class BertPreTrainedModel(nn.Module): module.bias.data.zero_() @classmethod - def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs): + def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, + from_tf=False, *inputs, **kwargs): """ Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict. Download and cache the pre-trained model file if needed. @@ -463,6 +466,10 @@ class BertPreTrainedModel(nn.Module): - a path or url to a pretrained model archive containing: . `bert_config.json` a configuration file for the model . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance + - a path or url to a pretrained model archive containing: + . `bert_config.json` a configuration file for the model + . `model.chkpt` a TensorFlow checkpoint + from_tf: should we load the weights from a locally saved TensorFlow checkpoint cache_dir: an optional path to a folder in which the pre-trained models will be cached. state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models *inputs, **kwargs: additional input for the specific Bert class @@ -490,7 +497,7 @@ class BertPreTrainedModel(nn.Module): logger.info("loading archive file {} from cache at {}".format( archive_file, resolved_archive_file)) tempdir = None - if os.path.isdir(resolved_archive_file): + if os.path.isdir(resolved_archive_file) or from_tf: serialization_dir = resolved_archive_file else: # Extract archive to temp dir @@ -506,10 +513,17 @@ class BertPreTrainedModel(nn.Module): logger.info("Model config {}".format(config)) # Instantiate model. model = cls(config, *inputs, **kwargs) - if state_dict is None: + if state_dict is None and not from_tf: weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) state_dict = torch.load(weights_path) - + if tempdir: + # Clean up temp dir + shutil.rmtree(tempdir) + if from_tf: + # Directly load from a TensorFlow checkpoint + weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME) + return load_tf_weights_in_bert(model, weights_path) + # Load from a PyTorch state_dict old_keys = [] new_keys = [] for key in state_dict.keys(): @@ -550,9 +564,6 @@ class BertPreTrainedModel(nn.Module): if len(error_msgs) > 0: raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( model.__class__.__name__, "\n\t".join(error_msgs))) - if tempdir: - # Clean up temp dir - shutil.rmtree(tempdir) return model diff --git a/pytorch_pretrained_bert/modeling_openai.py b/pytorch_pretrained_bert/modeling_openai.py index c3cd165e68..cd72beba66 100644 --- a/pytorch_pretrained_bert/modeling_openai.py +++ b/pytorch_pretrained_bert/modeling_openai.py @@ -32,14 +32,14 @@ from torch.nn.parameter import Parameter from .modeling import BertLayerNorm as LayerNorm from .file_utils import cached_path +from .convert_openai_checkpoint_to_pytorch import load_tf_weights_in_openai_gpt logger = logging.getLogger(__name__) -PRETRAINED_MODEL_ARCHIVE_MAP = { - 'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt.tar.gz", -} -CONFIG_NAME = 'openai_gpt_config.json' -WEIGHTS_NAME = 'pytorch_model.bin' +PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt.tar.gz"} +CONFIG_NAME = "openai_gpt_config.json" +WEIGHTS_NAME = "pytorch_model.bin" + def gelu(x): return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) @@ -49,27 +49,27 @@ def swish(x): return x * torch.sigmoid(x) -ACT_FNS = { - 'relu': nn.ReLU, - 'swish': swish, - 'gelu': gelu -} +ACT_FNS = {"relu": nn.ReLU, "swish": swish, "gelu": gelu} + class OpenAIGPTConfig(object): """Configuration class to store the configuration of a `OpenAIGPTModel`. """ - def __init__(self, - vocab_size_or_config_json_file=40478, - n_special=0, - n_ctx=512, - n_embd=768, - n_layer=12, - n_head=12, - afn="gelu", - resid_pdrop=0.1, - embd_pdrop=0.1, - attn_pdrop=0.1, - initializer_range=0.02): + + def __init__( + self, + vocab_size_or_config_json_file=40478, + n_special=0, + n_ctx=512, + n_embd=768, + n_layer=12, + n_head=12, + afn="gelu", + resid_pdrop=0.1, + embd_pdrop=0.1, + attn_pdrop=0.1, + initializer_range=0.02, + ): """Constructs OpenAIGPTConfig. Args: @@ -91,7 +91,7 @@ class OpenAIGPTConfig(object): initializing all weight matrices. """ if isinstance(vocab_size_or_config_json_file, str): - with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: + with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader: json_config = json.loads(reader.read()) for key, value in json_config.items(): self.__dict__[key] = value @@ -108,8 +108,10 @@ class OpenAIGPTConfig(object): self.attn_pdrop = attn_pdrop self.initializer_range = initializer_range else: - raise ValueError("First argument must be either a vocabulary size (int)" - "or the path to a pretrained model config file (str)") + raise ValueError( + "First argument must be either a vocabulary size (int)" + "or the path to a pretrained model config file (str)" + ) @property def total_num_embeddings(self): @@ -126,7 +128,7 @@ class OpenAIGPTConfig(object): @classmethod def from_json_file(cls, json_file): """Constructs a `OpenAIGPTConfig` from a json file of parameters.""" - with open(json_file, "r", encoding='utf-8') as reader: + with open(json_file, "r", encoding="utf-8") as reader: text = reader.read() return cls.from_dict(json.loads(text)) @@ -142,6 +144,7 @@ class OpenAIGPTConfig(object): """Serializes this instance to a JSON string.""" return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" + class Conv1D(nn.Module): def __init__(self, nf, rf, nx): super(Conv1D, self).__init__() @@ -171,7 +174,7 @@ class Attention(nn.Module): n_state = nx # in Attention: n_state=768 (nx=n_embd) # [switch nx => n_state from Block to Attention to keep identical to TF implem] assert n_state % config.n_head == 0 - self.register_buffer('b', torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx)) + self.register_buffer("b", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx)) self.n_head = config.n_head self.split_size = n_state self.scale = scale @@ -186,7 +189,7 @@ class Attention(nn.Module): w = w / math.sqrt(v.size(-1)) # w = w * self.b + -1e9 * (1 - self.b) # TF implem method: mask_attn_weights # XD: self.b may be larger than w, so we need to crop it - b = self.b[:, :, :w.size(-2), :w.size(-1)] + b = self.b[:, :, : w.size(-2), : w.size(-1)] w = w * b + -1e9 * (1 - b) w = nn.Softmax(dim=-1)(w) @@ -262,7 +265,7 @@ class OpenAIGPTLMHead(nn.Module): def set_embeddings_weights(self, model_embeddings_weights): embed_shape = model_embeddings_weights.shape self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False) - self.decoder.weight = model_embeddings_weights # Tied weights + self.decoder.weight = model_embeddings_weights # Tied weights def forward(self, hidden_state): # Truncated Language modeling logits (we remove the last token) @@ -281,14 +284,15 @@ class OpenAIGPTMultipleChoiceHead(nn.Module): self.dropout = nn.Dropout2d(config.resid_pdrop) # To reproduce the noise_shape parameter of TF implementation self.linear = nn.Linear(config.n_embd, 1) - nn.init.normal_(self.linear.weight, std = 0.02) + nn.init.normal_(self.linear.weight, std=0.02) nn.init.normal_(self.linear.bias, 0) - def forward(self, hidden_states, multiple_choice_token_mask): + def forward(self, hidden_states, mc_token_mask): # Classification logits # hidden_states = hidden_states.view(-1, self.n_embd) - # multiple_choice_token_mask = multiple_choice_token_mask.view(-1, 1).expand_as(hidden_states) - multiple_choice_h = hidden_states * multiple_choice_token_mask.unsqueeze(-1) + # mc_token_mask = mc_token_mask.view(-1, 1).expand_as(hidden_states) + mc_token_mask = mc_token_mask.float() + multiple_choice_h = hidden_states * mc_token_mask.unsqueeze(-1) multiple_choice_h = multiple_choice_h.sum(dim=-2) # flat = x[..., 0].contiguous().view(-1) # multiple_choice_h = multiple_choice_h[flat == self.multiple_choice_token, :] @@ -307,6 +311,7 @@ class OpenAIGPTPreTrainedModel(nn.Module): """ An abstract class to handle weights initialization and a simple interface for dowloading and loading pretrained models. """ + def __init__(self, config, *inputs, **kwargs): super(OpenAIGPTPreTrainedModel, self).__init__() if not isinstance(config, OpenAIGPTConfig): @@ -315,7 +320,8 @@ class OpenAIGPTPreTrainedModel(nn.Module): "To create a model from a pretrained model use " "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( self.__class__.__name__, self.__class__.__name__ - )) + ) + ) self.config = config def init_weights(self, module): @@ -335,8 +341,9 @@ class OpenAIGPTPreTrainedModel(nn.Module): pass @classmethod - def from_pretrained(cls, pretrained_model_name, num_special_tokens=0, state_dict=None, cache_dir=None, - *inputs, **kwargs): + def from_pretrained( + cls, pretrained_model_name, num_special_tokens=None, state_dict=None, cache_dir=None, from_tf=False, *inputs, **kwargs + ): """ Instantiate a OpenAIGPTPreTrainedModel from a pre-trained model file or a pytorch state dict. Download and cache the pre-trained model file if needed. @@ -348,6 +355,10 @@ class OpenAIGPTPreTrainedModel(nn.Module): - a path or url to a pretrained model archive containing: . `openai_gpt_config.json` a configuration file for the model . `pytorch_model.bin` a PyTorch dump of a OpenAIGPTModel instance + - a path or url to a pretrained model archive containing: + . `bert_config.json` a configuration file for the model + . a series of NumPy files containing OpenAI TensorFlow trained weights + from_tf: should we load the weights from a locally saved TensorFlow checkpoint cache_dir: an optional path to a folder in which the pre-trained models will be cached. state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models *inputs, **kwargs: additional input for the specific Bert class @@ -365,24 +376,22 @@ class OpenAIGPTPreTrainedModel(nn.Module): "Model name '{}' was not found in model name list ({}). " "We assumed '{}' was a path or url but couldn't find any file " "associated to this path or url.".format( - pretrained_model_name, - ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), - archive_file)) + pretrained_model_name, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), archive_file + ) + ) return None if resolved_archive_file == archive_file: logger.info("loading archive file {}".format(archive_file)) else: - logger.info("loading archive file {} from cache at {}".format( - archive_file, resolved_archive_file)) + logger.info("loading archive file {} from cache at {}".format(archive_file, resolved_archive_file)) tempdir = None if os.path.isdir(resolved_archive_file): serialization_dir = resolved_archive_file else: # Extract archive to temp dir tempdir = tempfile.mkdtemp() - logger.info("extracting archive file {} to temp dir {}".format( - resolved_archive_file, tempdir)) - with tarfile.open(resolved_archive_file, 'r:gz') as archive: + logger.info("extracting archive file {} to temp dir {}".format(resolved_archive_file, tempdir)) + with tarfile.open(resolved_archive_file, "r:gz") as archive: archive.extractall(tempdir) serialization_dir = tempdir # Load config @@ -391,18 +400,24 @@ class OpenAIGPTPreTrainedModel(nn.Module): logger.info("Model config {}".format(config)) # Instantiate model. model = cls(config, *inputs, **kwargs) - if state_dict is None: + if state_dict is None and not from_tf: weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) - state_dict = torch.load(weights_path) + state_dict = torch.load(weights_path, map_location='cpu' if not torch.cuda.is_available() else None) + if tempdir: + # Clean up temp dir + shutil.rmtree(tempdir) + if from_tf: + # Directly load from a TensorFlow checkpoint (stored as NumPy array) + return load_tf_weights_in_openai_gpt(model, serialization_dir) old_keys = [] new_keys = [] for key in state_dict.keys(): new_key = None - if 'gamma' in key: - new_key = key.replace('gamma', 'weight') - if 'beta' in key: - new_key = key.replace('beta', 'bias') + if "gamma" in key: + new_key = key.replace("gamma", "weight") + if "beta" in key: + new_key = key.replace("beta", "bias") if new_key: old_keys.append(key) new_keys.append(new_key) @@ -413,34 +428,36 @@ class OpenAIGPTPreTrainedModel(nn.Module): unexpected_keys = [] error_msgs = [] # copy state_dict so _load_from_state_dict can modify it - metadata = getattr(state_dict, '_metadata', None) + metadata = getattr(state_dict, "_metadata", None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata - def load(module, prefix=''): + def load(module, prefix=""): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) module._load_from_state_dict( - state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) + state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs + ) for name, child in module._modules.items(): if child is not None: - load(child, prefix + name + '.') - load(model.transformer if hasattr(model, 'transformer') else model, prefix='') + load(child, prefix + name + ".") + + load(model.transformer if hasattr(model, "transformer") else model, prefix="") if len(missing_keys) > 0: - logger.info("Weights of {} not initialized from pretrained model: {}".format( - model.__class__.__name__, missing_keys)) + logger.info( + "Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, missing_keys) + ) if len(unexpected_keys) > 0: - logger.info("Weights from pretrained model not used in {}: {}".format( - model.__class__.__name__, unexpected_keys)) + logger.info( + "Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, unexpected_keys) + ) if len(error_msgs) > 0: - raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( - model.__class__.__name__, "\n\t".join(error_msgs))) + raise RuntimeError( + "Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs)) + ) # Add additional embeddings for special tokens if needed - if num_special_tokens != config.n_special: + if num_special_tokens is not None and num_special_tokens != config.n_special: model.set_num_special_tokens(num_special_tokens) - if tempdir: - # Clean up temp dir - shutil.rmtree(tempdir) return model @@ -495,6 +512,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): hidden_states = model(input_ids) ``` """ + def __init__(self, config): super(OpenAIGPTModel, self).__init__(config) total_embeddings_size = config.vocab_size + config.n_special + config.n_ctx @@ -516,8 +534,8 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): # Initialize all new embeddings (in particular the special tokens) self.init_weights(self.embed) # Copy word and positional embeddings from the previous weights - self.embed.weight.data[:self.config.vocab_size, :] = old_embed.weight.data[:self.config.vocab_size, :] - self.embed.weight.data[-self.config.n_ctx:, :] = old_embed.weight.data[-self.config.n_ctx:, :] + self.embed.weight.data[: self.config.vocab_size, :] = old_embed.weight.data[: self.config.vocab_size, :] + self.embed.weight.data[-self.config.n_ctx :, :] = old_embed.weight.data[-self.config.n_ctx :, :] def forward(self, input_ids, position_ids=None, token_type_ids=None): if position_ids is None: @@ -544,6 +562,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): hidden_states = block(hidden_states) return hidden_states.view(*input_shape, hidden_states.size(-1)) + class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): """OpenAI GPT model with a Language Modeling head ("Improving Language Understanding by Generative Pre-Training"). @@ -602,6 +621,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): lm_logits = model(input_ids) ``` """ + def __init__(self, config): super(OpenAIGPTLMHeadModel, self).__init__(config) self.transformer = OpenAIGPTModel(config) @@ -622,6 +642,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): return loss return lm_logits + class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): """OpenAI GPT model with a Language Modeling and a Multiple Choice heads ("Improving Language Understanding by Generative Pre-Training"). @@ -653,7 +674,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): Inputs: `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] with the word BPE token indices selected in the range [0, config.vocab_size[ - `multiple_choice_token_mask`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] + `mc_token_mask`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] with a value of 1 were the last hidden state is (usually the [CLS] token) and 0 otherwise. `position_ids`: an optional torch.LongTensor with the same shape as input_ids with the position indices (selected in the range [config.vocab_size + config.n_special, @@ -678,14 +699,15 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): ```python # Already been converted into BPE token ids input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - multiple_choice_token_mask = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + mc_token_mask = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) config = modeling_openai.OpenAIGPTConfig() model = modeling_openai.OpenAIGPTLMHeadModel(config) - lm_logits, multiple_choice_logits = model(input_ids, multiple_choice_token_mask) + lm_logits, multiple_choice_logits = model(input_ids, mc_token_mask) ``` """ + def __init__(self, config): super(OpenAIGPTDoubleHeadsModel, self).__init__(config) self.transformer = OpenAIGPTModel(config) @@ -698,18 +720,17 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): self.transformer.set_num_special_tokens(num_special_tokens) self.lm_head.set_embeddings_weights(self.transformer.embed.weight) - def forward(self, input_ids, multiple_choice_token_mask, position_ids=None, token_type_ids=None, - lm_labels=None, multiple_choice_labels=None): + def forward(self, input_ids, mc_token_mask, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None): hidden_states = self.transformer(input_ids, position_ids, token_type_ids) lm_logits = self.lm_head(hidden_states) - multiple_choice_logits = self.multiple_choice_head(hidden_states, multiple_choice_token_mask) + mc_logits = self.multiple_choice_head(hidden_states, mc_token_mask) losses = [] if lm_labels is not None: loss_fct = CrossEntropyLoss(ignore_index=-1) losses.append(loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1))) - if multiple_choice_labels is not None: + if mc_labels is not None: loss_fct = CrossEntropyLoss() - losses.append(loss_fct(multiple_choice_logits, multiple_choice_labels.view(-1))) + losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))) if losses: return losses - return lm_logits, multiple_choice_logits + return lm_logits, mc_logits diff --git a/pytorch_pretrained_bert/modeling_transfo_xl.py b/pytorch_pretrained_bert/modeling_transfo_xl.py index ba8994fd8a..54c387c34b 100644 --- a/pytorch_pretrained_bert/modeling_transfo_xl.py +++ b/pytorch_pretrained_bert/modeling_transfo_xl.py @@ -37,6 +37,7 @@ from torch.nn.parameter import Parameter from .modeling import BertLayerNorm as LayerNorm from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits from .file_utils import cached_path +from .convert_transfo_xl_checkpoint_to_pytorch import load_tf_weights_in_transfo_xl logger = logging.getLogger(__name__) @@ -48,6 +49,7 @@ PRETRAINED_CONFIG_ARCHIVE_MAP = { } CONFIG_NAME = 'transfo_xl_config.json' WEIGHTS_NAME = 'pytorch_model.bin' +TF_WEIGHTS_NAME = 'model.ckpt' class TransfoXLConfig(object): """Configuration class to store the configuration of a `TransfoXLModel`. @@ -749,7 +751,7 @@ class TransfoXLPreTrainedModel(nn.Module): @classmethod def from_pretrained(cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None, - *inputs, **kwargs): + from_tf=False, *inputs, **kwargs): """ Instantiate a TransfoXLPreTrainedModel from a pre-trained model file or a pytorch state dict. Download and cache the pre-trained model file if needed. @@ -761,6 +763,10 @@ class TransfoXLPreTrainedModel(nn.Module): - a path or url to a pretrained model archive containing: . `transfo_xl_config.json` a configuration file for the model . `pytorch_model.bin` a PyTorch dump of a TransfoXLModel instance + - a path or url to a pretrained model archive containing: + . `bert_config.json` a configuration file for the model + . `model.chkpt` a TensorFlow checkpoint + from_tf: should we load the weights from a locally saved TensorFlow checkpoint cache_dir: an optional path to a folder in which the pre-trained models will be cached. state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models *inputs, **kwargs: additional input for the specific Bert class @@ -799,9 +805,12 @@ class TransfoXLPreTrainedModel(nn.Module): logger.info("Model config {}".format(config)) # Instantiate model. model = cls(config, *inputs, **kwargs) - if state_dict is None: + if state_dict is None and not from_tf: state_dict = torch.load(resolved_archive_file) - + if from_tf: + # Directly load from a TensorFlow checkpoint + weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME) + return load_tf_weights_in_transfo_xl(model, weights_path) missing_keys = [] unexpected_keys = [] error_msgs = [] diff --git a/pytorch_pretrained_bert/tokenization_openai.py b/pytorch_pretrained_bert/tokenization_openai.py index 1492075817..e5e4dbda39 100644 --- a/pytorch_pretrained_bert/tokenization_openai.py +++ b/pytorch_pretrained_bert/tokenization_openai.py @@ -130,6 +130,9 @@ class OpenAIGPTTokenizer(object): else: self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens)) + def __len__(self): + return len(self.encoder) + len(self.special_tokens) + def set_special_tokens(self, special_tokens): self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))