Style
This commit is contained in:
@@ -76,9 +76,7 @@ def parse_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--preprocessing_num_workers", type=int, default=4, help="A csv or a json file containing the training data."
|
"--preprocessing_num_workers", type=int, default=4, help="A csv or a json file containing the training data."
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument("--do_predict", action="store_true", help="Eval the question answering model")
|
||||||
"--do_predict", action="store_true", help="Eval the question answering model"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
|
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
|
||||||
)
|
)
|
||||||
@@ -396,7 +394,6 @@ def main():
|
|||||||
|
|
||||||
return tokenized_examples
|
return tokenized_examples
|
||||||
|
|
||||||
|
|
||||||
if "train" not in raw_datasets:
|
if "train" not in raw_datasets:
|
||||||
raise ValueError("--do_train requires a train dataset")
|
raise ValueError("--do_train requires a train dataset")
|
||||||
train_dataset = raw_datasets["train"]
|
train_dataset = raw_datasets["train"]
|
||||||
@@ -481,7 +478,6 @@ def main():
|
|||||||
|
|
||||||
return tokenized_examples
|
return tokenized_examples
|
||||||
|
|
||||||
|
|
||||||
if "validation" not in raw_datasets:
|
if "validation" not in raw_datasets:
|
||||||
raise ValueError("--do_eval requires a validation dataset")
|
raise ValueError("--do_eval requires a validation dataset")
|
||||||
eval_examples = raw_datasets["validation"]
|
eval_examples = raw_datasets["validation"]
|
||||||
@@ -539,11 +535,8 @@ def main():
|
|||||||
train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size
|
train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
eval_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
|
eval_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
|
||||||
eval_dataloader = DataLoader(
|
eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)
|
||||||
eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.do_predict:
|
if args.do_predict:
|
||||||
test_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
|
test_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
|
||||||
@@ -605,7 +598,7 @@ def main():
|
|||||||
if step + batch_size < len(dataset):
|
if step + batch_size < len(dataset):
|
||||||
logits_concat[step : step + batch_size, :cols] = output_logit
|
logits_concat[step : step + batch_size, :cols] = output_logit
|
||||||
else:
|
else:
|
||||||
logits_concat[step:, :cols] = output_logit[:len(dataset) - step]
|
logits_concat[step:, :cols] = output_logit[: len(dataset) - step]
|
||||||
|
|
||||||
step += batch_size
|
step += batch_size
|
||||||
|
|
||||||
|
|||||||
@@ -81,9 +81,7 @@ def parse_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--preprocessing_num_workers", type=int, default=4, help="A csv or a json file containing the training data."
|
"--preprocessing_num_workers", type=int, default=4, help="A csv or a json file containing the training data."
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument("--do_predict", action="store_true", help="Eval the question answering model")
|
||||||
"--do_predict", action="store_true", help="Eval the question answering model"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
|
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
|
||||||
)
|
)
|
||||||
@@ -543,9 +541,7 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
eval_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
|
eval_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
|
||||||
eval_dataloader = DataLoader(
|
eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)
|
||||||
eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.do_predict:
|
if args.do_predict:
|
||||||
test_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
|
test_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
|
||||||
@@ -607,7 +603,7 @@ def main():
|
|||||||
if step + batch_size < len(dataset):
|
if step + batch_size < len(dataset):
|
||||||
logits_concat[step : step + batch_size, :cols] = output_logit
|
logits_concat[step : step + batch_size, :cols] = output_logit
|
||||||
else:
|
else:
|
||||||
logits_concat[step:, :cols] = output_logit[:len(dataset) - step]
|
logits_concat[step:, :cols] = output_logit[: len(dataset) - step]
|
||||||
|
|
||||||
step += batch_size
|
step += batch_size
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user