Cleanup squad and add allow train_file and predict_file usage
This commit is contained in:
@@ -337,7 +337,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
|||||||
else:
|
else:
|
||||||
logger.info("Creating features from dataset file at %s", input_dir)
|
logger.info("Creating features from dataset file at %s", input_dir)
|
||||||
|
|
||||||
if not args.data_dir:
|
if not args.data_dir and ((evaluate and not args.predict_file) or (not evaluate and not args.train_file)):
|
||||||
try:
|
try:
|
||||||
import tensorflow_datasets as tfds
|
import tensorflow_datasets as tfds
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -350,7 +350,11 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
|||||||
examples = SquadV1Processor().get_examples_from_dataset(tfds_examples, evaluate=evaluate)
|
examples = SquadV1Processor().get_examples_from_dataset(tfds_examples, evaluate=evaluate)
|
||||||
else:
|
else:
|
||||||
processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor()
|
processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor()
|
||||||
examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
|
|
||||||
|
if evaluate:
|
||||||
|
examples = processor.get_dev_examples(args.data_dir, filename=args.predict_file)
|
||||||
|
else:
|
||||||
|
examples = processor.get_train_examples(args.data_dir, filename=args.train_file)
|
||||||
|
|
||||||
features, dataset = squad_convert_examples_to_features(
|
features, dataset = squad_convert_examples_to_features(
|
||||||
examples=examples,
|
examples=examples,
|
||||||
@@ -387,7 +391,14 @@ def main():
|
|||||||
|
|
||||||
## Other parameters
|
## Other parameters
|
||||||
parser.add_argument("--data_dir", default=None, type=str,
|
parser.add_argument("--data_dir", default=None, type=str,
|
||||||
help="The input data dir. Should contain the .json files for the task. If not specified, will run with tensorflow_datasets.")
|
help="The input data dir. Should contain the .json files for the task." +
|
||||||
|
"If no data dir or train/predict files are specified, will run with tensorflow_datasets.")
|
||||||
|
parser.add_argument("--train_file", default=None, type=str,
|
||||||
|
help="The input training file. If a data dir is specified, will look for the file there" +
|
||||||
|
"If no data dir or train/predict files are specified, will run with tensorflow_datasets.")
|
||||||
|
parser.add_argument("--predict_file", default=None, type=str,
|
||||||
|
help="The input evaluation file. If a data dir is specified, will look for the file there" +
|
||||||
|
"If no data dir or train/predict files are specified, will run with tensorflow_datasets.")
|
||||||
parser.add_argument("--config_name", default="", type=str,
|
parser.add_argument("--config_name", default="", type=str,
|
||||||
help="Pretrained config name or path if not the same as model_name")
|
help="Pretrained config name or path if not the same as model_name")
|
||||||
parser.add_argument("--tokenizer_name", default="", type=str,
|
parser.add_argument("--tokenizer_name", default="", type=str,
|
||||||
@@ -472,11 +483,6 @@ def main():
|
|||||||
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
|
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
args.predict_file = os.path.join(args.output_dir, 'predictions_{}_{}.txt'.format(
|
|
||||||
list(filter(None, args.model_name_or_path.split('/'))).pop(),
|
|
||||||
str(args.max_seq_length))
|
|
||||||
)
|
|
||||||
|
|
||||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
||||||
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
||||||
|
|
||||||
|
|||||||
@@ -373,6 +373,9 @@ class SquadProcessor(DataProcessor):
|
|||||||
which is `train-v1.1.json` and `train-v2.0.json` for squad versions 1.1 and 2.0 respectively.
|
which is `train-v1.1.json` and `train-v2.0.json` for squad versions 1.1 and 2.0 respectively.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if data_dir is None:
|
||||||
|
data_dir = ""
|
||||||
|
|
||||||
if self.train_file is None:
|
if self.train_file is None:
|
||||||
raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor")
|
raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor")
|
||||||
|
|
||||||
@@ -389,6 +392,9 @@ class SquadProcessor(DataProcessor):
|
|||||||
filename: None by default, specify this if the evaluation file has a different name than the original one
|
filename: None by default, specify this if the evaluation file has a different name than the original one
|
||||||
which is `train-v1.1.json` and `train-v2.0.json` for squad versions 1.1 and 2.0 respectively.
|
which is `train-v1.1.json` and `train-v2.0.json` for squad versions 1.1 and 2.0 respectively.
|
||||||
"""
|
"""
|
||||||
|
if data_dir is None:
|
||||||
|
data_dir = ""
|
||||||
|
|
||||||
if self.dev_file is None:
|
if self.dev_file is None:
|
||||||
raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor")
|
raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user