diff --git a/Comparing TF and PT models.ipynb b/Comparing TF and PT models.ipynb index 1148bac16e..e042bfc290 100644 --- a/Comparing TF and PT models.ipynb +++ b/Comparing TF and PT models.ipynb @@ -12,8 +12,8 @@ "execution_count": 1, "metadata": { "ExecuteTime": { - "end_time": "2018-11-02T13:05:56.692585Z", - "start_time": "2018-11-02T13:05:55.699169Z" + "end_time": "2018-11-02T14:09:09.239405Z", + "start_time": "2018-11-02T14:09:08.126668Z" } }, "outputs": [], @@ -23,11 +23,11 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 2, "metadata": { "ExecuteTime": { - "end_time": "2018-11-02T13:18:23.944585Z", - "start_time": "2018-11-02T13:18:23.821309Z" + "end_time": "2018-11-02T14:09:09.370511Z", + "start_time": "2018-11-02T14:09:09.242527Z" } }, "outputs": [ @@ -67,11 +67,11 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 3, "metadata": { "ExecuteTime": { - "end_time": "2018-11-02T13:18:24.802620Z", - "start_time": "2018-11-02T13:18:24.764474Z" + "end_time": "2018-11-02T14:09:12.514617Z", + "start_time": "2018-11-02T14:09:09.372137Z" } }, "outputs": [ @@ -79,15 +79,15 @@ "name": "stdout", "output_type": "stream", "text": [ - "WARNING:tensorflow:Estimator's model_fn (.model_fn at 0x128feb7b8>) includes params argument, but params are not passed to Estimator.\n", - "WARNING:tensorflow:Using temporary folder as model directory: /var/folders/yx/cw8n_njx3js5jksyw_qlp8p00000gn/T/tmpp9hntmfs\n", - "INFO:tensorflow:Using config: {'_model_dir': '/var/folders/yx/cw8n_njx3js5jksyw_qlp8p00000gn/T/tmpp9hntmfs', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true\n", + "WARNING:tensorflow:Estimator's model_fn (.model_fn at 0x12b266ae8>) includes params argument, but params are not passed to Estimator.\n", + "WARNING:tensorflow:Using temporary folder as model directory: /var/folders/yx/cw8n_njx3js5jksyw_qlp8p00000gn/T/tmphrjfnoqh\n", + "INFO:tensorflow:Using config: {'_model_dir': '/var/folders/yx/cw8n_njx3js5jksyw_qlp8p00000gn/T/tmphrjfnoqh', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true\n", "graph_options {\n", " rewrite_options {\n", " meta_optimizer_iterations: ONE\n", " }\n", "}\n", - ", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_service': None, '_cluster_spec': , '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_tpu_config': TPUConfig(iterations_per_loop=2, num_shards=1, num_cores_per_replica=None, per_host_input_for_training=3, tpu_job_name=None, initial_infeed_sleep_secs=None, input_partition_dims=None), '_cluster': None}\n", + ", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_service': None, '_cluster_spec': , '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_tpu_config': TPUConfig(iterations_per_loop=2, num_shards=1, num_cores_per_replica=None, per_host_input_for_training=3, tpu_job_name=None, initial_infeed_sleep_secs=None, input_partition_dims=None), '_cluster': None}\n", "WARNING:tensorflow:Setting TPUConfig.num_shards==1 is an unsupported behavior. Please fix as soon as possible (leaving num_shards as None.\n", "INFO:tensorflow:_TPUContext: eval_on_tpu True\n", "WARNING:tensorflow:eval_on_tpu ignored because use_tpu is False.\n" @@ -123,11 +123,11 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 4, "metadata": { "ExecuteTime": { - "end_time": "2018-11-02T13:19:20.060587Z", - "start_time": "2018-11-02T13:19:14.804525Z" + "end_time": "2018-11-02T14:09:17.745970Z", + "start_time": "2018-11-02T14:09:12.516953Z" } }, "outputs": [ @@ -135,7 +135,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "INFO:tensorflow:Could not find trained model in model_dir: /var/folders/yx/cw8n_njx3js5jksyw_qlp8p00000gn/T/tmpp9hntmfs, running initialization to predict.\n", + "INFO:tensorflow:Could not find trained model in model_dir: /var/folders/yx/cw8n_njx3js5jksyw_qlp8p00000gn/T/tmphrjfnoqh, running initialization to predict.\n", "INFO:tensorflow:Calling model_fn.\n", "INFO:tensorflow:Running infer on CPU\n", "INFO:tensorflow:Done calling model_fn.\n", @@ -154,7 +154,7 @@ " feature = unique_id_to_feature[unique_id]\n", " output_json = collections.OrderedDict()\n", " output_json[\"linex_index\"] = unique_id\n", - " all_features = []\n", + " all_out_features = []\n", " for (i, token) in enumerate(feature.tokens):\n", " all_layers = []\n", " for (j, layer_index) in enumerate(layer_indexes):\n", @@ -165,54 +165,75 @@ " round(float(x), 6) for x in layer_output[i:(i + 1)].flat\n", " ]\n", " all_layers.append(layers)\n", - " features = collections.OrderedDict()\n", - " features[\"token\"] = token\n", - " features[\"layers\"] = all_layers\n", - " all_features.append(features)\n", - " output_json[\"features\"] = all_features\n", + " out_features = collections.OrderedDict()\n", + " out_features[\"token\"] = token\n", + " out_features[\"layers\"] = all_layers\n", + " all_out_features.append(out_features)\n", + " output_json[\"features\"] = all_out_features\n", " all_out.append(output_json)" ] }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 5, "metadata": { "ExecuteTime": { - "end_time": "2018-11-02T13:22:39.694206Z", - "start_time": "2018-11-02T13:22:39.663432Z" + "end_time": "2018-11-02T14:09:17.780532Z", + "start_time": "2018-11-02T14:09:17.748778Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1\n", + "2\n", + "odict_keys(['linex_index', 'features'])\n", + "14\n" + ] + } + ], + "source": [ + "print(len(all_out))\n", + "print(len(all_out[0]))\n", + "print(all_out[0].keys())\n", + "print(len(all_out[0]['features']))" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "ExecuteTime": { + "end_time": "2018-11-02T14:09:17.818968Z", + "start_time": "2018-11-02T14:09:17.782121Z" } }, "outputs": [ { "data": { "text/plain": [ - "14" + "[-0.628111,\n", + " 0.193215,\n", + " -0.75185,\n", + " -0.040464,\n", + " -0.875331,\n", + " 0.15654,\n", + " 1.385444,\n", + " 1.066997,\n", + " -0.349549,\n", + " 0.270686]" ] }, - "execution_count": 32, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "len(all_out)\n", - "len(all_out[0])\n", - "all_out[0].keys()\n", - "len(all_out[0]['features'])" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "metadata": { - "ExecuteTime": { - "end_time": "2018-11-02T13:23:05.752981Z", - "start_time": "2018-11-02T13:23:05.723891Z" - } - }, - "outputs": [], - "source": [ - "tensorflow_output = all_out[0]['features'][0]['layers'][0]['values']" + "tensorflow_output = all_out[0]['features'][0]['layers'][0]['values']\n", + "tensorflow_output[:10]" ] }, { @@ -224,11 +245,11 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 7, "metadata": { "ExecuteTime": { - "end_time": "2018-11-02T13:24:27.644785Z", - "start_time": "2018-11-02T13:24:27.611996Z" + "end_time": "2018-11-02T14:09:17.954196Z", + "start_time": "2018-11-02T14:09:17.821115Z" } }, "outputs": [], @@ -236,6 +257,794 @@ "from extract_features_pytorch import *" ] }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "ExecuteTime": { + "end_time": "2018-11-02T14:09:19.196475Z", + "start_time": "2018-11-02T14:09:17.956199Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "BertModel(\n", + " (embeddings): BERTEmbeddings(\n", + " (word_embeddings): Embedding(30522, 768)\n", + " (position_embeddings): Embedding(512, 768)\n", + " (token_type_embeddings): Embedding(2, 768)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (encoder): BERTEncoder(\n", + " (layer): ModuleList(\n", + " (0): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (1): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (2): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (3): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (4): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (5): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (6): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (7): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (8): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (9): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (10): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (11): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (pooler): BERTPooler(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (activation): Tanh()\n", + " )\n", + ")" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "init_checkpoint_pt=\"/Users/thomaswolf/Documents/Thomas/Code/HF/BERT/google_models/uncased_L-12_H-768_A-12/pytorch_model.bin\"\n", + "\n", + "device = torch.device(\"cpu\")\n", + "model = BertModel(bert_config)\n", + "model.load_state_dict(torch.load(init_checkpoint_pt, map_location='cpu'))\n", + "model.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "ExecuteTime": { + "end_time": "2018-11-02T14:09:19.236256Z", + "start_time": "2018-11-02T14:09:19.198407Z" + }, + "code_folding": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "BertModel(\n", + " (embeddings): BERTEmbeddings(\n", + " (word_embeddings): Embedding(30522, 768)\n", + " (position_embeddings): Embedding(512, 768)\n", + " (token_type_embeddings): Embedding(2, 768)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (encoder): BERTEncoder(\n", + " (layer): ModuleList(\n", + " (0): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (1): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (2): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (3): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (4): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (5): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (6): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (7): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (8): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (9): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (10): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (11): BERTLayer(\n", + " (attention): BERTAttention(\n", + " (self): BERTSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " (output): BERTSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " (intermediate): BERTIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BERTOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): BERTLayerNorm()\n", + " (dropout): Dropout(p=0.1)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (pooler): BERTPooler(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (activation): Tanh()\n", + " )\n", + ")" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)\n", + "all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)\n", + "all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)\n", + "\n", + "eval_data = TensorDataset(all_input_ids, all_input_mask, all_example_index)\n", + "eval_sampler = SequentialSampler(eval_data)\n", + "eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=1)\n", + "\n", + "model.eval()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "ExecuteTime": { + "end_time": "2018-11-02T14:09:19.671994Z", + "start_time": "2018-11-02T14:09:19.239454Z" + } + }, + "outputs": [], + "source": [ + "pytorch_all_out = []\n", + "for input_ids, input_mask, example_indices in eval_dataloader:\n", + " input_ids = input_ids.to(device)\n", + " input_mask = input_mask.float().to(device)\n", + "\n", + " all_encoder_layers, _ = model(input_ids, token_type_ids=None, attention_mask=input_mask)\n", + "\n", + " for enc_layers, example_index in zip(all_encoder_layers, example_indices):\n", + " feature = features[example_index.item()]\n", + " unique_id = int(feature.unique_id)\n", + " # feature = unique_id_to_feature[unique_id]\n", + " output_json = collections.OrderedDict()\n", + " output_json[\"linex_index\"] = unique_id\n", + " all_out_features = []\n", + " for (i, token) in enumerate(feature.tokens):\n", + " all_layers = []\n", + " for (j, layer_index) in enumerate(layer_indexes):\n", + " layer_output = enc_layers[int(layer_index)].detach().cpu().numpy()\n", + " layers = collections.OrderedDict()\n", + " layers[\"index\"] = layer_index\n", + " layers[\"values\"] = [\n", + " round(float(x), 6) for x in layer_output[i:(i + 1)].flat\n", + " ]\n", + " all_layers.append(layers)\n", + " out_features = collections.OrderedDict()\n", + " out_features[\"token\"] = token\n", + " out_features[\"layers\"] = all_layers\n", + " all_out_features.append(out_features)\n", + " output_json[\"features\"] = all_out_features\n", + " pytorch_all_out.append(output_json)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "ExecuteTime": { + "end_time": "2018-11-02T14:09:19.706616Z", + "start_time": "2018-11-02T14:09:19.673670Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1\n", + "2\n", + "odict_keys(['linex_index', 'features'])\n", + "14\n" + ] + } + ], + "source": [ + "print(len(pytorch_all_out))\n", + "print(len(pytorch_all_out[0]))\n", + "print(pytorch_all_out[0].keys())\n", + "print(len(pytorch_all_out[0]['features']))" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "ExecuteTime": { + "end_time": "2018-11-02T14:10:28.295669Z", + "start_time": "2018-11-02T14:10:28.263140Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[-0.016153,\n", + " -0.697252,\n", + " -0.298296,\n", + " -0.167194,\n", + " -0.219306,\n", + " 0.061712,\n", + " -0.006953,\n", + " 0.366519,\n", + " -0.031027,\n", + " -0.33547]" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pytorch_output = pytorch_all_out[0]['features'][0]['layers'][0]['values']\n", + "pytorch_output[:10]" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "ExecuteTime": { + "end_time": "2018-11-02T14:10:34.540457Z", + "start_time": "2018-11-02T14:10:34.510109Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[-0.628111,\n", + " 0.193215,\n", + " -0.75185,\n", + " -0.040464,\n", + " -0.875331,\n", + " 0.15654,\n", + " 1.385444,\n", + " 1.066997,\n", + " -0.349549,\n", + " 0.270686]" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tensorflow_output[:10]" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/extract_features_pytorch.py b/extract_features_pytorch.py index 6f4f79e0f9..7596298cca 100644 --- a/extract_features_pytorch.py +++ b/extract_features_pytorch.py @@ -26,6 +26,7 @@ import json import re import tokenization +import torch from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler @@ -251,10 +252,9 @@ def main(): all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) - all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long) all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) - eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index) + eval_data = TensorDataset(all_input_ids, all_input_mask, all_example_index) if args.local_rank == -1: eval_sampler = SequentialSampler(eval_data) else: @@ -263,12 +263,11 @@ def main(): model.eval() with open(args.output_file, "w", encoding='utf-8') as writer: - for input_ids, input_mask, segment_ids, example_indices in eval_dataloader: + for input_ids, input_mask, example_indices in eval_dataloader: input_ids = input_ids.to(device) input_mask = input_mask.float().to(device) - segment_ids = segment_ids.to(device) - all_encoder_layers, _ = model(input_ids, segment_ids, input_mask) + all_encoder_layers, _ = model(input_ids, token_type_ids=None, attention_mask=input_mask) for enc_layers, example_index in zip(all_encoder_layers, example_indices): feature = features[example_index.item()] diff --git a/modeling_pytorch.py b/modeling_pytorch.py index ca6e49f24d..4a8514e3a0 100644 --- a/modeling_pytorch.py +++ b/modeling_pytorch.py @@ -377,12 +377,17 @@ class BertModel(nn.Module): self.encoder = BERTEncoder(config) self.pooler = BERTPooler(config) - def forward(self, input_ids, token_type_ids, attention_mask): + def forward(self, input_ids, token_type_ids=None, attention_mask=None): # We create 3D attention mask from a 2D tensor mask. # Sizes are [batch_size, 1, 1, from_seq_length] # So we can broadcast to [batch_size, num_heads, to_seq_length, from_seq_length] # It's more simple than the triangular masking of causal attention, just need to # prepare the broadcast here + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) attention_mask = (1.0 - attention_mask) * -10000.0