DistilBERT token type ids removed from inputs in run_squad
This commit is contained in:
@@ -207,11 +207,14 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
inputs = {
|
inputs = {
|
||||||
"input_ids": batch[0],
|
"input_ids": batch[0],
|
||||||
"attention_mask": batch[1],
|
"attention_mask": batch[1],
|
||||||
"token_type_ids": None if args.model_type in ["xlm", "roberta", "distilbert"] else batch[2],
|
"token_type_ids": batch[2],
|
||||||
"start_positions": batch[3],
|
"start_positions": batch[3],
|
||||||
"end_positions": batch[4],
|
"end_positions": batch[4],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if args.model_type in ["xlm", "roberta", "distilbert"]:
|
||||||
|
del inputs["token_type_ids"]
|
||||||
|
|
||||||
if args.model_type in ["xlnet", "xlm"]:
|
if args.model_type in ["xlnet", "xlm"]:
|
||||||
inputs.update({"cls_index": batch[5], "p_mask": batch[6]})
|
inputs.update({"cls_index": batch[5], "p_mask": batch[6]})
|
||||||
if args.version_2_with_negative:
|
if args.version_2_with_negative:
|
||||||
@@ -316,8 +319,12 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
inputs = {
|
inputs = {
|
||||||
"input_ids": batch[0],
|
"input_ids": batch[0],
|
||||||
"attention_mask": batch[1],
|
"attention_mask": batch[1],
|
||||||
"token_type_ids": None if args.model_type in ["xlm", "roberta", "distilbert"] else batch[2],
|
"token_type_ids": batch[2],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if args.model_type in ["xlm", "roberta", "distilbert"]:
|
||||||
|
del inputs["token_type_ids"]
|
||||||
|
|
||||||
example_indices = batch[3]
|
example_indices = batch[3]
|
||||||
|
|
||||||
# XLNet and XLM use more arguments for their predictions
|
# XLNet and XLM use more arguments for their predictions
|
||||||
@@ -427,10 +434,14 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Init features and dataset from cache if it exists
|
# Init features and dataset from cache if it exists
|
||||||
if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples:
|
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
||||||
logger.info("Loading features from cached file %s", cached_features_file)
|
logger.info("Loading features from cached file %s", cached_features_file)
|
||||||
features_and_dataset = torch.load(cached_features_file)
|
features_and_dataset = torch.load(cached_features_file)
|
||||||
features, dataset = features_and_dataset["features"], features_and_dataset["dataset"]
|
features, dataset, examples = (
|
||||||
|
features_and_dataset["features"],
|
||||||
|
features_and_dataset["dataset"],
|
||||||
|
features_and_dataset["examples"],
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.info("Creating features from dataset file at %s", input_dir)
|
logger.info("Creating features from dataset file at %s", input_dir)
|
||||||
|
|
||||||
@@ -465,7 +476,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
|||||||
|
|
||||||
if args.local_rank in [-1, 0]:
|
if args.local_rank in [-1, 0]:
|
||||||
logger.info("Saving features into cached file %s", cached_features_file)
|
logger.info("Saving features into cached file %s", cached_features_file)
|
||||||
torch.save({"features": features, "dataset": dataset}, cached_features_file)
|
torch.save({"features": features, "dataset": dataset, "examples": examples}, cached_features_file)
|
||||||
|
|
||||||
if args.local_rank == 0 and not evaluate:
|
if args.local_rank == 0 and not evaluate:
|
||||||
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||||
@@ -776,7 +787,7 @@ def main():
|
|||||||
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||||
|
|
||||||
# Load a trained model and vocabulary that you have fine-tuned
|
# Load a trained model and vocabulary that you have fine-tuned
|
||||||
model = model_class.from_pretrained(args.output_dir, force_download=True)
|
model = model_class.from_pretrained(args.output_dir) # , force_download=True)
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
|
|
||||||
@@ -801,7 +812,7 @@ def main():
|
|||||||
for checkpoint in checkpoints:
|
for checkpoint in checkpoints:
|
||||||
# Reload the model
|
# Reload the model
|
||||||
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||||
model = model_class.from_pretrained(checkpoint, force_download=True)
|
model = model_class.from_pretrained(checkpoint) # , force_download=True)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
|
|
||||||
# Evaluate
|
# Evaluate
|
||||||
|
|||||||
Reference in New Issue
Block a user