Merge branch 'master' into run_multiple_choice_merge

# Please enter a commit message to explain why this merge is necessary,
# especially if it merges an updated upstream into a topic branch.
#
# Lines starting with '#' will be ignored, and an empty message aborts
# the commit.
This commit is contained in:
erenup
2019-09-18 21:43:46 +08:00
5 changed files with 198 additions and 39 deletions

View File

@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Finetuning the library models for multiple choice (Bert, XLM, XLNet)."""
""" Finetuning the library models for multiple choice (Bert, Roberta, XLNet)."""
from __future__ import absolute_import, division, print_function
@@ -44,7 +44,7 @@ from utils_multiple_choice import (convert_examples_to_features, processors)
logger = logging.getLogger(__name__)
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig)), ())
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, RobertaConfig)), ())
MODEL_CLASSES = {
'bert': (BertConfig, BertForMultipleChoice, BertTokenizer),
@@ -209,7 +209,6 @@ def train(args, train_dataset, model, tokenizer):
def evaluate(args, model, tokenizer, prefix="", test=False):
# Loop to handle MNLI double evaluation (matched, mis-matched)
eval_task_names = (args.task_name,)
eval_outputs_dirs = (args.output_dir,)
@@ -260,7 +259,7 @@ def evaluate(args, model, tokenizer, prefix="", test=False):
result = {"eval_acc": acc, "eval_loss": eval_loss}
results.update(result)
output_eval_file = os.path.join(eval_output_dir, "is_test_" + str(test) + "_eval_results.txt")
output_eval_file = os.path.join(eval_output_dir, "is_test_" + str(test).lower() + "_eval_results.txt")
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results {} *****".format(str(prefix) + " is test:" + str(test)))
@@ -523,9 +522,9 @@ def main():
if not args.do_train:
args.output_dir = args.model_name_or_path
checkpoints = [args.output_dir]
if args.eval_all_checkpoints: #can not use this to do test!! just for different paras
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
# if args.eval_all_checkpoints: # can not use this to do test!!
# checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
# logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints:
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""