Merge branch 'master' into run_multiple_choice_merge
# Please enter a commit message to explain why this merge is necessary, # especially if it merges an updated upstream into a topic branch. # # Lines starting with '#' will be ignored, and an empty message aborts # the commit.
This commit is contained in:
@@ -13,7 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Finetuning the library models for multiple choice (Bert, XLM, XLNet)."""
|
||||
""" Finetuning the library models for multiple choice (Bert, Roberta, XLNet)."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
@@ -44,7 +44,7 @@ from utils_multiple_choice import (convert_examples_to_features, processors)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig)), ())
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, RobertaConfig)), ())
|
||||
|
||||
MODEL_CLASSES = {
|
||||
'bert': (BertConfig, BertForMultipleChoice, BertTokenizer),
|
||||
@@ -209,7 +209,6 @@ def train(args, train_dataset, model, tokenizer):
|
||||
|
||||
|
||||
def evaluate(args, model, tokenizer, prefix="", test=False):
|
||||
# Loop to handle MNLI double evaluation (matched, mis-matched)
|
||||
eval_task_names = (args.task_name,)
|
||||
eval_outputs_dirs = (args.output_dir,)
|
||||
|
||||
@@ -260,7 +259,7 @@ def evaluate(args, model, tokenizer, prefix="", test=False):
|
||||
result = {"eval_acc": acc, "eval_loss": eval_loss}
|
||||
results.update(result)
|
||||
|
||||
output_eval_file = os.path.join(eval_output_dir, "is_test_" + str(test) + "_eval_results.txt")
|
||||
output_eval_file = os.path.join(eval_output_dir, "is_test_" + str(test).lower() + "_eval_results.txt")
|
||||
|
||||
with open(output_eval_file, "w") as writer:
|
||||
logger.info("***** Eval results {} *****".format(str(prefix) + " is test:" + str(test)))
|
||||
@@ -523,9 +522,9 @@ def main():
|
||||
if not args.do_train:
|
||||
args.output_dir = args.model_name_or_path
|
||||
checkpoints = [args.output_dir]
|
||||
if args.eval_all_checkpoints: #can not use this to do test!! just for different paras
|
||||
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
|
||||
logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
|
||||
# if args.eval_all_checkpoints: # can not use this to do test!!
|
||||
# checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
|
||||
# logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
|
||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
for checkpoint in checkpoints:
|
||||
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
|
||||
|
||||
Reference in New Issue
Block a user