Patch evaluation for impossible values + cleanup
This commit is contained in:
@@ -55,7 +55,7 @@ Example usage
|
|||||||
^^^^^^^^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
An example using these processors is given in the
|
An example using these processors is given in the
|
||||||
`run_glue.py <https://github.com/huggingface/pytorch-transformers/blob/master/examples/run_glue.py>`__ script.
|
`run_glue.py <https://github.com/huggingface/transformers/blob/master/examples/run_glue.py>`__ script.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -132,4 +132,4 @@ Example::
|
|||||||
|
|
||||||
|
|
||||||
Another example using these processors is given in the
|
Another example using these processors is given in the
|
||||||
`run_squad.py <https://github.com/huggingface/pytorch-transformers/blob/master/examples/run_squad.py>`__ script.
|
`run_squad.py <https://github.com/huggingface/transformers/blob/master/examples/run_squad.py>`__ script.
|
||||||
@@ -311,7 +311,8 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
|||||||
str(args.max_seq_length)))
|
str(args.max_seq_length)))
|
||||||
if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples:
|
if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples:
|
||||||
logger.info("Loading features from cached file %s", cached_features_file)
|
logger.info("Loading features from cached file %s", cached_features_file)
|
||||||
features = torch.load(cached_features_file)
|
features_and_dataset = torch.load(cached_features_file)
|
||||||
|
features, dataset = features_and_dataset["features"], features_and_dataset["dataset"]
|
||||||
else:
|
else:
|
||||||
logger.info("Creating features from dataset file at %s", input_dir)
|
logger.info("Creating features from dataset file at %s", input_dir)
|
||||||
|
|
||||||
@@ -330,40 +331,24 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
|||||||
processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor()
|
processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor()
|
||||||
examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
|
examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
|
||||||
|
|
||||||
features = squad_convert_examples_to_features(
|
features, dataset = squad_convert_examples_to_features(
|
||||||
examples=examples,
|
examples=examples,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
max_seq_length=args.max_seq_length,
|
max_seq_length=args.max_seq_length,
|
||||||
doc_stride=args.doc_stride,
|
doc_stride=args.doc_stride,
|
||||||
max_query_length=args.max_query_length,
|
max_query_length=args.max_query_length,
|
||||||
is_training=not evaluate,
|
is_training=not evaluate,
|
||||||
|
return_dataset='pt'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if args.local_rank in [-1, 0]:
|
if args.local_rank in [-1, 0]:
|
||||||
logger.info("Saving features into cached file %s", cached_features_file)
|
logger.info("Saving features into cached file %s", cached_features_file)
|
||||||
torch.save(features, cached_features_file)
|
torch.save({"features": features, "dataset": dataset}, cached_features_file)
|
||||||
|
|
||||||
if args.local_rank == 0 and not evaluate:
|
if args.local_rank == 0 and not evaluate:
|
||||||
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||||
|
|
||||||
# Convert to Tensors and build dataset
|
|
||||||
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
|
||||||
all_input_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
|
|
||||||
all_segment_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
|
|
||||||
all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long)
|
|
||||||
all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
|
|
||||||
if evaluate:
|
|
||||||
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
|
|
||||||
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
|
|
||||||
all_example_index, all_cls_index, all_p_mask)
|
|
||||||
else:
|
|
||||||
all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
|
|
||||||
all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
|
|
||||||
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
|
|
||||||
all_start_positions, all_end_positions,
|
|
||||||
all_cls_index, all_p_mask)
|
|
||||||
|
|
||||||
if output_examples:
|
if output_examples:
|
||||||
return dataset, examples, features
|
return dataset, examples, features
|
||||||
return dataset
|
return dataset
|
||||||
|
|||||||
@@ -312,7 +312,7 @@ class SquadProcessor(DataProcessor):
|
|||||||
if not evaluate:
|
if not evaluate:
|
||||||
answer = tensor_dict['answers']['text'][0].numpy().decode('utf-8')
|
answer = tensor_dict['answers']['text'][0].numpy().decode('utf-8')
|
||||||
answer_start = tensor_dict['answers']['answer_start'][0].numpy()
|
answer_start = tensor_dict['answers']['answer_start'][0].numpy()
|
||||||
answers = None
|
answers = []
|
||||||
else:
|
else:
|
||||||
answers = [{
|
answers = [{
|
||||||
"answer_start": start.numpy(),
|
"answer_start": start.numpy(),
|
||||||
@@ -408,7 +408,7 @@ class SquadProcessor(DataProcessor):
|
|||||||
question_text = qa["question"]
|
question_text = qa["question"]
|
||||||
start_position_character = None
|
start_position_character = None
|
||||||
answer_text = None
|
answer_text = None
|
||||||
answers = None
|
answers = []
|
||||||
|
|
||||||
if "is_impossible" in qa:
|
if "is_impossible" in qa:
|
||||||
is_impossible = qa["is_impossible"]
|
is_impossible = qa["is_impossible"]
|
||||||
@@ -469,7 +469,7 @@ class SquadExample(object):
|
|||||||
answer_text,
|
answer_text,
|
||||||
start_position_character,
|
start_position_character,
|
||||||
title,
|
title,
|
||||||
answers=None,
|
answers=[],
|
||||||
is_impossible=False):
|
is_impossible=False):
|
||||||
self.qas_id = qas_id
|
self.qas_id = qas_id
|
||||||
self.question_text = question_text
|
self.question_text = question_text
|
||||||
|
|||||||
@@ -194,7 +194,7 @@ class PreTrainedTokenizer(object):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def pad_token_type_id(self):
|
def pad_token_type_id(self):
|
||||||
""" Id of the padding token in the vocabulary. Log an error if used while not having been set. """
|
""" Id of the padding token type in the vocabulary."""
|
||||||
return self._pad_token_type_id
|
return self._pad_token_type_id
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
Reference in New Issue
Block a user