Return dataset (pytorch)
This commit is contained in:
@@ -7,7 +7,11 @@ import numpy as np
|
|||||||
|
|
||||||
from ...tokenization_bert import BasicTokenizer, whitespace_tokenize
|
from ...tokenization_bert import BasicTokenizer, whitespace_tokenize
|
||||||
from .utils import DataProcessor, InputExample, InputFeatures
|
from .utils import DataProcessor, InputExample, InputFeatures
|
||||||
from ...file_utils import is_tf_available
|
from ...file_utils import is_tf_available, is_torch_available
|
||||||
|
|
||||||
|
if is_torch_available:
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import TensorDataset
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
@@ -73,7 +77,8 @@ def _is_whitespace(c):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||||
doc_stride, max_query_length, is_training):
|
doc_stride, max_query_length, is_training,
|
||||||
|
return_dataset=False):
|
||||||
"""
|
"""
|
||||||
Converts a list of examples into a list of features that can be directly given as input to a model.
|
Converts a list of examples into a list of features that can be directly given as input to a model.
|
||||||
It is model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs.
|
It is model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs.
|
||||||
@@ -84,7 +89,10 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
|||||||
max_seq_length: The maximum sequence length of the inputs.
|
max_seq_length: The maximum sequence length of the inputs.
|
||||||
doc_stride: The stride used when the context is too large and is split across several features.
|
doc_stride: The stride used when the context is too large and is split across several features.
|
||||||
max_query_length: The maximum length of the query.
|
max_query_length: The maximum length of the query.
|
||||||
is_training: wheter to create features for model evaluation or model training.
|
is_training: whether to create features for model evaluation or model training.
|
||||||
|
return_dataset: Default False. Either 'pt' or 'tf'.
|
||||||
|
if 'pt': returns a torch.data.TensorDataset,
|
||||||
|
if 'tf': returns a tf.data.Dataset
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list of :class:`~transformers.data.processors.squad.SquadFeatures`
|
list of :class:`~transformers.data.processors.squad.SquadFeatures`
|
||||||
@@ -264,6 +272,31 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
|
|||||||
|
|
||||||
unique_id += 1
|
unique_id += 1
|
||||||
|
|
||||||
|
if return_dataset == 'pt':
|
||||||
|
if not is_torch_available():
|
||||||
|
raise ImportError("Pytorch must be installed to return a pytorch dataset.")
|
||||||
|
|
||||||
|
# Convert to Tensors and build dataset
|
||||||
|
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
||||||
|
all_input_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
|
||||||
|
all_segment_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
|
||||||
|
all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long)
|
||||||
|
all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
|
||||||
|
|
||||||
|
if not is_training:
|
||||||
|
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
|
||||||
|
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
|
||||||
|
all_example_index, all_cls_index, all_p_mask)
|
||||||
|
else:
|
||||||
|
all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
|
||||||
|
all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
|
||||||
|
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
|
||||||
|
all_start_positions, all_end_positions,
|
||||||
|
all_cls_index, all_p_mask)
|
||||||
|
|
||||||
|
return features, dataset
|
||||||
|
|
||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
|
|
||||||
@@ -359,7 +392,7 @@ class SquadProcessor(DataProcessor):
|
|||||||
if self.dev_file is None:
|
if self.dev_file is None:
|
||||||
raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor")
|
raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor")
|
||||||
|
|
||||||
with open(os.path.join(data_dir, self.dev_file if filename is not None else filename), "r", encoding='utf-8') as reader:
|
with open(os.path.join(data_dir, self.dev_file if filename is None else filename), "r", encoding='utf-8') as reader:
|
||||||
input_data = json.load(reader)["data"]
|
input_data = json.load(reader)["data"]
|
||||||
return self._create_examples(input_data, "dev")
|
return self._create_examples(input_data, "dev")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user