Fix it to work with BART (#6756)
This commit is contained in:
@@ -187,7 +187,7 @@ def train(args, train_dataset, model, tokenizer):
|
||||
"end_positions": batch[4],
|
||||
}
|
||||
|
||||
if args.model_type in ["xlm", "roberta", "distilbert", "camembert"]:
|
||||
if args.model_type in ["xlm", "roberta", "distilbert", "camembert", "bart"]:
|
||||
del inputs["token_type_ids"]
|
||||
|
||||
if args.model_type in ["xlnet", "xlm"]:
|
||||
@@ -300,7 +300,7 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
"token_type_ids": batch[2],
|
||||
}
|
||||
|
||||
if args.model_type in ["xlm", "roberta", "distilbert", "camembert"]:
|
||||
if args.model_type in ["xlm", "roberta", "distilbert", "camembert", "bart"]:
|
||||
del inputs["token_type_ids"]
|
||||
|
||||
feature_indices = batch[3]
|
||||
|
||||
Reference in New Issue
Block a user