logging + update copyright
This commit is contained in:
@@ -1,4 +1,17 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
|
# Copyright 2018 The HugginFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
"""Convert BERT checkpoint."""
|
"""Convert BERT checkpoint."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2018 The Google AI Language Team Authors.
|
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -12,7 +12,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Extract pre-computed feature vectors from BERT."""
|
"""Extract pre-computed feature vectors from a PyTorch BERT model."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2018 The Google AI Language Team Authors.
|
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -12,7 +12,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Common utility functions related to TensorFlow."""
|
"""PyTorch BERT model."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
|
|||||||
@@ -1,3 +1,19 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""PyTorch optimization for BERT model."""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2018 The Google AI Language Team Authors.
|
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
|||||||
33
run_squad.py
33
run_squad.py
@@ -1,5 +1,5 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2018 The Google AI Language Team Authors.
|
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -720,22 +720,6 @@ def main():
|
|||||||
help="The maximum length of an answer that can be generated. This is needed because the start "
|
help="The maximum length of an answer that can be generated. This is needed because the start "
|
||||||
"and end predictions are not conditioned on one another.")
|
"and end predictions are not conditioned on one another.")
|
||||||
|
|
||||||
### BEGIN - TO DELETE EVENTUALLY --> NO SENSE IN PYTORCH ###
|
|
||||||
# parser.add_argument("--use_tpu", default=False, action='store_true', help="Whether to use TPU or GPU/CPU.")
|
|
||||||
# parser.add_argument("--tpu_name", default=None, type=str,
|
|
||||||
# help="The Cloud TPU to use for training. This should be either the name used when creating the "
|
|
||||||
# "Cloud TPU, or a grpc://ip.address.of.tpu:8470 url.")
|
|
||||||
# parser.add_argument("--tpu_zone", default=None, type=str,
|
|
||||||
# help="[Optional] GCE zone where the Cloud TPU is located in. If not specified, we will attempt "
|
|
||||||
# "to automatically detect the GCE project from metadata.")
|
|
||||||
# parser.add_argument("--gcp_project", default=None, type=str,
|
|
||||||
# help="[Optional] Project name for the Cloud TPU-enabled project. If not specified, we will attempt "
|
|
||||||
# "to automatically detect the GCE project from metadata.")
|
|
||||||
# parser.add_argument("--master", default=None, type=str, help="[Optional] TensorFlow master URL.")
|
|
||||||
# parser.add_argument("--num_tpu_cores", default=8, type=int, help="Only used if `use_tpu` is True. "
|
|
||||||
# "Total number of TPU cores to use.")
|
|
||||||
### END - TO DELETE EVENTUALLY --> NO SENSE IN PYTORCH ###
|
|
||||||
|
|
||||||
parser.add_argument("--verbose_logging", default=False, action='store_true',
|
parser.add_argument("--verbose_logging", default=False, action='store_true',
|
||||||
help="If true, all of the warnings related to data processing will be printed. "
|
help="If true, all of the warnings related to data processing will be printed. "
|
||||||
"A number of warnings are expected for a normal SQuAD evaluation.")
|
"A number of warnings are expected for a normal SQuAD evaluation.")
|
||||||
@@ -859,10 +843,10 @@ def main():
|
|||||||
segment_ids = segment_ids.to(device)
|
segment_ids = segment_ids.to(device)
|
||||||
start_positions = start_positions.to(device)
|
start_positions = start_positions.to(device)
|
||||||
end_positions = start_positions.to(device)
|
end_positions = start_positions.to(device)
|
||||||
|
|
||||||
start_positions = start_positions.view(-1, 1)
|
start_positions = start_positions.view(-1, 1)
|
||||||
end_positions = end_positions.view(-1, 1)
|
end_positions = end_positions.view(-1, 1)
|
||||||
|
|
||||||
loss, _ = model(input_ids, segment_ids, input_mask, start_positions, end_positions)
|
loss, _ = model(input_ids, segment_ids, input_mask, start_positions, end_positions)
|
||||||
if n_gpu > 1:
|
if n_gpu > 1:
|
||||||
loss = loss.mean() # mean() to average on multi-gpu.
|
loss = loss.mean() # mean() to average on multi-gpu.
|
||||||
@@ -871,7 +855,6 @@ def main():
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
global_step += 1
|
global_step += 1
|
||||||
logger.info("Done %s steps", global_step)
|
|
||||||
|
|
||||||
if args.do_predict:
|
if args.do_predict:
|
||||||
eval_examples = read_squad_examples(
|
eval_examples = read_squad_examples(
|
||||||
@@ -892,10 +875,8 @@ def main():
|
|||||||
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
|
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
|
||||||
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
|
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
|
||||||
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
|
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
|
||||||
#all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
|
|
||||||
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
|
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
|
||||||
|
|
||||||
#eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_example_index)
|
|
||||||
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
|
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
|
||||||
if args.local_rank == -1:
|
if args.local_rank == -1:
|
||||||
eval_sampler = SequentialSampler(eval_data)
|
eval_sampler = SequentialSampler(eval_data)
|
||||||
@@ -906,7 +887,6 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
all_results = []
|
all_results = []
|
||||||
logger.info("Start evaluating")
|
logger.info("Start evaluating")
|
||||||
#for input_ids, input_mask, segment_ids, label_ids, example_index in eval_dataloader:
|
|
||||||
for input_ids, input_mask, segment_ids, example_index in tqdm(eval_dataloader, descr="Evaluating"):
|
for input_ids, input_mask, segment_ids, example_index in tqdm(eval_dataloader, descr="Evaluating"):
|
||||||
if len(all_results) % 1000 == 0:
|
if len(all_results) % 1000 == 0:
|
||||||
logger.info("Processing example: %d" % (len(all_results)))
|
logger.info("Processing example: %d" % (len(all_results)))
|
||||||
@@ -918,9 +898,7 @@ def main():
|
|||||||
start_logits, end_logits = model(input_ids, segment_ids, input_mask)
|
start_logits, end_logits = model(input_ids, segment_ids, input_mask)
|
||||||
|
|
||||||
unique_id = [int(eval_features[e.item()].unique_id) for e in example_index]
|
unique_id = [int(eval_features[e.item()].unique_id) for e in example_index]
|
||||||
#start_logits = [x.item() for x in start_logits]
|
|
||||||
start_logits = [x.view(-1).detach().cpu().numpy() for x in start_logits]
|
start_logits = [x.view(-1).detach().cpu().numpy() for x in start_logits]
|
||||||
#end_logits = [x.item() for x in end_logits]
|
|
||||||
end_logits = [x.view(-1).detach().cpu().numpy() for x in end_logits]
|
end_logits = [x.view(-1).detach().cpu().numpy() for x in end_logits]
|
||||||
for idx, i in enumerate(unique_id):
|
for idx, i in enumerate(unique_id):
|
||||||
s = [float(x) for x in start_logits[idx]]
|
s = [float(x) for x in start_logits[idx]]
|
||||||
@@ -932,11 +910,6 @@ def main():
|
|||||||
end_logits=e
|
end_logits=e
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# all_results.append(
|
|
||||||
# RawResult(
|
|
||||||
# unique_id=unique_id,
|
|
||||||
# start_logits=start_logits,
|
|
||||||
# end_logits=end_logits))
|
|
||||||
|
|
||||||
output_prediction_file = os.path.join(args.output_dir, "predictions.json")
|
output_prediction_file = os.path.join(args.output_dir, "predictions.json")
|
||||||
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json")
|
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json")
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2018 The Google AI Language Team Authors.
|
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
|||||||
Reference in New Issue
Block a user