@@ -22,15 +22,7 @@ from typing import List, Optional, Union
|
||||
import tqdm
|
||||
from filelock import FileLock
|
||||
|
||||
from transformers import (
|
||||
DataProcessor,
|
||||
PreTrainedTokenizer,
|
||||
RobertaTokenizer,
|
||||
RobertaTokenizerFast,
|
||||
XLMRobertaTokenizer,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
)
|
||||
from transformers import DataProcessor, PreTrainedTokenizer, is_tf_available, is_torch_available
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -106,7 +98,6 @@ if is_torch_available():
|
||||
evaluate: bool = False,
|
||||
):
|
||||
processor = hans_processors[task]()
|
||||
output_mode = hans_output_modes[task]
|
||||
|
||||
cached_features_file = os.path.join(
|
||||
data_dir,
|
||||
@@ -127,22 +118,12 @@ if is_torch_available():
|
||||
logger.info(f"Creating features from dataset file at {data_dir}")
|
||||
label_list = processor.get_labels()
|
||||
|
||||
if task in ["mnli", "mnli-mm"] and tokenizer.__class__ in (
|
||||
RobertaTokenizer,
|
||||
RobertaTokenizerFast,
|
||||
XLMRobertaTokenizer,
|
||||
):
|
||||
# HACK(label indices are swapped in RoBERTa pretrained model)
|
||||
label_list[1], label_list[2] = label_list[2], label_list[1]
|
||||
examples = (
|
||||
processor.get_dev_examples(data_dir) if evaluate else processor.get_train_examples(data_dir)
|
||||
)
|
||||
|
||||
logger.info("Training examples: %s", len(examples))
|
||||
# TODO clean up all this to leverage built-in features of tokenizers
|
||||
self.features = hans_convert_examples_to_features(
|
||||
examples, label_list, max_seq_length, tokenizer, output_mode
|
||||
)
|
||||
self.features = hans_convert_examples_to_features(examples, label_list, max_seq_length, tokenizer)
|
||||
logger.info("Saving features into cached file %s", cached_features_file)
|
||||
torch.save(self.features, cached_features_file)
|
||||
|
||||
@@ -174,21 +155,10 @@ if is_tf_available():
|
||||
evaluate: bool = False,
|
||||
):
|
||||
processor = hans_processors[task]()
|
||||
output_mode = hans_output_modes[task]
|
||||
label_list = processor.get_labels()
|
||||
|
||||
if task in ["mnli", "mnli-mm"] and tokenizer.__class__ in (
|
||||
RobertaTokenizer,
|
||||
RobertaTokenizerFast,
|
||||
XLMRobertaTokenizer,
|
||||
):
|
||||
# HACK(label indices are swapped in RoBERTa pretrained model)
|
||||
label_list[1], label_list[2] = label_list[2], label_list[1]
|
||||
|
||||
examples = processor.get_dev_examples(data_dir) if evaluate else processor.get_train_examples(data_dir)
|
||||
self.features = hans_convert_examples_to_features(
|
||||
examples, label_list, max_seq_length, tokenizer, output_mode
|
||||
)
|
||||
self.features = hans_convert_examples_to_features(examples, label_list, max_seq_length, tokenizer)
|
||||
|
||||
def gen():
|
||||
for (ex_index, ex) in tqdm.tqdm(enumerate(self.features), desc="convert examples to features"):
|
||||
@@ -240,15 +210,6 @@ if is_tf_available():
|
||||
class HansProcessor(DataProcessor):
|
||||
"""Processor for the HANS data set."""
|
||||
|
||||
def get_example_from_tensor_dict(self, tensor_dict):
|
||||
"""See base class."""
|
||||
return InputExample(
|
||||
tensor_dict["idx"].numpy(),
|
||||
tensor_dict["premise"].numpy().decode("utf-8"),
|
||||
tensor_dict["hypothesis"].numpy().decode("utf-8"),
|
||||
str(tensor_dict["label"].numpy()),
|
||||
)
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(self._read_tsv(os.path.join(data_dir, "heuristics_train_set.txt")), "train")
|
||||
@@ -277,11 +238,7 @@ class HansProcessor(DataProcessor):
|
||||
|
||||
|
||||
def hans_convert_examples_to_features(
|
||||
examples: List[InputExample],
|
||||
label_list: List[str],
|
||||
max_length: int,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
output_mode: str,
|
||||
examples: List[InputExample], label_list: List[str], max_length: int, tokenizer: PreTrainedTokenizer,
|
||||
):
|
||||
"""
|
||||
Loads a data file into a list of ``InputFeatures``
|
||||
@@ -313,19 +270,8 @@ def hans_convert_examples_to_features(
|
||||
pad_to_max_length=True,
|
||||
return_overflowing_tokens=True,
|
||||
)
|
||||
if "num_truncated_tokens" in inputs and inputs["num_truncated_tokens"] > 0:
|
||||
logger.info(
|
||||
"Attention! you are cropping tokens (swag task is ok). "
|
||||
"If you are training ARC and RACE and you are poping question + options,"
|
||||
"you need to try to use a bigger max seq length!"
|
||||
)
|
||||
|
||||
if output_mode == "classification":
|
||||
label = label_map[example.label] if example.label in label_map else 0
|
||||
elif output_mode == "regression":
|
||||
label = float(example.label)
|
||||
else:
|
||||
raise KeyError(output_mode)
|
||||
label = label_map[example.label] if example.label in label_map else 0
|
||||
|
||||
pairID = int(example.pairID)
|
||||
|
||||
@@ -346,7 +292,3 @@ hans_tasks_num_labels = {
|
||||
hans_processors = {
|
||||
"hans": HansProcessor,
|
||||
}
|
||||
|
||||
hans_output_modes = {
|
||||
"hans": "classification",
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user