diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md deleted file mode 100644 index 124b4b32c2..0000000000 --- a/CONTRIBUTING.md +++ /dev/null @@ -1,31 +0,0 @@ -# How to Contribute - -BERT needs to maintain permanent compatibility with the pre-trained model files, -so we do not plan to make any major changes to this library (other than what was -promised in the README). However, we can accept small patches related to -re-factoring and documentation. To submit contributes, there are just a few -small guidelines you need to follow. - -## Contributor License Agreement - -Contributions to this project must be accompanied by a Contributor License -Agreement. You (or your employer) retain the copyright to your contribution; -this simply gives us permission to use and redistribute your contributions as -part of the project. Head over to to see -your current agreements on file or to sign a new one. - -You generally only need to submit a CLA once, so if you've already submitted one -(even if it was for a different project), you probably don't need to do it -again. - -## Code reviews - -All submissions, including submissions by project members, require review. We -use GitHub pull requests for this purpose. Consult -[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more -information on using pull requests. - -## Community Guidelines - -This project follows -[Google's Open Source Community Guidelines](https://opensource.google.com/conduct/). diff --git a/README.md b/README.md index cd4c7c97dd..3ed9fd6b76 100644 --- a/README.md +++ b/README.md @@ -8,29 +8,26 @@ This implementation can load any pre-trained TensorFlow checkpoint for BERT (in The code to use, in addition, [the Multilingual and Chinese models](https://github.com/google-research/bert/blob/master/multilingual.md) will be added later this week (it's actually just the tokenization code that needs to be updated). -## Loading a TensorFlow checkpoint (e.g. [Google's pre-trained models](https://github.com/google-research/bert#pre-trained-models)) +## Installation, requirements, test -You can convert any TensorFlow checkpoint for BERT (in particular [the pre-trained models released by Google](https://github.com/google-research/bert#pre-trained-models)) in a PyTorch save file by using the [`convert_tf_checkpoint_to_pytorch.py`](convert_tf_checkpoint_to_pytorch.py) script. +This code was tested on Python 3.5+. The requirements are: -This script takes as input a TensorFlow checkpoint (three files starting with `bert_model.ckpt`) and the associated configuration file (`bert_config.json`), and creates a PyTorch model for this configuration, loads the weights from the TensorFlow checkpoint in the PyTorch model and saves the resulting model in a standard PyTorch save file that can be imported using `torch.load()` (see examples in `extract_features.py`, `run_classifier.py` and `run_squad.py`). +- PyTorch (>= 0.4.1) +- tqdm -You only need to run this conversion script **once** to get a PyTorch model. You can then disregard the TensorFlow checkpoint (the three files starting with `bert_model.ckpt`) but be sure to keep the configuration file (`bert_config.json`) and the vocabulary file (`vocab.txt`) as these are needed for the PyTorch model too. +To install the dependencies: -To run this specific conversion script you will need to have TensorFlow and PyTorch installed (`pip install tensorflow`). The rest of the repository only requires PyTorch. +````bash +pip install -r ./requirements.txt +```` -Here is an example of the conversion process for a pre-trained `BERT-Base Uncased` model: +A series of tests is included in the [tests folder](https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/tests) and can be run using `pytest` (install pytest if needed: `pip install pytest`). -```shell -export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12 - -python convert_tf_checkpoint_to_pytorch.py \ - --tf_checkpoint_path $BERT_BASE_DIR/bert_model.ckpt \ - --bert_config_file $BERT_BASE_DIR/bert_config.json \ - --pytorch_dump_path $BERT_BASE_DIR/pytorch_model.bin +You can run the tests with the command: +```bash +python -m pytest -sv tests/ ``` -You can download Google's pre-trained models for the conversion [here](https://github.com/google-research/bert#pre-trained-models). - ## PyTorch models for BERT We included three PyTorch models in this repository that you will find in [`modeling.py`](modeling.py): @@ -52,10 +49,15 @@ We detail them here. This model takes as inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] with the word token indices in the vocabulary (see the tokens preprocessing logic in the scripts `extract_features.py`, `run_classifier.py` and `run_squad.py`), and - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max input sequence length in the current batch. It's the mask that we typically use for attention when a batch has varying length sentences. +- `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`. This model outputs a tuple composed of: -- `all_encoder_layers`: a list of torch.FloatTensor of size [batch_size, sequence_length, hidden_size] which is a list of the full sequences of hidden-states at the end of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), and +- `encoded_layers`: controled by the value of the `output_encoded_layers` argument: + + . `output_all_encoded_layers=True`: outputs a list of the encoded-hidden-states at the end of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], + . `output_all_encoded_layers=False`: outputs only the encoded-hidden-states corresponding to the last attention block, + - `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a classifier pretrained on top of the hidden state associated to the first character of the input (`CLF`) to train on the Next-Sentence task (see BERT's paper). An example on how to use this class is given in the `extract_features.py` script which can be used to extract the hidden states of the model for a given input. @@ -76,26 +78,30 @@ The token-level classifier takes as input the full sequence of the last hidden s An example on how to use this class is given in the `run_squad.py` script which can be used to fine-tune a token classifier using BERT, for example for the SQuAD task. -## Installation, requirements, test -This code was tested on Python 3.5+. The requirements are: +## Converting a TensorFlow checkpoint in a PyTorch checkpoint -- PyTorch (>= 0.4.1) -- tqdm +You can convert any TensorFlow checkpoint for BERT (in particular [the pre-trained models released by Google](https://github.com/google-research/bert#pre-trained-models)) in a PyTorch save file by using the [`convert_tf_checkpoint_to_pytorch.py`](convert_tf_checkpoint_to_pytorch.py) script. -To install the dependencies: +This script takes as input a TensorFlow checkpoint (three files starting with `bert_model.ckpt`) and the associated configuration file (`bert_config.json`), and creates a PyTorch model for this configuration, loads the weights from the TensorFlow checkpoint in the PyTorch model and saves the resulting model in a standard PyTorch save file that can be imported using `torch.load()` (see examples in `extract_features.py`, `run_classifier.py` and `run_squad.py`). -````bash -pip install -r ./requirements.txt -```` +You only need to run this conversion script **once** to get a PyTorch model. You can then disregard the TensorFlow checkpoint (the three files starting with `bert_model.ckpt`) but be sure to keep the configuration file (`bert_config.json`) and the vocabulary file (`vocab.txt`) as these are needed for the PyTorch model too. -A series of tests is included in the [tests folder](https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/tests) and can be run using `pytest` (install pytest if needed: `pip install pytest`). +To run this specific conversion script you will need to have TensorFlow and PyTorch installed (`pip install tensorflow`). The rest of the repository only requires PyTorch. -You can run the tests with the command: -```bash -python -m pytest -sv tests/ +Here is an example of the conversion process for a pre-trained `BERT-Base Uncased` model: + +```shell +export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12 + +python convert_tf_checkpoint_to_pytorch.py \ + --tf_checkpoint_path $BERT_BASE_DIR/bert_model.ckpt \ + --bert_config_file $BERT_BASE_DIR/bert_config.json \ + --pytorch_dump_path $BERT_BASE_DIR/pytorch_model.bin ``` +You can download Google's pre-trained models for the conversion [here](https://github.com/google-research/bert#pre-trained-models). + ## Training on large batches: gradient accumulation, multi-GPU and distributed training BERT-base and BERT-large are respectively 110M and 340M parameters models and it can be difficult to fine-tune them on a single GPU with the recommended batch size for good performance (in most case a batch size of 32). diff --git a/__init__.py b/__init__.py deleted file mode 100644 index effb57b1e8..0000000000 --- a/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The Google AI Language Team Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - diff --git a/bin/pytorch_pretrained_bert b/bin/pytorch_pretrained_bert new file mode 100644 index 0000000000..eee2b4c250 --- /dev/null +++ b/bin/pytorch_pretrained_bert @@ -0,0 +1,2 @@ +#!/bin/sh +python -m pytorch_pretrained_bert "$@" \ No newline at end of file diff --git a/extract_features.py b/examples/extract_features.py similarity index 96% rename from extract_features.py rename to examples/extract_features.py index 6ad3a90e00..3ea1909bb3 100644 --- a/extract_features.py +++ b/examples/extract_features.py @@ -19,18 +19,17 @@ from __future__ import division from __future__ import print_function import argparse -import codecs import collections import logging import json import re import torch -from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler +from torch.utils.data import TensorDataset, DataLoader, SequentialSampler from torch.utils.data.distributed import DistributedSampler -import tokenization -from modeling import BertConfig, BertModel +from pytorch_pretrained_bert.tokenization import convert_to_unicode, BertTokenizer +from pytorch_pretrained_bert.modeling import BertConfig, BertModel logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt = '%m/%d/%Y %H:%M:%S', @@ -171,7 +170,7 @@ def read_examples(input_file): unique_id = 0 with open(input_file, "r") as reader: while True: - line = tokenization.convert_to_unicode(reader.readline()) + line = convert_to_unicode(reader.readline()) if not line: break line = line.strip() @@ -227,13 +226,13 @@ def main(): n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.distributed.init_process_group(backend='nccl') - logger.info("device", device, "n_gpu", n_gpu, "distributed training", bool(args.local_rank != -1)) + logger.info("device: {} n_gpu: {} distributed training: {}".format(device, n_gpu, bool(args.local_rank != -1))) layer_indexes = [int(x) for x in args.layers.split(",")] bert_config = BertConfig.from_json_file(args.bert_config_file) - tokenizer = tokenization.FullTokenizer( + tokenizer = BertTokenizer( vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) examples = read_examples(args.input_file) diff --git a/run_classifier.py b/examples/run_classifier.py similarity index 96% rename from run_classifier.py rename to examples/run_classifier.py index 2f58382ede..9543b95a94 100644 --- a/run_classifier.py +++ b/examples/run_classifier.py @@ -30,9 +30,9 @@ import torch from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler -import tokenization -from modeling import BertConfig, BertForSequenceClassification -from optimization import BERTAdam +from pytorch_pretrained_bert.tokenization import printable_text, convert_to_unicode, BertTokenizer +from pytorch_pretrained_bert.modeling import BertConfig, BertForSequenceClassification +from pytorch_pretrained_bert.optimization import BERTAdam logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt = '%m/%d/%Y %H:%M:%S', @@ -122,9 +122,9 @@ class MrpcProcessor(DataProcessor): if i == 0: continue guid = "%s-%s" % (set_type, i) - text_a = tokenization.convert_to_unicode(line[3]) - text_b = tokenization.convert_to_unicode(line[4]) - label = tokenization.convert_to_unicode(line[0]) + text_a = convert_to_unicode(line[3]) + text_b = convert_to_unicode(line[4]) + label = convert_to_unicode(line[0]) examples.append( InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) return examples @@ -154,14 +154,14 @@ class MnliProcessor(DataProcessor): for (i, line) in enumerate(lines): if i == 0: continue - guid = "%s-%s" % (set_type, tokenization.convert_to_unicode(line[0])) - text_a = tokenization.convert_to_unicode(line[8]) - text_b = tokenization.convert_to_unicode(line[9]) - label = tokenization.convert_to_unicode(line[-1]) + guid = "%s-%s" % (set_type, convert_to_unicode(line[0])) + text_a = convert_to_unicode(line[8]) + text_b = convert_to_unicode(line[9]) + label = convert_to_unicode(line[-1]) examples.append( InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) return examples - + class ColaProcessor(DataProcessor): """Processor for the CoLA data set (GLUE version).""" @@ -185,8 +185,8 @@ class ColaProcessor(DataProcessor): examples = [] for (i, line) in enumerate(lines): guid = "%s-%s" % (set_type, i) - text_a = tokenization.convert_to_unicode(line[3]) - label = tokenization.convert_to_unicode(line[1]) + text_a = convert_to_unicode(line[3]) + label = convert_to_unicode(line[1]) examples.append( InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) return examples @@ -273,7 +273,7 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer logger.info("*** Example ***") logger.info("guid: %s" % (example.guid)) logger.info("tokens: %s" % " ".join( - [tokenization.printable_text(x) for x in tokens])) + [printable_text(x) for x in tokens])) logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) logger.info( @@ -281,11 +281,10 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer logger.info("label: %s (id = %d)" % (example.label, label_id)) features.append( - InputFeatures( - input_ids=input_ids, - input_mask=input_mask, - segment_ids=segment_ids, - label_id=label_id)) + InputFeatures(input_ids=input_ids, + input_mask=input_mask, + segment_ids=segment_ids, + label_id=label_id)) return features @@ -307,7 +306,7 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length): def accuracy(out, labels): outputs = np.argmax(out, axis=1) - return np.sum(outputs==labels) + return np.sum(outputs == labels) def copy_optimizer_params_to_model(named_params_model, named_params_optimizer): """ Utility function for optimize_on_cpu and 16-bits training. @@ -497,7 +496,7 @@ def main(): processor = processors[task_name]() label_list = processor.get_labels() - tokenizer = tokenization.FullTokenizer( + tokenizer = BertTokenizer( vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) train_examples = None diff --git a/run_squad.py b/examples/run_squad.py similarity index 97% rename from run_squad.py rename to examples/run_squad.py index 248b92c504..1011e836fd 100644 --- a/run_squad.py +++ b/examples/run_squad.py @@ -25,7 +25,6 @@ import json import math import os import random -import six from tqdm import tqdm, trange import numpy as np @@ -33,9 +32,9 @@ import torch from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler -import tokenization -from modeling import BertConfig, BertForQuestionAnswering -from optimization import BERTAdam +from pytorch_pretrained_bert.tokenization import printable_text, whitespace_tokenize, BasicTokenizer, BertTokenizer +from pytorch_pretrained_bert.modeling import BertConfig, BertForQuestionAnswering +from pytorch_pretrained_bert.optimization import BERTAdam logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt = '%m/%d/%Y %H:%M:%S', @@ -65,9 +64,9 @@ class SquadExample(object): def __repr__(self): s = "" - s += "qas_id: %s" % (tokenization.printable_text(self.qas_id)) + s += "qas_id: %s" % (printable_text(self.qas_id)) s += ", question_text: %s" % ( - tokenization.printable_text(self.question_text)) + printable_text(self.question_text)) s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens)) if self.start_position: s += ", start_position: %d" % (self.start_position) @@ -156,7 +155,7 @@ def read_squad_examples(input_file, is_training): # guaranteed to be preserved. actual_text = " ".join(doc_tokens[start_position:(end_position + 1)]) cleaned_answer_text = " ".join( - tokenization.whitespace_tokenize(orig_answer_text)) + whitespace_tokenize(orig_answer_text)) if actual_text.find(cleaned_answer_text) == -1: logger.warning("Could not find answer: '%s' vs. '%s'", actual_text, cleaned_answer_text) @@ -290,11 +289,11 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, logger.info("example_index: %s" % (example_index)) logger.info("doc_span_index: %s" % (doc_span_index)) logger.info("tokens: %s" % " ".join( - [tokenization.printable_text(x) for x in tokens])) - logger.info("token_to_orig_map: %s" % " ".join( - ["%d:%d" % (x, y) for (x, y) in six.iteritems(token_to_orig_map)])) + [printable_text(x) for x in tokens])) + logger.info("token_to_orig_map: %s" % " ".join([ + "%d:%d" % (x, y) for (x, y) in token_to_orig_map.items()])) logger.info("token_is_max_context: %s" % " ".join([ - "%d:%s" % (x, y) for (x, y) in six.iteritems(token_is_max_context) + "%d:%s" % (x, y) for (x, y) in token_is_max_context.items() ])) logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) logger.info( @@ -306,7 +305,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, logger.info("start_position: %d" % (start_position)) logger.info("end_position: %d" % (end_position)) logger.info( - "answer: %s" % (tokenization.printable_text(answer_text))) + "answer: %s" % (printable_text(answer_text))) features.append( InputFeatures( @@ -582,7 +581,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): # and `pred_text`, and check if they are the same length. If they are # NOT the same length, the heuristic has failed. If they are the same # length, we assume the characters are one-to-one aligned. - tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case) + tokenizer = BasicTokenizer(do_lower_case=do_lower_case) tok_text = " ".join(tokenizer.tokenize(orig_text)) @@ -606,7 +605,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): # We then project the characters in `pred_text` back to `orig_text` using # the character-to-character alignment. tok_s_to_ns_map = {} - for (i, tok_index) in six.iteritems(tok_ns_to_s_map): + for (i, tok_index) in tok_ns_to_s_map.items(): tok_s_to_ns_map[tok_index] = i orig_start_position = None @@ -827,7 +826,7 @@ def main(): raise ValueError("Output directory () already exists and is not empty.") os.makedirs(args.output_dir, exist_ok=True) - tokenizer = tokenization.FullTokenizer( + tokenizer = BertTokenizer( vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) train_examples = None diff --git a/modeling.py b/modeling.py deleted file mode 100644 index 53243e5eb4..0000000000 --- a/modeling.py +++ /dev/null @@ -1,483 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch BERT model.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import copy -import json -import math -import six -import torch -import torch.nn as nn -from torch.nn import CrossEntropyLoss -from six import string_types - -def gelu(x): - """Implementation of the gelu activation function. - For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): - 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) - """ - return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) - - -def swish(x): - return x * torch.sigmoid(x) - - -ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} - - -class BertConfig(object): - """Configuration class to store the configuration of a `BertModel`. - """ - def __init__(self, - vocab_size, - hidden_size=768, - num_hidden_layers=12, - num_attention_heads=12, - intermediate_size=3072, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=16, - initializer_range=0.02): - """Constructs BertConfig. - - Args: - vocab_size: Vocabulary size of `inputs_ids` in `BertModel`. - hidden_size: Size of the encoder layers and the pooler layer. - num_hidden_layers: Number of hidden layers in the Transformer encoder. - num_attention_heads: Number of attention heads for each attention layer in - the Transformer encoder. - intermediate_size: The size of the "intermediate" (i.e., feed-forward) - layer in the Transformer encoder. - hidden_act: The non-linear activation function (function or string) in the - encoder and pooler. If string, "gelu", "relu" and "swish" are supported. - hidden_dropout_prob: The dropout probabilitiy for all fully connected - layers in the embeddings, encoder, and pooler. - attention_probs_dropout_prob: The dropout ratio for the attention - probabilities. - max_position_embeddings: The maximum sequence length that this model might - ever be used with. Typically set this to something large just in case - (e.g., 512 or 1024 or 2048). - type_vocab_size: The vocabulary size of the `token_type_ids` passed into - `BertModel`. - initializer_range: The sttdev of the truncated_normal_initializer for - initializing all weight matrices. - """ - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.hidden_act = hidden_act - self.intermediate_size = intermediate_size - self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob - self.max_position_embeddings = max_position_embeddings - self.type_vocab_size = type_vocab_size - self.initializer_range = initializer_range - - @classmethod - def from_dict(cls, json_object): - """Constructs a `BertConfig` from a Python dictionary of parameters.""" - config = BertConfig(vocab_size=None) - for (key, value) in six.iteritems(json_object): - config.__dict__[key] = value - return config - - @classmethod - def from_json_file(cls, json_file): - """Constructs a `BertConfig` from a json file of parameters.""" - with open(json_file, "r") as reader: - text = reader.read() - return cls.from_dict(json.loads(text)) - - def to_dict(self): - """Serializes this instance to a Python dictionary.""" - output = copy.deepcopy(self.__dict__) - return output - - def to_json_string(self): - """Serializes this instance to a JSON string.""" - return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" - - -class BERTLayerNorm(nn.Module): - def __init__(self, config, variance_epsilon=1e-12): - """Construct a layernorm module in the TF style (epsilon inside the square root). - """ - super(BERTLayerNorm, self).__init__() - self.gamma = nn.Parameter(torch.ones(config.hidden_size)) - self.beta = nn.Parameter(torch.zeros(config.hidden_size)) - self.variance_epsilon = variance_epsilon - - def forward(self, x): - u = x.mean(-1, keepdim=True) - s = (x - u).pow(2).mean(-1, keepdim=True) - x = (x - u) / torch.sqrt(s + self.variance_epsilon) - return self.gamma * x + self.beta - -class BERTEmbeddings(nn.Module): - def __init__(self, config): - super(BERTEmbeddings, self).__init__() - """Construct the embedding module from word, position and token_type embeddings. - """ - self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) - self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) - self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file - self.LayerNorm = BERTLayerNorm(config) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - def forward(self, input_ids, token_type_ids=None): - seq_length = input_ids.size(1) - position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) - position_ids = position_ids.unsqueeze(0).expand_as(input_ids) - if token_type_ids is None: - token_type_ids = torch.zeros_like(input_ids) - - words_embeddings = self.word_embeddings(input_ids) - position_embeddings = self.position_embeddings(position_ids) - token_type_embeddings = self.token_type_embeddings(token_type_ids) - - embeddings = words_embeddings + position_embeddings + token_type_embeddings - embeddings = self.LayerNorm(embeddings) - embeddings = self.dropout(embeddings) - return embeddings - - -class BERTSelfAttention(nn.Module): - def __init__(self, config): - super(BERTSelfAttention, self).__init__() - if config.hidden_size % config.num_attention_heads != 0: - raise ValueError( - "The hidden size (%d) is not a multiple of the number of attention " - "heads (%d)" % (config.hidden_size, config.num_attention_heads)) - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - - self.query = nn.Linear(config.hidden_size, self.all_head_size) - self.key = nn.Linear(config.hidden_size, self.all_head_size) - self.value = nn.Linear(config.hidden_size, self.all_head_size) - - self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - - def forward(self, hidden_states, attention_mask): - mixed_query_layer = self.query(hidden_states) - mixed_key_layer = self.key(hidden_states) - mixed_value_layer = self.value(hidden_states) - - query_layer = self.transpose_for_scores(mixed_query_layer) - key_layer = self.transpose_for_scores(mixed_key_layer) - value_layer = self.transpose_for_scores(mixed_value_layer) - - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - # Apply the attention mask is (precomputed for all layers in BertModel forward() function) - attention_scores = attention_scores + attention_mask - - # Normalize the attention scores to probabilities. - attention_probs = nn.Softmax(dim=-1)(attention_scores) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs) - - context_layer = torch.matmul(attention_probs, value_layer) - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) - return context_layer - - -class BERTSelfOutput(nn.Module): - def __init__(self, config): - super(BERTSelfOutput, self).__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.LayerNorm = BERTLayerNorm(config) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - def forward(self, hidden_states, input_tensor): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states - - -class BERTAttention(nn.Module): - def __init__(self, config): - super(BERTAttention, self).__init__() - self.self = BERTSelfAttention(config) - self.output = BERTSelfOutput(config) - - def forward(self, input_tensor, attention_mask): - self_output = self.self(input_tensor, attention_mask) - attention_output = self.output(self_output, input_tensor) - return attention_output - - -class BERTIntermediate(nn.Module): - def __init__(self, config): - super(BERTIntermediate, self).__init__() - self.dense = nn.Linear(config.hidden_size, config.intermediate_size) - self.intermediate_act_fn = ACT2FN[config.hidden_act] \ - if isinstance(config.hidden_act, string_types) else config.hidden_act - - def forward(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - return hidden_states - - -class BERTOutput(nn.Module): - def __init__(self, config): - super(BERTOutput, self).__init__() - self.dense = nn.Linear(config.intermediate_size, config.hidden_size) - self.LayerNorm = BERTLayerNorm(config) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - def forward(self, hidden_states, input_tensor): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states - - -class BERTLayer(nn.Module): - def __init__(self, config): - super(BERTLayer, self).__init__() - self.attention = BERTAttention(config) - self.intermediate = BERTIntermediate(config) - self.output = BERTOutput(config) - - def forward(self, hidden_states, attention_mask): - attention_output = self.attention(hidden_states, attention_mask) - intermediate_output = self.intermediate(attention_output) - layer_output = self.output(intermediate_output, attention_output) - return layer_output - - -class BERTEncoder(nn.Module): - def __init__(self, config): - super(BERTEncoder, self).__init__() - layer = BERTLayer(config) - self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) - - def forward(self, hidden_states, attention_mask): - all_encoder_layers = [] - for layer_module in self.layer: - hidden_states = layer_module(hidden_states, attention_mask) - all_encoder_layers.append(hidden_states) - return all_encoder_layers - - -class BERTPooler(nn.Module): - def __init__(self, config): - super(BERTPooler, self).__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.activation = nn.Tanh() - - def forward(self, hidden_states): - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - first_token_tensor = hidden_states[:, 0] - pooled_output = self.dense(first_token_tensor) - pooled_output = self.activation(pooled_output) - return pooled_output - - -class BertModel(nn.Module): - """BERT model ("Bidirectional Embedding Representations from a Transformer"). - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) - - config = modeling.BertConfig(vocab_size=32000, hidden_size=512, - num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) - - model = modeling.BertModel(config=config) - all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) - ``` - """ - def __init__(self, config: BertConfig): - """Constructor for BertModel. - - Args: - config: `BertConfig` instance. - """ - super(BertModel, self).__init__() - self.embeddings = BERTEmbeddings(config) - self.encoder = BERTEncoder(config) - self.pooler = BERTPooler(config) - - def forward(self, input_ids, token_type_ids=None, attention_mask=None): - if attention_mask is None: - attention_mask = torch.ones_like(input_ids) - if token_type_ids is None: - token_type_ids = torch.zeros_like(input_ids) - - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 - - embedding_output = self.embeddings(input_ids, token_type_ids) - all_encoder_layers = self.encoder(embedding_output, extended_attention_mask) - sequence_output = all_encoder_layers[-1] - pooled_output = self.pooler(sequence_output) - return all_encoder_layers, pooled_output - -class BertForSequenceClassification(nn.Module): - """BERT model for classification. - This module is composed of the BERT model with a linear layer on top of - the pooled output. - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) - - config = BertConfig(vocab_size=32000, hidden_size=512, - num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) - - num_labels = 2 - - model = BertForSequenceClassification(config, num_labels) - logits = model(input_ids, token_type_ids, input_mask) - ``` - """ - def __init__(self, config, num_labels): - super(BertForSequenceClassification, self).__init__() - self.bert = BertModel(config) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.classifier = nn.Linear(config.hidden_size, num_labels) - - def init_weights(module): - if isinstance(module, (nn.Linear, nn.Embedding)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=config.initializer_range) - elif isinstance(module, BERTLayerNorm): - module.beta.data.normal_(mean=0.0, std=config.initializer_range) - module.gamma.data.normal_(mean=0.0, std=config.initializer_range) - if isinstance(module, nn.Linear): - module.bias.data.zero_() - self.apply(init_weights) - - def forward(self, input_ids, token_type_ids, attention_mask, labels=None): - _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask) - pooled_output = self.dropout(pooled_output) - logits = self.classifier(pooled_output) - - if labels is not None: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits, labels) - return loss, logits - else: - return logits - -class BertForQuestionAnswering(nn.Module): - """BERT model for Question Answering (span extraction). - This module is composed of the BERT model with a linear layer on top of - the sequence output that computes start_logits and end_logits - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) - - config = BertConfig(vocab_size=32000, hidden_size=512, - num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) - - model = BertForQuestionAnswering(config) - start_logits, end_logits = model(input_ids, token_type_ids, input_mask) - ``` - """ - def __init__(self, config): - super(BertForQuestionAnswering, self).__init__() - self.bert = BertModel(config) - # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version - # self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.qa_outputs = nn.Linear(config.hidden_size, 2) - - def init_weights(module): - if isinstance(module, (nn.Linear, nn.Embedding)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=config.initializer_range) - elif isinstance(module, BERTLayerNorm): - module.beta.data.normal_(mean=0.0, std=config.initializer_range) - module.gamma.data.normal_(mean=0.0, std=config.initializer_range) - if isinstance(module, nn.Linear): - module.bias.data.zero_() - self.apply(init_weights) - - def forward(self, input_ids, token_type_ids, attention_mask, start_positions=None, end_positions=None): - all_encoder_layers, _ = self.bert(input_ids, token_type_ids, attention_mask) - sequence_output = all_encoder_layers[-1] - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = logits.split(1, dim=-1) - start_logits = start_logits.squeeze(-1) - end_logits = end_logits.squeeze(-1) - - if start_positions is not None and end_positions is not None: - # If we are on multi-GPU, split add a dimension - if len(start_positions.size()) > 1: - start_positions = start_positions.squeeze(-1) - if len(end_positions.size()) > 1: - end_positions = end_positions.squeeze(-1) - # sometimes the start/end positions are outside our model inputs, we ignore these terms - ignored_index = start_logits.size(1) - start_positions.clamp_(0, ignored_index) - end_positions.clamp_(0, ignored_index) - - loss_fct = CrossEntropyLoss(ignore_index=ignored_index) - start_loss = loss_fct(start_logits, start_positions) - end_loss = loss_fct(end_logits, end_positions) - total_loss = (start_loss + end_loss) / 2 - return total_loss - else: - return start_logits, end_logits diff --git a/notebooks/Comparing TF and PT models SQuAD predictions.ipynb b/notebooks/Comparing TF and PT models SQuAD predictions.ipynb index d314eccfae..c91822d8f9 100644 --- a/notebooks/Comparing TF and PT models SQuAD predictions.ipynb +++ b/notebooks/Comparing TF and PT models SQuAD predictions.ipynb @@ -463,7 +463,7 @@ ], "source": [ "bert_config = modeling_tensorflow.BertConfig.from_json_file(bert_config_file)\n", - "tokenizer = tokenization.FullTokenizer(\n", + "tokenizer = tokenization.BertTokenizer(\n", " vocab_file=vocab_file, do_lower_case=True)\n", "\n", "eval_examples = read_squad_examples(\n", diff --git a/notebooks/Comparing TF and PT models.ipynb b/notebooks/Comparing TF and PT models.ipynb index 3623c08b5e..5e724a710a 100644 --- a/notebooks/Comparing TF and PT models.ipynb +++ b/notebooks/Comparing TF and PT models.ipynb @@ -22,8 +22,8 @@ "execution_count": 1, "metadata": { "ExecuteTime": { - "end_time": "2018-11-05T13:58:50.559657Z", - "start_time": "2018-11-05T13:58:50.546096Z" + "end_time": "2018-11-15T14:56:48.412622Z", + "start_time": "2018-11-15T14:56:48.400110Z" } }, "outputs": [], @@ -44,8 +44,8 @@ "execution_count": 2, "metadata": { "ExecuteTime": { - "end_time": "2018-11-05T13:58:50.574455Z", - "start_time": "2018-11-05T13:58:50.561988Z" + "end_time": "2018-11-15T14:56:49.483829Z", + "start_time": "2018-11-15T14:56:49.471296Z" } }, "outputs": [], @@ -63,19 +63,39 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 6, "metadata": { "ExecuteTime": { - "end_time": "2018-11-05T13:58:52.202531Z", - "start_time": "2018-11-05T13:58:50.576198Z" + "end_time": "2018-11-15T14:57:51.597932Z", + "start_time": "2018-11-15T14:57:51.549466Z" } }, - "outputs": [], + "outputs": [ + { + "ename": "DuplicateFlagError", + "evalue": "The flag 'input_file' is defined twice. First from *, Second from *. Description from first occurrence: (no help available)", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mDuplicateFlagError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mspec\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mimportlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutil\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mspec_from_file_location\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'*'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moriginal_tf_inplem_dir\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m'/extract_features_tensorflow.py'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mimportlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutil\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodule_from_spec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mspec\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mspec\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloader\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexec_module\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodules\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'extract_features_tensorflow'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/bert/lib/python3.6/importlib/_bootstrap_external.py\u001b[0m in \u001b[0;36mexec_module\u001b[0;34m(self, module)\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/bert/lib/python3.6/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_call_with_frames_removed\u001b[0;34m(f, *args, **kwds)\u001b[0m\n", + "\u001b[0;32m~/Documents/Thomas/Code/HF/BERT/pytorch-pretrained-BERT/tensorflow_code/extract_features_tensorflow.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0mFLAGS\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mflags\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mFLAGS\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 34\u001b[0;31m \u001b[0mflags\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDEFINE_string\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"input_file\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 35\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[0mflags\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDEFINE_string\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"output_file\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/bert/lib/python3.6/site-packages/tensorflow/python/platform/flags.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0;34m'Use of the keyword argument names (flag_name, default_value, '\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 57\u001b[0m 'docstring) is deprecated, please use (name, default, help) instead.')\n\u001b[0;32m---> 58\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0moriginal_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 59\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mtf_decorator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmake_decorator\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moriginal_function\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/bert/lib/python3.6/site-packages/absl/flags/_defines.py\u001b[0m in \u001b[0;36mDEFINE_string\u001b[0;34m(name, default, help, flag_values, **args)\u001b[0m\n\u001b[1;32m 239\u001b[0m \u001b[0mparser\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_argument_parser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mArgumentParser\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 240\u001b[0m \u001b[0mserializer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_argument_parser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mArgumentSerializer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 241\u001b[0;31m \u001b[0mDEFINE\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparser\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdefault\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhelp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mflag_values\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mserializer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 242\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 243\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/bert/lib/python3.6/site-packages/absl/flags/_defines.py\u001b[0m in \u001b[0;36mDEFINE\u001b[0;34m(parser, name, default, help, flag_values, serializer, module_name, **args)\u001b[0m\n\u001b[1;32m 80\u001b[0m \"\"\"\n\u001b[1;32m 81\u001b[0m DEFINE_flag(_flag.Flag(parser, serializer, name, default, help, **args),\n\u001b[0;32m---> 82\u001b[0;31m flag_values, module_name)\n\u001b[0m\u001b[1;32m 83\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 84\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/bert/lib/python3.6/site-packages/absl/flags/_defines.py\u001b[0m in \u001b[0;36mDEFINE_flag\u001b[0;34m(flag, flag_values, module_name)\u001b[0m\n\u001b[1;32m 102\u001b[0m \u001b[0;31m# Copying the reference to flag_values prevents pychecker warnings.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 103\u001b[0m \u001b[0mfv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mflag_values\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 104\u001b[0;31m \u001b[0mfv\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mflag\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mflag\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 105\u001b[0m \u001b[0;31m# Tell flag_values who's defining the flag.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 106\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmodule_name\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/miniconda3/envs/bert/lib/python3.6/site-packages/absl/flags/_flagvalues.py\u001b[0m in \u001b[0;36m__setitem__\u001b[0;34m(self, name, flag)\u001b[0m\n\u001b[1;32m 427\u001b[0m \u001b[0;31m# module is simply being imported a subsequent time.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 428\u001b[0m \u001b[0;32mreturn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 429\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0m_exceptions\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDuplicateFlagError\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_flag\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 430\u001b[0m \u001b[0mshort_name\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mflag\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshort_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 431\u001b[0m \u001b[0;31m# If a new flag overrides an old one, we need to cleanup the old flag's\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mDuplicateFlagError\u001b[0m: The flag 'input_file' is defined twice. First from *, Second from *. Description from first occurrence: (no help available)" + ] + } + ], "source": [ "import importlib.util\n", "import sys\n", "\n", - "spec = importlib.util.spec_from_file_location('*', original_tf_inplem_dir + '/extract_features.py')\n", + "spec = importlib.util.spec_from_file_location('*', original_tf_inplem_dir + '/extract_features_tensorflow.py')\n", "module = importlib.util.module_from_spec(spec)\n", "spec.loader.exec_module(module)\n", "sys.modules['extract_features_tensorflow'] = module\n", @@ -85,11 +105,11 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 8, "metadata": { "ExecuteTime": { - "end_time": "2018-11-05T13:58:52.325822Z", - "start_time": "2018-11-05T13:58:52.205361Z" + "end_time": "2018-11-15T14:58:05.650987Z", + "start_time": "2018-11-15T14:58:05.541620Z" } }, "outputs": [ @@ -122,11 +142,11 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 9, "metadata": { "ExecuteTime": { - "end_time": "2018-11-05T13:58:55.939938Z", - "start_time": "2018-11-05T13:58:52.330202Z" + "end_time": "2018-11-15T14:58:11.562443Z", + "start_time": "2018-11-15T14:58:08.036485Z" } }, "outputs": [ @@ -134,15 +154,15 @@ "name": "stdout", "output_type": "stream", "text": [ - "WARNING:tensorflow:Estimator's model_fn (.model_fn at 0x12839dbf8>) includes params argument, but params are not passed to Estimator.\n", - "WARNING:tensorflow:Using temporary folder as model directory: /var/folders/yx/cw8n_njx3js5jksyw_qlp8p00000gn/T/tmpdbx_h23u\n", - "INFO:tensorflow:Using config: {'_model_dir': '/var/folders/yx/cw8n_njx3js5jksyw_qlp8p00000gn/T/tmpdbx_h23u', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true\n", + "WARNING:tensorflow:Estimator's model_fn (.model_fn at 0x11ea7f1e0>) includes params argument, but params are not passed to Estimator.\n", + "WARNING:tensorflow:Using temporary folder as model directory: /var/folders/yx/cw8n_njx3js5jksyw_qlp8p00000gn/T/tmphs4_nsq9\n", + "INFO:tensorflow:Using config: {'_model_dir': '/var/folders/yx/cw8n_njx3js5jksyw_qlp8p00000gn/T/tmphs4_nsq9', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true\n", "graph_options {\n", " rewrite_options {\n", " meta_optimizer_iterations: ONE\n", " }\n", "}\n", - ", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_service': None, '_cluster_spec': , '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_tpu_config': TPUConfig(iterations_per_loop=2, num_shards=1, num_cores_per_replica=None, per_host_input_for_training=3, tpu_job_name=None, initial_infeed_sleep_secs=None, input_partition_dims=None), '_cluster': None}\n", + ", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_service': None, '_cluster_spec': , '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_tpu_config': TPUConfig(iterations_per_loop=2, num_shards=1, num_cores_per_replica=None, per_host_input_for_training=3, tpu_job_name=None, initial_infeed_sleep_secs=None, input_partition_dims=None), '_cluster': None}\n", "WARNING:tensorflow:Setting TPUConfig.num_shards==1 is an unsupported behavior. Please fix as soon as possible (leaving num_shards as None.\n", "INFO:tensorflow:_TPUContext: eval_on_tpu True\n", "WARNING:tensorflow:eval_on_tpu ignored because use_tpu is False.\n" @@ -178,11 +198,11 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 10, "metadata": { "ExecuteTime": { - "end_time": "2018-11-05T13:59:01.717585Z", - "start_time": "2018-11-05T13:58:55.941869Z" + "end_time": "2018-11-15T14:58:21.736543Z", + "start_time": "2018-11-15T14:58:16.723829Z" } }, "outputs": [ @@ -190,7 +210,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "INFO:tensorflow:Could not find trained model in model_dir: /var/folders/yx/cw8n_njx3js5jksyw_qlp8p00000gn/T/tmpdbx_h23u, running initialization to predict.\n", + "INFO:tensorflow:Could not find trained model in model_dir: /var/folders/yx/cw8n_njx3js5jksyw_qlp8p00000gn/T/tmphs4_nsq9, running initialization to predict.\n", "INFO:tensorflow:Calling model_fn.\n", "INFO:tensorflow:Running infer on CPU\n", "INFO:tensorflow:Done calling model_fn.\n", @@ -241,11 +261,11 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 11, "metadata": { "ExecuteTime": { - "end_time": "2018-11-05T13:59:01.769845Z", - "start_time": "2018-11-05T13:59:01.719878Z" + "end_time": "2018-11-15T14:58:23.970714Z", + "start_time": "2018-11-15T14:58:23.931930Z" } }, "outputs": [ @@ -266,7 +286,7 @@ "(128, 768)" ] }, - "execution_count": 7, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -282,11 +302,11 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 12, "metadata": { "ExecuteTime": { - "end_time": "2018-11-05T13:59:01.807638Z", - "start_time": "2018-11-05T13:59:01.771422Z" + "end_time": "2018-11-15T14:58:25.547012Z", + "start_time": "2018-11-15T14:58:25.516076Z" } }, "outputs": [], @@ -303,361 +323,391 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "os.chdir('./examples')" + ] + }, + { + "cell_type": "code", + "execution_count": 17, "metadata": { "ExecuteTime": { - "end_time": "2018-11-05T13:59:02.020918Z", - "start_time": "2018-11-05T13:59:01.810061Z" + "end_time": "2018-11-15T15:03:49.528679Z", + "start_time": "2018-11-15T15:03:49.497697Z" } }, "outputs": [], "source": [ "import extract_features\n", + "import pytorch_pretrained_bert as ppb\n", "from extract_features import *" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 25, "metadata": { "ExecuteTime": { - "end_time": "2018-11-05T13:59:02.058211Z", - "start_time": "2018-11-05T13:59:02.022785Z" + "end_time": "2018-11-15T15:21:18.001177Z", + "start_time": "2018-11-15T15:21:17.970369Z" } }, "outputs": [], "source": [ - "init_checkpoint_pt = \"../google_models/uncased_L-12_H-768_A-12/pytorch_model.bin\"" + "init_checkpoint_pt = \"../../google_models/uncased_L-12_H-768_A-12/\"" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 26, "metadata": { "ExecuteTime": { - "end_time": "2018-11-05T13:59:03.740561Z", - "start_time": "2018-11-05T13:59:02.059877Z" + "end_time": "2018-11-15T15:21:20.893669Z", + "start_time": "2018-11-15T15:21:18.786623Z" }, "scrolled": true }, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "11/15/2018 16:21:18 - INFO - pytorch_pretrained_bert.modeling - loading archive file ../../google_models/uncased_L-12_H-768_A-12/\n", + "11/15/2018 16:21:18 - INFO - pytorch_pretrained_bert.modeling - Model config {\n", + " \"attention_probs_dropout_prob\": 0.1,\n", + " \"hidden_act\": \"gelu\",\n", + " \"hidden_dropout_prob\": 0.1,\n", + " \"hidden_size\": 768,\n", + " \"initializer_range\": 0.02,\n", + " \"intermediate_size\": 3072,\n", + " \"max_position_embeddings\": 512,\n", + " \"num_attention_heads\": 12,\n", + " \"num_hidden_layers\": 12,\n", + " \"type_vocab_size\": 2,\n", + " \"vocab_size\": 30522\n", + "}\n", + "\n" + ] + }, { "data": { "text/plain": [ "BertModel(\n", - " (embeddings): BERTEmbeddings(\n", + " (embeddings): BertEmbeddings(\n", " (word_embeddings): Embedding(30522, 768)\n", " (position_embeddings): Embedding(512, 768)\n", " (token_type_embeddings): Embedding(2, 768)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", - " (encoder): BERTEncoder(\n", + " (encoder): BertEncoder(\n", " (layer): ModuleList(\n", - " (0): BERTLayer(\n", - " (attention): BERTAttention(\n", - " (self): BERTSelfAttention(\n", + " (0): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1)\n", " )\n", - " (output): BERTSelfOutput(\n", + " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (intermediate): BERTIntermediate(\n", + " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", - " (output): BERTOutput(\n", + " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (1): BERTLayer(\n", - " (attention): BERTAttention(\n", - " (self): BERTSelfAttention(\n", + " (1): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1)\n", " )\n", - " (output): BERTSelfOutput(\n", + " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (intermediate): BERTIntermediate(\n", + " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", - " (output): BERTOutput(\n", + " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (2): BERTLayer(\n", - " (attention): BERTAttention(\n", - " (self): BERTSelfAttention(\n", + " (2): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1)\n", " )\n", - " (output): BERTSelfOutput(\n", + " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (intermediate): BERTIntermediate(\n", + " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", - " (output): BERTOutput(\n", + " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (3): BERTLayer(\n", - " (attention): BERTAttention(\n", - " (self): BERTSelfAttention(\n", + " (3): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1)\n", " )\n", - " (output): BERTSelfOutput(\n", + " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (intermediate): BERTIntermediate(\n", + " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", - " (output): BERTOutput(\n", + " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (4): BERTLayer(\n", - " (attention): BERTAttention(\n", - " (self): BERTSelfAttention(\n", + " (4): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1)\n", " )\n", - " (output): BERTSelfOutput(\n", + " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (intermediate): BERTIntermediate(\n", + " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", - " (output): BERTOutput(\n", + " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (5): BERTLayer(\n", - " (attention): BERTAttention(\n", - " (self): BERTSelfAttention(\n", + " (5): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1)\n", " )\n", - " (output): BERTSelfOutput(\n", + " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (intermediate): BERTIntermediate(\n", + " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", - " (output): BERTOutput(\n", + " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (6): BERTLayer(\n", - " (attention): BERTAttention(\n", - " (self): BERTSelfAttention(\n", + " (6): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1)\n", " )\n", - " (output): BERTSelfOutput(\n", + " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (intermediate): BERTIntermediate(\n", + " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", - " (output): BERTOutput(\n", + " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (7): BERTLayer(\n", - " (attention): BERTAttention(\n", - " (self): BERTSelfAttention(\n", + " (7): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1)\n", " )\n", - " (output): BERTSelfOutput(\n", + " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (intermediate): BERTIntermediate(\n", + " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", - " (output): BERTOutput(\n", + " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (8): BERTLayer(\n", - " (attention): BERTAttention(\n", - " (self): BERTSelfAttention(\n", + " (8): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1)\n", " )\n", - " (output): BERTSelfOutput(\n", + " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (intermediate): BERTIntermediate(\n", + " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", - " (output): BERTOutput(\n", + " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (9): BERTLayer(\n", - " (attention): BERTAttention(\n", - " (self): BERTSelfAttention(\n", + " (9): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1)\n", " )\n", - " (output): BERTSelfOutput(\n", + " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (intermediate): BERTIntermediate(\n", + " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", - " (output): BERTOutput(\n", + " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (10): BERTLayer(\n", - " (attention): BERTAttention(\n", - " (self): BERTSelfAttention(\n", + " (10): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1)\n", " )\n", - " (output): BERTSelfOutput(\n", + " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (intermediate): BERTIntermediate(\n", + " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", - " (output): BERTOutput(\n", + " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (11): BERTLayer(\n", - " (attention): BERTAttention(\n", - " (self): BERTSelfAttention(\n", + " (11): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1)\n", " )\n", - " (output): BERTSelfOutput(\n", + " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (intermediate): BERTIntermediate(\n", + " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", - " (output): BERTOutput(\n", + " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", " )\n", " )\n", - " (pooler): BERTPooler(\n", + " (pooler): BertPooler(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (activation): Tanh()\n", " )\n", ")" ] }, - "execution_count": 11, + "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "device = torch.device(\"cpu\")\n", - "model = extract_features.BertModel(bert_config)\n", - "model.load_state_dict(torch.load(init_checkpoint_pt, map_location='cpu'))\n", + "model = ppb.BertModel.from_pretrained(init_checkpoint_pt)\n", "model.to(device)" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 27, "metadata": { "ExecuteTime": { - "end_time": "2018-11-05T13:59:03.780145Z", - "start_time": "2018-11-05T13:59:03.742407Z" + "end_time": "2018-11-15T15:21:26.963427Z", + "start_time": "2018-11-15T15:21:26.922494Z" }, "code_folding": [] }, @@ -666,301 +716,301 @@ "data": { "text/plain": [ "BertModel(\n", - " (embeddings): BERTEmbeddings(\n", + " (embeddings): BertEmbeddings(\n", " (word_embeddings): Embedding(30522, 768)\n", " (position_embeddings): Embedding(512, 768)\n", " (token_type_embeddings): Embedding(2, 768)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", - " (encoder): BERTEncoder(\n", + " (encoder): BertEncoder(\n", " (layer): ModuleList(\n", - " (0): BERTLayer(\n", - " (attention): BERTAttention(\n", - " (self): BERTSelfAttention(\n", + " (0): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1)\n", " )\n", - " (output): BERTSelfOutput(\n", + " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (intermediate): BERTIntermediate(\n", + " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", - " (output): BERTOutput(\n", + " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (1): BERTLayer(\n", - " (attention): BERTAttention(\n", - " (self): BERTSelfAttention(\n", + " (1): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1)\n", " )\n", - " (output): BERTSelfOutput(\n", + " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (intermediate): BERTIntermediate(\n", + " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", - " (output): BERTOutput(\n", + " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (2): BERTLayer(\n", - " (attention): BERTAttention(\n", - " (self): BERTSelfAttention(\n", + " (2): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1)\n", " )\n", - " (output): BERTSelfOutput(\n", + " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (intermediate): BERTIntermediate(\n", + " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", - " (output): BERTOutput(\n", + " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (3): BERTLayer(\n", - " (attention): BERTAttention(\n", - " (self): BERTSelfAttention(\n", + " (3): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1)\n", " )\n", - " (output): BERTSelfOutput(\n", + " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (intermediate): BERTIntermediate(\n", + " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", - " (output): BERTOutput(\n", + " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (4): BERTLayer(\n", - " (attention): BERTAttention(\n", - " (self): BERTSelfAttention(\n", + " (4): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1)\n", " )\n", - " (output): BERTSelfOutput(\n", + " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (intermediate): BERTIntermediate(\n", + " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", - " (output): BERTOutput(\n", + " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (5): BERTLayer(\n", - " (attention): BERTAttention(\n", - " (self): BERTSelfAttention(\n", + " (5): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1)\n", " )\n", - " (output): BERTSelfOutput(\n", + " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (intermediate): BERTIntermediate(\n", + " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", - " (output): BERTOutput(\n", + " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (6): BERTLayer(\n", - " (attention): BERTAttention(\n", - " (self): BERTSelfAttention(\n", + " (6): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1)\n", " )\n", - " (output): BERTSelfOutput(\n", + " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (intermediate): BERTIntermediate(\n", + " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", - " (output): BERTOutput(\n", + " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (7): BERTLayer(\n", - " (attention): BERTAttention(\n", - " (self): BERTSelfAttention(\n", + " (7): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1)\n", " )\n", - " (output): BERTSelfOutput(\n", + " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (intermediate): BERTIntermediate(\n", + " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", - " (output): BERTOutput(\n", + " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (8): BERTLayer(\n", - " (attention): BERTAttention(\n", - " (self): BERTSelfAttention(\n", + " (8): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1)\n", " )\n", - " (output): BERTSelfOutput(\n", + " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (intermediate): BERTIntermediate(\n", + " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", - " (output): BERTOutput(\n", + " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (9): BERTLayer(\n", - " (attention): BERTAttention(\n", - " (self): BERTSelfAttention(\n", + " (9): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1)\n", " )\n", - " (output): BERTSelfOutput(\n", + " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (intermediate): BERTIntermediate(\n", + " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", - " (output): BERTOutput(\n", + " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (10): BERTLayer(\n", - " (attention): BERTAttention(\n", - " (self): BERTSelfAttention(\n", + " (10): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1)\n", " )\n", - " (output): BERTSelfOutput(\n", + " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (intermediate): BERTIntermediate(\n", + " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", - " (output): BERTOutput(\n", + " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (11): BERTLayer(\n", - " (attention): BERTAttention(\n", - " (self): BERTSelfAttention(\n", + " (11): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1)\n", " )\n", - " (output): BERTSelfOutput(\n", + " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", - " (intermediate): BERTIntermediate(\n", + " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", - " (output): BERTOutput(\n", + " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", - " (LayerNorm): BERTLayerNorm()\n", + " (LayerNorm): BertLayerNorm()\n", " (dropout): Dropout(p=0.1)\n", " )\n", " )\n", " )\n", " )\n", - " (pooler): BERTPooler(\n", + " (pooler): BertPooler(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (activation): Tanh()\n", " )\n", ")" ] }, - "execution_count": 12, + "execution_count": 27, "metadata": {}, "output_type": "execute_result" } @@ -980,11 +1030,11 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 28, "metadata": { "ExecuteTime": { - "end_time": "2018-11-05T13:59:04.233844Z", - "start_time": "2018-11-05T13:59:03.782525Z" + "end_time": "2018-11-15T15:21:30.718724Z", + "start_time": "2018-11-15T15:21:30.329205Z" } }, "outputs": [ @@ -1068,11 +1118,11 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 29, "metadata": { "ExecuteTime": { - "end_time": "2018-11-05T13:59:04.278496Z", - "start_time": "2018-11-05T13:59:04.235703Z" + "end_time": "2018-11-15T15:21:35.703615Z", + "start_time": "2018-11-15T15:21:35.666150Z" } }, "outputs": [ @@ -1094,7 +1144,7 @@ "(128, 768)" ] }, - "execution_count": 14, + "execution_count": 29, "metadata": {}, "output_type": "execute_result" } @@ -1111,11 +1161,11 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 30, "metadata": { "ExecuteTime": { - "end_time": "2018-11-05T13:59:04.313952Z", - "start_time": "2018-11-05T13:59:04.280352Z" + "end_time": "2018-11-15T15:21:36.999073Z", + "start_time": "2018-11-15T15:21:36.966762Z" } }, "outputs": [ @@ -1136,11 +1186,11 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 31, "metadata": { "ExecuteTime": { - "end_time": "2018-11-05T13:59:04.350048Z", - "start_time": "2018-11-05T13:59:04.316003Z" + "end_time": "2018-11-15T15:21:37.936522Z", + "start_time": "2018-11-15T15:21:37.905269Z" } }, "outputs": [ @@ -1167,11 +1217,11 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 32, "metadata": { "ExecuteTime": { - "end_time": "2018-11-05T13:59:04.382430Z", - "start_time": "2018-11-05T13:59:04.351550Z" + "end_time": "2018-11-15T15:21:39.437137Z", + "start_time": "2018-11-15T15:21:39.406150Z" } }, "outputs": [], @@ -1181,11 +1231,11 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 33, "metadata": { "ExecuteTime": { - "end_time": "2018-11-05T13:59:04.428334Z", - "start_time": "2018-11-05T13:59:04.386070Z" + "end_time": "2018-11-15T15:21:40.181870Z", + "start_time": "2018-11-15T15:21:40.137023Z" } }, "outputs": [ diff --git a/notebooks/Comparing TF and PT models_MLM_NSP.ipynb b/notebooks/Comparing TF and PT models_MLM_NSP.ipynb new file mode 100644 index 0000000000..7b226e8371 --- /dev/null +++ b/notebooks/Comparing TF and PT models_MLM_NSP.ipynb @@ -0,0 +1,1268 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Comparing TensorFlow (original) and PyTorch models\n", + "\n", + "You can use this small notebook to check the conversion of the model's weights from the TensorFlow model to the PyTorch model. In the following, we compare the weights of the last layer on a simple example (in `input.txt`) but both models returns all the hidden layers so you can check every stage of the model.\n", + "\n", + "To run this notebook, follow these instructions:\n", + "- make sure that your Python environment has both TensorFlow and PyTorch installed,\n", + "- download the original TensorFlow implementation,\n", + "- download a pre-trained TensorFlow model as indicaded in the TensorFlow implementation readme,\n", + "- run the script `convert_tf_checkpoint_to_pytorch.py` as indicated in the `README` to convert the pre-trained TensorFlow model to PyTorch.\n", + "\n", + "If needed change the relative paths indicated in this notebook (at the beggining of Sections 1 and 2) to point to the relevent models and code." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "ExecuteTime": { + "end_time": "2018-11-05T13:58:50.559657Z", + "start_time": "2018-11-05T13:58:50.546096Z" + } + }, + "outputs": [], + "source": [ + "import os\n", + "os.chdir('../')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1/ TensorFlow code" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "ExecuteTime": { + "end_time": "2018-11-05T13:58:50.574455Z", + "start_time": "2018-11-05T13:58:50.561988Z" + } + }, + "outputs": [], + "source": [ + "original_tf_inplem_dir = \"./tensorflow_code/\"\n", + "model_dir = \"../google_models/uncased_L-12_H-768_A-12/\"\n", + "\n", + "vocab_file = model_dir + \"vocab.txt\"\n", + "bert_config_file = model_dir + \"bert_config.json\"\n", + "init_checkpoint = model_dir + \"bert_model.ckpt\"\n", + "\n", + "input_file = \"./samples/input.txt\"\n", + "max_seq_length = 128" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "ExecuteTime": { + "end_time": "2018-11-05T13:58:52.202531Z", + "start_time": "2018-11-05T13:58:50.576198Z" + } + }, + "outputs": [], + "source": [ + "import importlib.util\n", + "import sys\n", + "\n", + "spec = importlib.util.spec_from_file_location('*', original_tf_inplem_dir + '/extract_features.py')\n", + "module = importlib.util.module_from_spec(spec)\n", + "spec.loader.exec_module(module)\n", + "sys.modules['extract_features_tensorflow'] = module\n", + "\n", + "from extract_features_tensorflow import *" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "ExecuteTime": { + "end_time": "2018-11-05T13:58:52.325822Z", + "start_time": "2018-11-05T13:58:52.205361Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO:tensorflow:*** Example ***\n", + "INFO:tensorflow:unique_id: 0\n", + "INFO:tensorflow:tokens: [CLS] who was jim henson ? [SEP] jim henson was a puppet ##eer [SEP]\n", + "INFO:tensorflow:input_ids: 101 2040 2001 3958 27227 1029 102 3958 27227 2001 1037 13997 11510 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + "INFO:tensorflow:input_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + "INFO:tensorflow:input_type_ids: 0 0 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n" + ] + } + ], + "source": [ + "layer_indexes = list(range(12))\n", + "bert_config = modeling.BertConfig.from_json_file(bert_config_file)\n", + "tokenizer = tokenization.BertTokenizer(\n", + " vocab_file=vocab_file, do_lower_case=True)\n", + "examples = read_examples(input_file)\n", + "\n", + "features = convert_examples_to_features(\n", + " examples=examples, seq_length=max_seq_length, tokenizer=tokenizer)\n", + "unique_id_to_feature = {}\n", + "for feature in features:\n", + " unique_id_to_feature[feature.unique_id] = feature" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "ExecuteTime": { + "end_time": "2018-11-05T13:58:55.939938Z", + "start_time": "2018-11-05T13:58:52.330202Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WARNING:tensorflow:Estimator's model_fn (.model_fn at 0x12839dbf8>) includes params argument, but params are not passed to Estimator.\n", + "WARNING:tensorflow:Using temporary folder as model directory: /var/folders/yx/cw8n_njx3js5jksyw_qlp8p00000gn/T/tmpdbx_h23u\n", + "INFO:tensorflow:Using config: {'_model_dir': '/var/folders/yx/cw8n_njx3js5jksyw_qlp8p00000gn/T/tmpdbx_h23u', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true\n", + "graph_options {\n", + " rewrite_options {\n", + " meta_optimizer_iterations: ONE\n", + " }\n", + "}\n", + ", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_service': None, '_cluster_spec': , '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_tpu_config': TPUConfig(iterations_per_loop=2, num_shards=1, num_cores_per_replica=None, per_host_input_for_training=3, tpu_job_name=None, initial_infeed_sleep_secs=None, input_partition_dims=None), '_cluster': None}\n", + "WARNING:tensorflow:Setting TPUConfig.num_shards==1 is an unsupported behavior. Please fix as soon as possible (leaving num_shards as None.\n", + "INFO:tensorflow:_TPUContext: eval_on_tpu True\n", + "WARNING:tensorflow:eval_on_tpu ignored because use_tpu is False.\n" + ] + } + ], + "source": [ + "is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2\n", + "run_config = tf.contrib.tpu.RunConfig(\n", + " master=None,\n", + " tpu_config=tf.contrib.tpu.TPUConfig(\n", + " num_shards=1,\n", + " per_host_input_for_training=is_per_host))\n", + "\n", + "model_fn = model_fn_builder(\n", + " bert_config=bert_config,\n", + " init_checkpoint=init_checkpoint,\n", + " layer_indexes=layer_indexes,\n", + " use_tpu=False,\n", + " use_one_hot_embeddings=False)\n", + "\n", + "# If TPU is not available, this will fall back to normal Estimator on CPU\n", + "# or GPU.\n", + "estimator = tf.contrib.tpu.TPUEstimator(\n", + " use_tpu=False,\n", + " model_fn=model_fn,\n", + " config=run_config,\n", + " predict_batch_size=1)\n", + "\n", + "input_fn = input_fn_builder(\n", + " features=features, seq_length=max_seq_length)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "ExecuteTime": { + "end_time": "2018-11-05T13:59:01.717585Z", + "start_time": "2018-11-05T13:58:55.941869Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO:tensorflow:Could not find trained model in model_dir: /var/folders/yx/cw8n_njx3js5jksyw_qlp8p00000gn/T/tmpdbx_h23u, running initialization to predict.\n", + "INFO:tensorflow:Calling model_fn.\n", + "INFO:tensorflow:Running infer on CPU\n", + "INFO:tensorflow:Done calling model_fn.\n", + "INFO:tensorflow:Graph was finalized.\n", + "INFO:tensorflow:Running local_init_op.\n", + "INFO:tensorflow:Done running local_init_op.\n", + "extracting layer 0\n", + "extracting layer 1\n", + "extracting layer 2\n", + "extracting layer 3\n", + "extracting layer 4\n", + "extracting layer 5\n", + "extracting layer 6\n", + "extracting layer 7\n", + "extracting layer 8\n", + "extracting layer 9\n", + "extracting layer 10\n", + "extracting layer 11\n", + "INFO:tensorflow:prediction_loop marked as finished\n", + "INFO:tensorflow:prediction_loop marked as finished\n" + ] + } + ], + "source": [ + "tensorflow_all_out = []\n", + "for result in estimator.predict(input_fn, yield_single_examples=True):\n", + " unique_id = int(result[\"unique_id\"])\n", + " feature = unique_id_to_feature[unique_id]\n", + " output_json = collections.OrderedDict()\n", + " output_json[\"linex_index\"] = unique_id\n", + " tensorflow_all_out_features = []\n", + " # for (i, token) in enumerate(feature.tokens):\n", + " all_layers = []\n", + " for (j, layer_index) in enumerate(layer_indexes):\n", + " print(\"extracting layer {}\".format(j))\n", + " layer_output = result[\"layer_output_%d\" % j]\n", + " layers = collections.OrderedDict()\n", + " layers[\"index\"] = layer_index\n", + " layers[\"values\"] = layer_output\n", + " all_layers.append(layers)\n", + " tensorflow_out_features = collections.OrderedDict()\n", + " tensorflow_out_features[\"layers\"] = all_layers\n", + " tensorflow_all_out_features.append(tensorflow_out_features)\n", + "\n", + " output_json[\"features\"] = tensorflow_all_out_features\n", + " tensorflow_all_out.append(output_json)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "ExecuteTime": { + "end_time": "2018-11-05T13:59:01.769845Z", + "start_time": "2018-11-05T13:59:01.719878Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1\n", + "2\n", + "odict_keys(['linex_index', 'features'])\n", + "number of tokens 1\n", + "number of layers 12\n" + ] + }, + { + "data": { + "text/plain": [ + "(128, 768)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "print(len(tensorflow_all_out))\n", + "print(len(tensorflow_all_out[0]))\n", + "print(tensorflow_all_out[0].keys())\n", + "print(\"number of tokens\", len(tensorflow_all_out[0]['features']))\n", + "print(\"number of layers\", len(tensorflow_all_out[0]['features'][0]['layers']))\n", + "tensorflow_all_out[0]['features'][0]['layers'][0]['values'].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "ExecuteTime": { + "end_time": "2018-11-05T13:59:01.807638Z", + "start_time": "2018-11-05T13:59:01.771422Z" + } + }, + "outputs": [], + "source": [ + "tensorflow_outputs = list(tensorflow_all_out[0]['features'][0]['layers'][t]['values'] for t in layer_indexes)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2/ PyTorch code" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "ExecuteTime": { + "end_time": "2018-11-05T13:59:02.020918Z", + "start_time": "2018-11-05T13:59:01.810061Z" + } + }, + "outputs": [], + "source": [ + "import extract_features\n", + "from extract_features import *" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "ExecuteTime": { + "end_time": "2018-11-05T13:59:02.058211Z", + "start_time": "2018-11-05T13:59:02.022785Z" + } + }, + "outputs": [], + "source": [ + "init_checkpoint_pt = \"../google_models/uncased_L-12_H-768_A-12/pytorch_model.bin\"" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "ExecuteTime": { + "end_time": "2018-11-05T13:59:03.740561Z", + "start_time": "2018-11-05T13:59:02.059877Z" + }, + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "BertModel(\n", + " (embeddings): BERTEmbeddings(\n", + " (word_embeddings): Embedding(30522, 768)\n", + " (position_embeddings): Embedding(512, 768)\n", + " (token_type_embeddings): Embedding(2, 768)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (encoder): BERTEncoder(\n", + " (layer): ModuleList(\n", + " (0): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (1): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (2): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (3): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (4): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (5): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (6): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (7): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (8): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (9): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (10): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (11): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (pooler): BERTPooler(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (activation): Tanh()\n", + " )\n", + ")" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "device = torch.device(\"cpu\")\n", + "model = extract_features.BertModel(bert_config)\n", + "model.load_state_dict(torch.load(init_checkpoint_pt, map_location='cpu'))\n", + "model.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "ExecuteTime": { + "end_time": "2018-11-05T13:59:03.780145Z", + "start_time": "2018-11-05T13:59:03.742407Z" + }, + "code_folding": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "BertModel(\n", + " (embeddings): BERTEmbeddings(\n", + " (word_embeddings): Embedding(30522, 768)\n", + " (position_embeddings): Embedding(512, 768)\n", + " (token_type_embeddings): Embedding(2, 768)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (encoder): BERTEncoder(\n", + " (layer): ModuleList(\n", + " (0): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (1): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (2): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (3): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (4): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (5): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (6): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (7): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (8): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (9): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (10): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (11): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (pooler): BERTPooler(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (activation): Tanh()\n", + " )\n", + ")" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)\n", + "all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)\n", + "all_input_type_ids = torch.tensor([f.input_type_ids for f in features], dtype=torch.long)\n", + "all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)\n", + "\n", + "eval_data = TensorDataset(all_input_ids, all_input_mask, all_input_type_ids, all_example_index)\n", + "eval_sampler = SequentialSampler(eval_data)\n", + "eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=1)\n", + "\n", + "model.eval()" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "ExecuteTime": { + "end_time": "2018-11-05T13:59:04.233844Z", + "start_time": "2018-11-05T13:59:03.782525Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[ 101, 2040, 2001, 3958, 27227, 1029, 102, 3958, 27227, 2001,\n", + " 1037, 13997, 11510, 102, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0]])\n", + "tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0]])\n", + "tensor([0])\n", + "layer 0 0\n", + "layer 1 1\n", + "layer 2 2\n", + "layer 3 3\n", + "layer 4 4\n", + "layer 5 5\n", + "layer 6 6\n", + "layer 7 7\n", + "layer 8 8\n", + "layer 9 9\n", + "layer 10 10\n", + "layer 11 11\n" + ] + } + ], + "source": [ + "layer_indexes = list(range(12))\n", + "\n", + "pytorch_all_out = []\n", + "for input_ids, input_mask, input_type_ids, example_indices in eval_dataloader:\n", + " print(input_ids)\n", + " print(input_mask)\n", + " print(example_indices)\n", + " input_ids = input_ids.to(device)\n", + " input_mask = input_mask.to(device)\n", + "\n", + " all_encoder_layers, _ = model(input_ids, token_type_ids=input_type_ids, attention_mask=input_mask)\n", + "\n", + " for b, example_index in enumerate(example_indices):\n", + " feature = features[example_index.item()]\n", + " unique_id = int(feature.unique_id)\n", + " # feature = unique_id_to_feature[unique_id]\n", + " output_json = collections.OrderedDict()\n", + " output_json[\"linex_index\"] = unique_id\n", + " all_out_features = []\n", + " # for (i, token) in enumerate(feature.tokens):\n", + " all_layers = []\n", + " for (j, layer_index) in enumerate(layer_indexes):\n", + " print(\"layer\", j, layer_index)\n", + " layer_output = all_encoder_layers[int(layer_index)].detach().cpu().numpy()\n", + " layer_output = layer_output[b]\n", + " layers = collections.OrderedDict()\n", + " layers[\"index\"] = layer_index\n", + " layer_output = layer_output\n", + " layers[\"values\"] = layer_output if not isinstance(layer_output, (int, float)) else [layer_output]\n", + " all_layers.append(layers)\n", + "\n", + " out_features = collections.OrderedDict()\n", + " out_features[\"layers\"] = all_layers\n", + " all_out_features.append(out_features)\n", + " output_json[\"features\"] = all_out_features\n", + " pytorch_all_out.append(output_json)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "ExecuteTime": { + "end_time": "2018-11-05T13:59:04.278496Z", + "start_time": "2018-11-05T13:59:04.235703Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1\n", + "2\n", + "odict_keys(['linex_index', 'features'])\n", + "number of tokens 1\n", + "number of layers 12\n", + "hidden_size 128\n" + ] + }, + { + "data": { + "text/plain": [ + "(128, 768)" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "print(len(pytorch_all_out))\n", + "print(len(pytorch_all_out[0]))\n", + "print(pytorch_all_out[0].keys())\n", + "print(\"number of tokens\", len(pytorch_all_out))\n", + "print(\"number of layers\", len(pytorch_all_out[0]['features'][0]['layers']))\n", + "print(\"hidden_size\", len(pytorch_all_out[0]['features'][0]['layers'][0]['values']))\n", + "pytorch_all_out[0]['features'][0]['layers'][0]['values'].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "ExecuteTime": { + "end_time": "2018-11-05T13:59:04.313952Z", + "start_time": "2018-11-05T13:59:04.280352Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(128, 768)\n", + "(128, 768)\n" + ] + } + ], + "source": [ + "pytorch_outputs = list(pytorch_all_out[0]['features'][0]['layers'][t]['values'] for t in layer_indexes)\n", + "print(pytorch_outputs[0].shape)\n", + "print(pytorch_outputs[1].shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "ExecuteTime": { + "end_time": "2018-11-05T13:59:04.350048Z", + "start_time": "2018-11-05T13:59:04.316003Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(128, 768)\n", + "(128, 768)\n" + ] + } + ], + "source": [ + "print(tensorflow_outputs[0].shape)\n", + "print(tensorflow_outputs[1].shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3/ Comparing the standard deviation on the last layer of both models" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "ExecuteTime": { + "end_time": "2018-11-05T13:59:04.382430Z", + "start_time": "2018-11-05T13:59:04.351550Z" + } + }, + "outputs": [], + "source": [ + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "ExecuteTime": { + "end_time": "2018-11-05T13:59:04.428334Z", + "start_time": "2018-11-05T13:59:04.386070Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "shape tensorflow layer, shape pytorch layer, standard deviation\n", + "((128, 768), (128, 768), 1.5258875e-07)\n", + "((128, 768), (128, 768), 2.342731e-07)\n", + "((128, 768), (128, 768), 2.801949e-07)\n", + "((128, 768), (128, 768), 3.5904986e-07)\n", + "((128, 768), (128, 768), 4.2842768e-07)\n", + "((128, 768), (128, 768), 5.127951e-07)\n", + "((128, 768), (128, 768), 6.14668e-07)\n", + "((128, 768), (128, 768), 7.063922e-07)\n", + "((128, 768), (128, 768), 7.906173e-07)\n", + "((128, 768), (128, 768), 8.475192e-07)\n", + "((128, 768), (128, 768), 8.975489e-07)\n", + "((128, 768), (128, 768), 4.1671223e-07)\n" + ] + } + ], + "source": [ + "print('shape tensorflow layer, shape pytorch layer, standard deviation')\n", + "print('\\n'.join(list(str((np.array(tensorflow_outputs[i]).shape,\n", + " np.array(pytorch_outputs[i]).shape, \n", + " np.sqrt(np.mean((np.array(tensorflow_outputs[i]) - np.array(pytorch_outputs[i]))**2.0)))) for i in range(12))))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "hide_input": false, + "kernelspec": { + "display_name": "Python [default]", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.7" + }, + "toc": { + "colors": { + "hover_highlight": "#DAA520", + "running_highlight": "#FF0000", + "selected_highlight": "#FFD700" + }, + "moveMenuLeft": true, + "nav_menu": { + "height": "48px", + "width": "252px" + }, + "navigate_menu": true, + "number_sections": true, + "sideBar": true, + "threshold": 4, + "toc_cell": false, + "toc_section_display": "block", + "toc_window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pytorch_pretrained_bert/__init__.py b/pytorch_pretrained_bert/__init__.py new file mode 100644 index 0000000000..066bd7830e --- /dev/null +++ b/pytorch_pretrained_bert/__init__.py @@ -0,0 +1,5 @@ +from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer +from .modeling import (BertConfig, BertModel, BertForPreTraining, + BertForMaskedLM, BertForNextSentencePrediction, + BertForSequenceClassification, BertForQuestionAnswering) +from .optimization import BERTAdam diff --git a/pytorch_pretrained_bert/__main__.py b/pytorch_pretrained_bert/__main__.py new file mode 100644 index 0000000000..73f1909b43 --- /dev/null +++ b/pytorch_pretrained_bert/__main__.py @@ -0,0 +1,19 @@ +# coding: utf8 +if __name__ == '__main__': + import sys + try: + from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch + except ModuleNotFoundError: + print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " + "In that case, it requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions.") + raise + + if len(sys.argv) != 5: + # pylint: disable=line-too-long + print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`") + else: + PYTORCH_DUMP_OUTPUT = sys.argv.pop() + TF_CONFIG = sys.argv.pop() + TF_CHECKPOINT = sys.argv.pop() + convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) diff --git a/convert_tf_checkpoint_to_pytorch.py b/pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py similarity index 50% rename from convert_tf_checkpoint_to_pytorch.py rename to pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py index eeebb3728e..20fdd8c0d6 100755 --- a/convert_tf_checkpoint_to_pytorch.py +++ b/pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py @@ -18,66 +18,39 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import re import argparse import tensorflow as tf import torch import numpy as np -from modeling import BertConfig, BertModel - -parser = argparse.ArgumentParser() - -## Required parameters -parser.add_argument("--tf_checkpoint_path", - default = None, - type = str, - required = True, - help = "Path the TensorFlow checkpoint path.") -parser.add_argument("--bert_config_file", - default = None, - type = str, - required = True, - help = "The config json file corresponding to the pre-trained BERT model. \n" - "This specifies the model architecture.") -parser.add_argument("--pytorch_dump_path", - default = None, - type = str, - required = True, - help = "Path to the output PyTorch model.") - -args = parser.parse_args() - -def convert(): - # Initialise PyTorch model - config = BertConfig.from_json_file(args.bert_config_file) - model = BertModel(config) +from .modeling import BertConfig, BertForPreTraining +def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): + config_path = os.path.abspath(bert_config_file) + tf_path = os.path.abspath(tf_checkpoint_path) + print("Converting TensorFlow checkpoint from {} with config at {}".format(tf_path, config_path)) # Load weights from TF model - path = args.tf_checkpoint_path - print("Converting TensorFlow checkpoint from {}".format(path)) - - init_vars = tf.train.list_variables(path) + init_vars = tf.train.list_variables(tf_path) names = [] arrays = [] for name, shape in init_vars: - print("Loading {} with shape {}".format(name, shape)) - array = tf.train.load_variable(path, name) - print("Numpy array shape {}".format(array.shape)) + print("Loading TF weight {} with shape {}".format(name, shape)) + array = tf.train.load_variable(tf_path, name) names.append(name) arrays.append(array) + # Initialise PyTorch model + config = BertConfig.from_json_file(bert_config_file) + print("Building PyTorch model from configuration: {}".format(str(config))) + model = BertForPreTraining(config) + for name, array in zip(names, arrays): - if not name.startswith("bert"): - print("Skipping {}".format(name)) - continue - else: - name = name.replace("bert/", "") # skip "bert/" - print("Loading {}".format(name)) name = name.split('/') # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v # which are not required for using pretrained model - if name[0] in ['redictions', 'eq_relationship'] or name[-1] == "adam_v" or name[-1] == "adam_m": + if name[-1] in ["adam_v", "adam_m"]: print("Skipping {}".format("/".join(name))) continue pointer = model @@ -88,6 +61,10 @@ def convert(): l = [m_name] if l[0] == 'kernel': pointer = getattr(pointer, 'weight') + elif l[0] == 'output_bias': + pointer = getattr(pointer, 'bias') + elif l[0] == 'output_weights': + pointer = getattr(pointer, 'weight') else: pointer = getattr(pointer, l[0]) if len(l) >= 2: @@ -102,10 +79,34 @@ def convert(): except AssertionError as e: e.args += (pointer.shape, array.shape) raise + print("Initialize PyTorch weight {}".format(name)) pointer.data = torch.from_numpy(array) # Save pytorch-model - torch.save(model.state_dict(), args.pytorch_dump_path) + print("Save PyTorch model to {}".format(pytorch_dump_path)) + torch.save(model.state_dict(), pytorch_dump_path) + if __name__ == "__main__": - convert() + parser = argparse.ArgumentParser() + ## Required parameters + parser.add_argument("--tf_checkpoint_path", + default = None, + type = str, + required = True, + help = "Path the TensorFlow checkpoint path.") + parser.add_argument("--bert_config_file", + default = None, + type = str, + required = True, + help = "The config json file corresponding to the pre-trained BERT model. \n" + "This specifies the model architecture.") + parser.add_argument("--pytorch_dump_path", + default = None, + type = str, + required = True, + help = "Path to the output PyTorch model.") + args = parser.parse_args() + convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, + args.bert_config_file, + args.pytorch_dump_path) diff --git a/pytorch_pretrained_bert/file_utils.py b/pytorch_pretrained_bert/file_utils.py new file mode 100644 index 0000000000..f734b7e22b --- /dev/null +++ b/pytorch_pretrained_bert/file_utils.py @@ -0,0 +1,233 @@ +""" +Utilities for working with the local dataset cache. +This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp +Copyright by the AllenNLP authors. +""" + +import os +import logging +import shutil +import tempfile +import json +from urllib.parse import urlparse +from pathlib import Path +from typing import Optional, Tuple, Union, IO, Callable, Set +from hashlib import sha256 +from functools import wraps + +from tqdm import tqdm + +import boto3 +from botocore.exceptions import ClientError +import requests + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + +PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', + Path.home() / '.pytorch_pretrained_bert')) + + +def url_to_filename(url: str, etag: str = None) -> str: + """ + Convert `url` into a hashed filename in a repeatable way. + If `etag` is specified, append its hash to the url's, delimited + by a period. + """ + url_bytes = url.encode('utf-8') + url_hash = sha256(url_bytes) + filename = url_hash.hexdigest() + + if etag: + etag_bytes = etag.encode('utf-8') + etag_hash = sha256(etag_bytes) + filename += '.' + etag_hash.hexdigest() + + return filename + + +def filename_to_url(filename: str, cache_dir: str = None) -> Tuple[str, str]: + """ + Return the url and etag (which may be ``None``) stored for `filename`. + Raise ``FileNotFoundError`` if `filename` or its stored metadata do not exist. + """ + if cache_dir is None: + cache_dir = PYTORCH_PRETRAINED_BERT_CACHE + + cache_path = os.path.join(cache_dir, filename) + if not os.path.exists(cache_path): + raise FileNotFoundError("file {} not found".format(cache_path)) + + meta_path = cache_path + '.json' + if not os.path.exists(meta_path): + raise FileNotFoundError("file {} not found".format(meta_path)) + + with open(meta_path) as meta_file: + metadata = json.load(meta_file) + url = metadata['url'] + etag = metadata['etag'] + + return url, etag + + +def cached_path(url_or_filename: Union[str, Path], cache_dir: str = None) -> str: + """ + Given something that might be a URL (or might be a local path), + determine which. If it's a URL, download the file and cache it, and + return the path to the cached file. If it's already a local path, + make sure the file exists and then return the path. + """ + if cache_dir is None: + cache_dir = PYTORCH_PRETRAINED_BERT_CACHE + if isinstance(url_or_filename, Path): + url_or_filename = str(url_or_filename) + + parsed = urlparse(url_or_filename) + + if parsed.scheme in ('http', 'https', 's3'): + # URL, so get it from the cache (downloading if necessary) + return get_from_cache(url_or_filename, cache_dir) + elif os.path.exists(url_or_filename): + # File, and it exists. + return url_or_filename + elif parsed.scheme == '': + # File, but it doesn't exist. + raise FileNotFoundError("file {} not found".format(url_or_filename)) + else: + # Something unknown + raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) + + +def split_s3_path(url: str) -> Tuple[str, str]: + """Split a full s3 path into the bucket name and path.""" + parsed = urlparse(url) + if not parsed.netloc or not parsed.path: + raise ValueError("bad s3 path {}".format(url)) + bucket_name = parsed.netloc + s3_path = parsed.path + # Remove '/' at beginning of path. + if s3_path.startswith("/"): + s3_path = s3_path[1:] + return bucket_name, s3_path + + +def s3_request(func: Callable): + """ + Wrapper function for s3 requests in order to create more helpful error + messages. + """ + + @wraps(func) + def wrapper(url: str, *args, **kwargs): + try: + return func(url, *args, **kwargs) + except ClientError as exc: + if int(exc.response["Error"]["Code"]) == 404: + raise FileNotFoundError("file {} not found".format(url)) + else: + raise + + return wrapper + + +@s3_request +def s3_etag(url: str) -> Optional[str]: + """Check ETag on S3 object.""" + s3_resource = boto3.resource("s3") + bucket_name, s3_path = split_s3_path(url) + s3_object = s3_resource.Object(bucket_name, s3_path) + return s3_object.e_tag + + +@s3_request +def s3_get(url: str, temp_file: IO) -> None: + """Pull a file directly from S3.""" + s3_resource = boto3.resource("s3") + bucket_name, s3_path = split_s3_path(url) + s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) + + +def http_get(url: str, temp_file: IO) -> None: + req = requests.get(url, stream=True) + content_length = req.headers.get('Content-Length') + total = int(content_length) if content_length is not None else None + progress = tqdm(unit="B", total=total) + for chunk in req.iter_content(chunk_size=1024): + if chunk: # filter out keep-alive new chunks + progress.update(len(chunk)) + temp_file.write(chunk) + progress.close() + + +def get_from_cache(url: str, cache_dir: str = None) -> str: + """ + Given a URL, look for the corresponding dataset in the local cache. + If it's not there, download it. Then return the path to the cached file. + """ + if cache_dir is None: + cache_dir = PYTORCH_PRETRAINED_BERT_CACHE + + os.makedirs(cache_dir, exist_ok=True) + + # Get eTag to add to filename, if it exists. + if url.startswith("s3://"): + etag = s3_etag(url) + else: + response = requests.head(url, allow_redirects=True) + if response.status_code != 200: + raise IOError("HEAD request failed for url {} with status code {}" + .format(url, response.status_code)) + etag = response.headers.get("ETag") + + filename = url_to_filename(url, etag) + + # get cache path to put the file + cache_path = os.path.join(cache_dir, filename) + + if not os.path.exists(cache_path): + # Download to temporary file, then copy to cache dir once finished. + # Otherwise you get corrupt cache entries if the download gets interrupted. + with tempfile.NamedTemporaryFile() as temp_file: + logger.info("%s not found in cache, downloading to %s", url, temp_file.name) + + # GET file object + if url.startswith("s3://"): + s3_get(url, temp_file) + else: + http_get(url, temp_file) + + # we are copying the file before closing it, so flush to avoid truncation + temp_file.flush() + # shutil.copyfileobj() starts at the current position, so go to the start + temp_file.seek(0) + + logger.info("copying %s to cache at %s", temp_file.name, cache_path) + with open(cache_path, 'wb') as cache_file: + shutil.copyfileobj(temp_file, cache_file) + + logger.info("creating metadata file for %s", cache_path) + meta = {'url': url, 'etag': etag} + meta_path = cache_path + '.json' + with open(meta_path, 'w') as meta_file: + json.dump(meta, meta_file) + + logger.info("removing temp file %s", temp_file.name) + + return cache_path + + +def read_set_from_file(filename: str) -> Set[str]: + ''' + Extract a de-duped collection (set) of text from a file. + Expected file format is one item per line. + ''' + collection = set() + with open(filename, 'r') as file_: + for line in file_: + collection.add(line.rstrip()) + return collection + + +def get_file_extension(path: str, dot=True, lower: bool = True): + ext = os.path.splitext(path)[1] + ext = ext if dot else ext[1:] + return ext.lower() if lower else ext diff --git a/pytorch_pretrained_bert/modeling.py b/pytorch_pretrained_bert/modeling.py new file mode 100644 index 0000000000..9ef592d0dc --- /dev/null +++ b/pytorch_pretrained_bert/modeling.py @@ -0,0 +1,964 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BERT model.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import copy +import json +import math +import logging +import tarfile +import tempfile +import shutil + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss + +from .file_utils import cached_path + +logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt = '%m/%d/%Y %H:%M:%S', + level = logging.INFO) +logger = logging.getLogger(__name__) + +PRETRAINED_MODEL_ARCHIVE_MAP = { + 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", + 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", + 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", + 'bert-base-multilingual': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual.tar.gz", + 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", +} +CONFIG_NAME = 'bert_config.json' +WEIGHTS_NAME = 'pytorch_model.bin' + +def gelu(x): + """Implementation of the gelu activation function. + For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): + 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + """ + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + + +def swish(x): + return x * torch.sigmoid(x) + + +ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} + + +class BertConfig(object): + """Configuration class to store the configuration of a `BertModel`. + """ + def __init__(self, + vocab_size_or_config_json_file, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02): + """Constructs BertConfig. + + Args: + vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. + hidden_size: Size of the encoder layers and the pooler layer. + num_hidden_layers: Number of hidden layers in the Transformer encoder. + num_attention_heads: Number of attention heads for each attention layer in + the Transformer encoder. + intermediate_size: The size of the "intermediate" (i.e., feed-forward) + layer in the Transformer encoder. + hidden_act: The non-linear activation function (function or string) in the + encoder and pooler. If string, "gelu", "relu" and "swish" are supported. + hidden_dropout_prob: The dropout probabilitiy for all fully connected + layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob: The dropout ratio for the attention + probabilities. + max_position_embeddings: The maximum sequence length that this model might + ever be used with. Typically set this to something large just in case + (e.g., 512 or 1024 or 2048). + type_vocab_size: The vocabulary size of the `token_type_ids` passed into + `BertModel`. + initializer_range: The sttdev of the truncated_normal_initializer for + initializing all weight matrices. + """ + if isinstance(vocab_size_or_config_json_file, str): + with open(vocab_size_or_config_json_file, "r") as reader: + json_config = json.loads(reader.read()) + for key, value in json_config.items(): + self.__dict__[key] = value + elif isinstance(vocab_size_or_config_json_file, int): + self.vocab_size = vocab_size_or_config_json_file + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + else: + raise ValueError("First argument must be either a vocabulary size (int)" + "or the path to a pretrained model config file (str)") + + @classmethod + def from_dict(cls, json_object): + """Constructs a `BertConfig` from a Python dictionary of parameters.""" + config = BertConfig(vocab_size_or_config_json_file=-1) + for key, value in json_object.items(): + config.__dict__[key] = value + return config + + @classmethod + def from_json_file(cls, json_file): + """Constructs a `BertConfig` from a json file of parameters.""" + with open(json_file, "r") as reader: + text = reader.read() + return cls.from_dict(json.loads(text)) + + def __repr__(self): + return str(self.to_json_string()) + + def to_dict(self): + """Serializes this instance to a Python dictionary.""" + output = copy.deepcopy(self.__dict__) + return output + + def to_json_string(self): + """Serializes this instance to a JSON string.""" + return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" + + +class BertLayerNorm(nn.Module): + def __init__(self, config, variance_epsilon=1e-12): + """Construct a layernorm module in the TF style (epsilon inside the square root). + """ + super(BertLayerNorm, self).__init__() + self.gamma = nn.Parameter(torch.ones(config.hidden_size)) + self.beta = nn.Parameter(torch.zeros(config.hidden_size)) + self.variance_epsilon = variance_epsilon + + def forward(self, x): + u = x.mean(-1, keepdim=True) + s = (x - u).pow(2).mean(-1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.variance_epsilon) + return self.gamma * x + self.beta + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings. + """ + def __init__(self, config): + super(BertEmbeddings, self).__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = BertLayerNorm(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, input_ids, token_type_ids=None): + seq_length = input_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = words_embeddings + position_embeddings + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config): + super(BertSelfAttention, self).__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads)) + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, attention_mask): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + return context_layer + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super(BertSelfOutput, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = BertLayerNorm(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config): + super(BertAttention, self).__init__() + self.self = BertSelfAttention(config) + self.output = BertSelfOutput(config) + + def forward(self, input_tensor, attention_mask): + self_output = self.self(input_tensor, attention_mask) + attention_output = self.output(self_output, input_tensor) + return attention_output + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super(BertIntermediate, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + self.intermediate_act_fn = ACT2FN[config.hidden_act] \ + if isinstance(config.hidden_act, str) else config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config): + super(BertOutput, self).__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = BertLayerNorm(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config): + super(BertLayer, self).__init__() + self.attention = BertAttention(config) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward(self, hidden_states, attention_mask): + attention_output = self.attention(hidden_states, attention_mask) + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + def __init__(self, config): + super(BertEncoder, self).__init__() + layer = BertLayer(config) + self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) + + def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): + all_encoder_layers = [] + for layer_module in self.layer: + hidden_states = layer_module(hidden_states, attention_mask) + if output_all_encoded_layers: + all_encoder_layers.append(hidden_states) + if not output_all_encoded_layers: + all_encoder_layers.append(hidden_states) + return all_encoder_layers + + +class BertPooler(nn.Module): + def __init__(self, config): + super(BertPooler, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super(BertPredictionHeadTransform, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.transform_act_fn = ACT2FN[config.hidden_act] \ + if isinstance(config.hidden_act, str) else config.hidden_act + self.LayerNorm = BertLayerNorm(config) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config, bert_model_embedding_weights): + super(BertLMPredictionHead, self).__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(bert_model_embedding_weights.size(1), + bert_model_embedding_weights.size(0), + bias=False) + self.decoder.weight = bert_model_embedding_weights + self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0))) + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + self.bias + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config, bert_model_embedding_weights): + super(BertOnlyMLMHead, self).__init__() + self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertOnlyNSPHead(nn.Module): + def __init__(self, config): + super(BertOnlyNSPHead, self).__init__() + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, pooled_output): + seq_relationship_score = self.seq_relationship(pooled_output) + return seq_relationship_score + + +class BertPreTrainingHeads(nn.Module): + def __init__(self, config, bert_model_embedding_weights): + super(BertPreTrainingHeads, self).__init__() + self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +class PreTrainedBertModel(nn.Module): + """ An abstract class to handle weights initialization and + a simple interface for dowloading and loading pretrained models. + """ + def __init__(self, config, *inputs, **kwargs): + super(PreTrainedBertModel, self).__init__() + if not isinstance(config, BertConfig): + raise ValueError( + "Parameter config in `{}(config)` should be an instance of class `BertConfig`. " + "To create a model from a Google pretrained model use " + "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( + self.__class__.__name__, self.__class__.__name__ + )) + self.config = config + + def init_bert_weights(self, module): + """ Initialize the weights. + """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, BertLayerNorm): + module.beta.data.normal_(mean=0.0, std=self.config.initializer_range) + module.gamma.data.normal_(mean=0.0, std=self.config.initializer_range) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + @classmethod + def from_pretrained(cls, pretrained_model_name, *inputs, **kwargs): + """ + Instantiate a PreTrainedBertModel from a pre-trained model file. + Download and cache the pre-trained model file if needed. + + Params: + pretrained_model_name: either: + - a str with the name of a pre-trained model to load selected in the list of: + . `bert-base-uncased` + . `bert-large-uncased` + . `bert-base-cased` + . `bert-base-multilingual` + . `bert-base-chinese` + - a path or url to a pretrained model archive containing: + . `bert_config.json` a configuration file for the model + . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance + *inputs, **kwargs: additional input for the specific Bert class + (ex: num_labels for BertForSequenceClassification) + """ + if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP: + archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name] + else: + archive_file = pretrained_model_name + # redirect to the cache, if necessary + try: + resolved_archive_file = cached_path(archive_file) + except FileNotFoundError: + logger.error( + "Model name '{}' was not found in model name list ({}). " + "We assumed '{}' was a path or url but couldn't find any file " + "associated to this path or url.".format( + pretrained_model_name, + ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), + pretrained_model_name)) + return None + if resolved_archive_file == archive_file: + logger.info("loading archive file {}".format(archive_file)) + else: + logger.info("loading archive file {} from cache at {}".format( + archive_file, resolved_archive_file)) + tempdir = None + if os.path.isdir(resolved_archive_file): + serialization_dir = resolved_archive_file + else: + # Extract archive to temp dir + tempdir = tempfile.mkdtemp() + logger.info("extracting archive file {} to temp dir {}".format( + resolved_archive_file, tempdir)) + with tarfile.open(resolved_archive_file, 'r:gz') as archive: + archive.extractall(tempdir) + serialization_dir = tempdir + # Load config + config_file = os.path.join(serialization_dir, CONFIG_NAME) + config = BertConfig.from_json_file(config_file) + logger.info("Model config {}".format(config)) + # Instantiate model. + model = cls(config, *inputs, **kwargs) + weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) + state_dict = torch.load(weights_path) + + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, '_metadata', None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + def load(module, prefix=''): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + module._load_from_state_dict( + state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + '.') + load(model, prefix='' if hasattr(model, 'bert') else 'bert.') + if len(missing_keys) > 0: + logger.info("Weights of {} not initialized from pretrained model: {}".format( + model.__class__.__name__, missing_keys)) + if len(unexpected_keys) > 0: + logger.info("Weights from pretrained model not used in {}: {}".format( + model.__class__.__name__, unexpected_keys)) + if tempdir: + # Clean up temp dir + shutil.rmtree(tempdir) + return model + + +class BertModel(PreTrainedBertModel): + """BERT model ("Bidirectional Embedding Representations from a Transformer"). + + Params: + config: a BertConfig class instance with the configuration to build a new model + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`. + + Outputs: Tuple of (encoded_layers, pooled_output) + `encoded_layers`: controled by `output_all_encoded_layers` argument: + - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end + of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each + encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], + - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding + to the last attention block, + `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a + classifier pretrained on top of the hidden state associated to the first character of the + input (`CLF`) to train on the Next-Sentence task (see BERT's paper). + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) + + config = modeling.BertConfig(vocab_size=32000, hidden_size=512, + num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) + + model = modeling.BertModel(config=config) + all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) + ``` + """ + def __init__(self, config): + super(BertModel, self).__init__(config) + self.embeddings = BertEmbeddings(config) + self.encoder = BertEncoder(config) + self.pooler = BertPooler(config) + self.apply(self.init_bert_weights) + + def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True): + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + embedding_output = self.embeddings(input_ids, token_type_ids) + encoded_layers = self.encoder(embedding_output, + extended_attention_mask, + output_all_encoded_layers=output_all_encoded_layers) + sequence_output = encoded_layers[-1] + pooled_output = self.pooler(sequence_output) + if not output_all_encoded_layers: + encoded_layers = encoded_layers[-1] + return encoded_layers, pooled_output + + +class BertForPreTraining(PreTrainedBertModel): + """BERT model with pre-training heads. + This module comprises the BERT model followed by the two pre-training heads: + - the masked language modeling head, and + - the next sentence classification head. + + Params: + config: a BertConfig class instance with the configuration to build a new model. + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] + with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss + is only computed for the labels set in [0, ..., vocab_size] + `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size] + with indices selected in [0, 1]. + 0 => next sentence is the continuation, 1 => next sentence is a random sentence. + + Outputs: + if `masked_lm_labels` and `next_sentence_label` are not `None`: + Outputs the total_loss which is the sum of the masked language modeling loss and the next + sentence classification loss. + if `masked_lm_labels` or `next_sentence_label` is `None`: + Outputs a tuple comprising + - the masked language modeling logits, and + - the next sentence classification logits. + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) + + config = BertConfig(vocab_size=32000, hidden_size=512, + num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) + + model = BertForPreTraining(config) + masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask) + ``` + """ + def __init__(self, config): + super(BertForPreTraining, self).__init__(config) + self.bert = BertModel(config) + self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight) + self.apply(self.init_bert_weights) + + def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None): + sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, + output_all_encoded_layers=False) + prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) + + if masked_lm_labels is not None and next_sentence_label is not None: + loss_fct = CrossEntropyLoss(ignore_index=-1) + masked_lm_loss = loss_fct(prediction_scores, masked_lm_labels) + next_sentence_loss = loss_fct(seq_relationship_score, next_sentence_label) + total_loss = masked_lm_loss + next_sentence_loss + return total_loss + else: + return prediction_scores, seq_relationship_score + + +class BertForMaskedLM(PreTrainedBertModel): + """BERT model with the masked language modeling head. + This module comprises the BERT model followed by the masked language modeling head. + + Params: + config: a BertConfig class instance with the configuration to build a new model. + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] + with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss + is only computed for the labels set in [0, ..., vocab_size] + + Outputs: + if `masked_lm_labels` is `None`: + Outputs the masked language modeling loss. + if `masked_lm_labels` is `None`: + Outputs the masked language modeling logits. + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) + + config = BertConfig(vocab_size=32000, hidden_size=512, + num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) + + model = BertForMaskedLM(config) + masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask) + ``` + """ + def __init__(self, config): + super(BertForMaskedLM, self).__init__(config) + self.bert = BertModel(config) + self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight) + self.apply(self.init_bert_weights) + + def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None): + sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, + output_all_encoded_layers=False) + prediction_scores = self.cls(sequence_output) + + if masked_lm_labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-1) + masked_lm_loss = loss_fct(prediction_scores, masked_lm_labels) + return masked_lm_loss + else: + return prediction_scores + + +class BertForNextSentencePrediction(PreTrainedBertModel): + """BERT model with next sentence prediction head. + This module comprises the BERT model followed by the next sentence classification head. + + Params: + config: a BertConfig class instance with the configuration to build a new model. + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size] + with indices selected in [0, 1]. + 0 => next sentence is the continuation, 1 => next sentence is a random sentence. + + Outputs: + if `next_sentence_label` is not `None`: + Outputs the total_loss which is the sum of the masked language modeling loss and the next + sentence classification loss. + if `next_sentence_label` is `None`: + Outputs the next sentence classification logits. + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) + + config = BertConfig(vocab_size=32000, hidden_size=512, + num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) + + model = BertForNextSentencePrediction(config) + seq_relationship_logits = model(input_ids, token_type_ids, input_mask) + ``` + """ + def __init__(self, config): + super(BertForNextSentencePrediction, self).__init__(config) + self.bert = BertModel(config) + self.cls = BertOnlyNSPHead(config) + self.apply(self.init_bert_weights) + + def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None): + _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, + output_all_encoded_layers=False) + seq_relationship_score = self.cls( pooled_output) + + if next_sentence_label is not None: + loss_fct = CrossEntropyLoss(ignore_index=-1) + next_sentence_loss = loss_fct(seq_relationship_score, next_sentence_label) + return next_sentence_loss + else: + return seq_relationship_score + + +class BertForSequenceClassification(PreTrainedBertModel): + """BERT model for classification. + This module is composed of the BERT model with a linear layer on top of + the pooled output. + + Params: + `config`: a BertConfig class instance with the configuration to build a new model. + `num_labels`: the number of classes for the classifier. Default = 2. + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] + with indices selected in [0, ..., num_labels]. + + Outputs: + if `labels` is not `None`: + Outputs the CrossEntropy classification loss of the output with the labels. + if `labels` is `None`: + Outputs the classification logits. + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) + + config = BertConfig(vocab_size=32000, hidden_size=512, + num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) + + num_labels = 2 + + model = BertForSequenceClassification(config, num_labels) + logits = model(input_ids, token_type_ids, input_mask) + ``` + """ + def __init__(self, config, num_labels=2): + super(BertForSequenceClassification, self).__init__(config) + self.bert = BertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, num_labels) + self.apply(self.init_bert_weights) + + def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): + _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits, labels) + return loss, logits + else: + return logits + + +class BertForQuestionAnswering(PreTrainedBertModel): + """BERT model for Question Answering (span extraction). + This module is composed of the BERT model with a linear layer on top of + the sequence output that computes start_logits and end_logits + + Params: + `config`: either + - a BertConfig class instance with the configuration to build a new model, or + - a str with the name of a pre-trained model to load selected in the list of: + . `bert-base-uncased` + . `bert-large-uncased` + . `bert-base-cased` + . `bert-base-multilingual` + . `bert-base-chinese` + The pre-trained model will be downloaded and cached if needed. + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size]. + Positions are clamped to the length of the sequence and position outside of the sequence are not taken + into account for computing the loss. + `end_positions`: position of the last token for the labeled span: torch.LongTensor of shape [batch_size]. + Positions are clamped to the length of the sequence and position outside of the sequence are not taken + into account for computing the loss. + + Outputs: + if `start_positions` and `end_positions` are not `None`: + Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions. + if `start_positions` or `end_positions` is `None`: + Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end + position tokens. + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) + + config = BertConfig(vocab_size=32000, hidden_size=512, + num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) + + model = BertForQuestionAnswering(config) + start_logits, end_logits = model(input_ids, token_type_ids, input_mask) + ``` + """ + def __init__(self, config): + super(BertForQuestionAnswering, self).__init__(config) + self.bert = BertModel(config) + # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version + # self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + self.apply(self.init_bert_weights) + + def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None): + sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions.clamp_(0, ignored_index) + end_positions.clamp_(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + return total_loss + else: + return start_logits, end_logits diff --git a/optimization.py b/pytorch_pretrained_bert/optimization.py similarity index 100% rename from optimization.py rename to pytorch_pretrained_bert/optimization.py diff --git a/tokenization.py b/pytorch_pretrained_bert/tokenization.py similarity index 73% rename from tokenization.py rename to pytorch_pretrained_bert/tokenization.py index 8cf83720d9..fab7b0049c 100644 --- a/tokenization.py +++ b/pytorch_pretrained_bert/tokenization.py @@ -20,27 +20,32 @@ from __future__ import print_function import collections import unicodedata -import six +import os +import logging +from .file_utils import cached_path + +logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt = '%m/%d/%Y %H:%M:%S', + level = logging.INFO) +logger = logging.getLogger(__name__) + +PRETRAINED_VOCAB_ARCHIVE_MAP = { + 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", + 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", + 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", + 'bert-base-multilingual': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-vocab.txt", + 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", +} def convert_to_unicode(text): """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" - if six.PY3: - if isinstance(text, str): - return text - elif isinstance(text, bytes): - return text.decode("utf-8", "ignore") - else: - raise ValueError("Unsupported string type: %s" % (type(text))) - elif six.PY2: - if isinstance(text, str): - return text.decode("utf-8", "ignore") - elif isinstance(text, unicode): - return text - else: - raise ValueError("Unsupported string type: %s" % (type(text))) + if isinstance(text, str): + return text + elif isinstance(text, bytes): + return text.decode("utf-8", "ignore") else: - raise ValueError("Not running on Python2 or Python 3?") + raise ValueError("Unsupported string type: %s" % (type(text))) def printable_text(text): @@ -48,22 +53,12 @@ def printable_text(text): # These functions want `str` for both Python2 and Python3, but in one case # it's a Unicode string and in the other it's a byte string. - if six.PY3: - if isinstance(text, str): - return text - elif isinstance(text, bytes): - return text.decode("utf-8", "ignore") - else: - raise ValueError("Unsupported string type: %s" % (type(text))) - elif six.PY2: - if isinstance(text, str): - return text - elif isinstance(text, unicode): - return text.encode("utf-8") - else: - raise ValueError("Unsupported string type: %s" % (type(text))) + if isinstance(text, str): + return text + elif isinstance(text, bytes): + return text.decode("utf-8", "ignore") else: - raise ValueError("Not running on Python2 or Python 3?") + raise ValueError("Unsupported string type: %s" % (type(text))) def load_vocab(vocab_file): @@ -81,14 +76,6 @@ def load_vocab(vocab_file): return vocab -def convert_tokens_to_ids(vocab, tokens): - """Converts a sequence of tokens into ids using the vocab.""" - ids = [] - for token in tokens: - ids.append(vocab[token]) - return ids - - def whitespace_tokenize(text): """Runs basic whitespace cleaning and splitting on a peice of text.""" text = text.strip() @@ -98,11 +85,16 @@ def whitespace_tokenize(text): return tokens -class FullTokenizer(object): - """Runs end-to-end tokenziation.""" - +class BertTokenizer(object): + """Runs end-to-end tokenization: punctuation splitting + wordpiece""" def __init__(self, vocab_file, do_lower_case=True): + if not os.path.isfile(vocab_file): + raise ValueError( + "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " + "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) self.vocab = load_vocab(vocab_file) + self.ids_to_tokens = collections.OrderedDict( + [(ids, tok) for tok, ids in self.vocab.items()]) self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) @@ -111,11 +103,52 @@ class FullTokenizer(object): for token in self.basic_tokenizer.tokenize(text): for sub_token in self.wordpiece_tokenizer.tokenize(token): split_tokens.append(sub_token) - return split_tokens def convert_tokens_to_ids(self, tokens): - return convert_tokens_to_ids(self.vocab, tokens) + """Converts a sequence of tokens into ids using the vocab.""" + ids = [] + for token in tokens: + ids.append(self.vocab[token]) + return ids + + def convert_ids_to_tokens(self, ids): + """Converts a sequence of ids in wordpiece tokens using the vocab.""" + tokens = [] + for i in ids: + tokens.append(self.ids_to_tokens[i]) + return tokens + + @classmethod + def from_pretrained(cls, pretrained_model_name, do_lower_case=True): + """ + Instantiate a PreTrainedBertModel from a pre-trained model file. + Download and cache the pre-trained model file if needed. + """ + if pretrained_model_name in PRETRAINED_VOCAB_ARCHIVE_MAP: + vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name] + else: + vocab_file = pretrained_model_name + # redirect to the cache, if necessary + try: + resolved_vocab_file = cached_path(vocab_file) + if resolved_vocab_file == vocab_file: + logger.info("loading vocabulary file {}".format(vocab_file)) + else: + logger.info("loading vocabulary file {} from cache at {}".format( + vocab_file, resolved_vocab_file)) + # Instantiate tokenizer. + tokenizer = cls(resolved_vocab_file, do_lower_case) + except FileNotFoundError: + logger.error( + "Model name '{}' was not found in model name list ({}). " + "We assumed '{}' was a path or url but couldn't find any file " + "associated to this path or url.".format( + pretrained_model_name, + ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), + pretrained_model_name)) + tokenizer = None + return tokenizer class BasicTokenizer(object): diff --git a/requirements.txt b/requirements.txt index 7d8fa561a2..0aeac31a31 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,9 @@ -torch -tqdm \ No newline at end of file +# This installs Pytorch for CUDA 8 only. If you are using a newer version, +# please visit http://pytorch.org/ and install the relevant version. +torch>=0.4.1,<0.5.0 +# progress bars in model download and training scripts +tqdm>=4.19 +# Accessing files from S3 directly. +boto3 +# Used for downloading models over HTTP +requests>=2.18 \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000..642d8e1c8c --- /dev/null +++ b/setup.py @@ -0,0 +1,31 @@ +from setuptools import find_packages, setup + +setup( + name="pytorch_pretrained_bert", + version="0.1.0", + author="Thomas Wolf, Victor Sanh, Tim Rault, Google AI Language Team Authors", + author_email="thomas@huggingface.co", + description="PyTorch version of Google AI BERT model with script to load Google pre-trained models", + long_description=open("README.md", "r").read(), + long_description_content_type="text/markdown", + keywords='BERT NLP deep learning google', + license='Apache', + url="https://github.com/huggingface/pytorch-pretrained-BERT", + packages=find_packages(exclude=["*.tests", "*.tests.*", + "tests.*", "tests"]), + install_requires=['numpy', + 'torch>=0.4.1', + 'boto3', + 'requests>=2.18', + 'tqdm>=4.19'], + scripts=["bin/pytorch_pretrained_bert"], + python_requires='>=3.5.0', + tests_require=['pytest'], + classifiers=[ + 'Intended Audience :: Science/Research', + 'Development Status :: 1 - Alpha', + 'License :: OSI Approved :: Apache Software License', + 'Programming Language :: Python :: 3', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + ], +) diff --git a/tests/tokenization_test.py b/tests/tokenization_test.py index 7c12ecccfe..fda1cdb243 100644 --- a/tests/tokenization_test.py +++ b/tests/tokenization_test.py @@ -34,7 +34,7 @@ class TokenizationTest(unittest.TestCase): vocab_file = vocab_writer.name - tokenizer = tokenization.FullTokenizer(vocab_file) + tokenizer = tokenization.BertTokenizer(vocab_file) os.remove(vocab_file) tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")