Fix it to work with BART (#6756)

This commit is contained in:
Tom Grek
2020-08-27 06:04:50 -07:00
committed by GitHub
parent 0d2c111a0c
commit c225e872ed

View File

@@ -187,7 +187,7 @@ def train(args, train_dataset, model, tokenizer):
"end_positions": batch[4], "end_positions": batch[4],
} }
if args.model_type in ["xlm", "roberta", "distilbert", "camembert"]: if args.model_type in ["xlm", "roberta", "distilbert", "camembert", "bart"]:
del inputs["token_type_ids"] del inputs["token_type_ids"]
if args.model_type in ["xlnet", "xlm"]: if args.model_type in ["xlnet", "xlm"]:
@@ -300,7 +300,7 @@ def evaluate(args, model, tokenizer, prefix=""):
"token_type_ids": batch[2], "token_type_ids": batch[2],
} }
if args.model_type in ["xlm", "roberta", "distilbert", "camembert"]: if args.model_type in ["xlm", "roberta", "distilbert", "camembert", "bart"]:
del inputs["token_type_ids"] del inputs["token_type_ids"]
feature_indices = batch[3] feature_indices = batch[3]