diff --git a/examples/adversarial/run_hans.py b/examples/adversarial/run_hans.py index fe384d0727..1bb6a12d15 100644 --- a/examples/adversarial/run_hans.py +++ b/examples/adversarial/run_hans.py @@ -33,7 +33,7 @@ from transformers import ( default_data_collator, set_seed, ) -from utils_hans import HansDataset, InputFeatures, hans_processors +from utils_hans import HansDataset, InputFeatures, hans_processors, hans_tasks_num_labels logger = logging.getLogger(__name__) @@ -130,9 +130,7 @@ def main(): set_seed(training_args.seed) try: - processor = hans_processors[data_args.task_name]() - label_list = processor.get_labels() - num_labels = len(label_list) + num_labels = hans_tasks_num_labels[data_args.task_name] except KeyError: raise ValueError("Task not found: %s" % (data_args.task_name)) @@ -214,6 +212,7 @@ def main(): pair_ids = [ex.pairID for ex in eval_dataset] output_eval_file = os.path.join(training_args.output_dir, "hans_predictions.txt") + label_list = eval_dataset.get_labels() if trainer.is_world_master(): with open(output_eval_file, "w") as writer: writer.write("pairID,gold_label\n") diff --git a/examples/adversarial/utils_hans.py b/examples/adversarial/utils_hans.py index 0dbc28149f..5058e8b45f 100644 --- a/examples/adversarial/utils_hans.py +++ b/examples/adversarial/utils_hans.py @@ -22,7 +22,17 @@ from typing import List, Optional, Union import tqdm from filelock import FileLock -from transformers import DataProcessor, PreTrainedTokenizer, is_tf_available, is_torch_available +from transformers import ( + BartTokenizer, + BartTokenizerFast, + DataProcessor, + PreTrainedTokenizer, + RobertaTokenizer, + RobertaTokenizerFast, + XLMRobertaTokenizer, + is_tf_available, + is_torch_available, +) logger = logging.getLogger(__name__) @@ -105,6 +115,17 @@ if is_torch_available(): "dev" if evaluate else "train", tokenizer.__class__.__name__, str(max_seq_length), task, ), ) + label_list = processor.get_labels() + if tokenizer.__class__ in ( + RobertaTokenizer, + RobertaTokenizerFast, + XLMRobertaTokenizer, + BartTokenizer, + BartTokenizerFast, + ): + # HACK(label indices are swapped in RoBERTa pretrained model) + label_list[1], label_list[2] = label_list[2], label_list[1] + self.label_list = label_list # Make sure only the first process in distributed training processes the dataset, # and the others will use the cache. @@ -116,7 +137,6 @@ if is_torch_available(): self.features = torch.load(cached_features_file) else: logger.info(f"Creating features from dataset file at {data_dir}") - label_list = processor.get_labels() examples = ( processor.get_dev_examples(data_dir) if evaluate else processor.get_train_examples(data_dir) @@ -133,6 +153,9 @@ if is_torch_available(): def __getitem__(self, i) -> InputFeatures: return self.features[i] + def get_labels(self): + return self.label_list + if is_tf_available(): import tensorflow as tf @@ -156,6 +179,16 @@ if is_tf_available(): ): processor = hans_processors[task]() label_list = processor.get_labels() + if tokenizer.__class__ in ( + RobertaTokenizer, + RobertaTokenizerFast, + XLMRobertaTokenizer, + BartTokenizer, + BartTokenizerFast, + ): + # HACK(label indices are swapped in RoBERTa pretrained model) + label_list[1], label_list[2] = label_list[2], label_list[1] + self.label_list = label_list examples = processor.get_dev_examples(data_dir) if evaluate else processor.get_train_examples(data_dir) self.features = hans_convert_examples_to_features(examples, label_list, max_seq_length, tokenizer) @@ -206,6 +239,9 @@ if is_tf_available(): def __getitem__(self, i) -> InputFeatures: return self.features[i] + def get_labels(self): + return self.label_list + class HansProcessor(DataProcessor): """Processor for the HANS data set."""