Trainer (#3800)
* doc
* [tests] Add sample files for a regression task
* [HUGE] Trainer
* Feedback from @sshleifer
* Feedback from @thomwolf + logging tweak
* [file_utils] when downloading concurrently, get_from_cache will use the cached file for subsequent processes
* [glue] Use default max_seq_length of 128 like before
* [glue] move DataTrainingArguments around
* [ner] Change interface of InputExample, and align run_{tf,pl}
* Re-align the pl scripts a little bit
* ner
* [ner] Add integration test
* Fix language_modeling with API tweak
* [ci] Tweak loss target
* Don't break console output
* amp.initialize: model must be on right device before
* [multiple-choice] update for Trainer
* Re-align to 827d6d6ef0
This commit is contained in:
@@ -13,7 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Multiple choice fine-tuning: utilities to work with multiple choice tasks of reading comprehension """
|
||||
""" Multiple choice fine-tuning: utilities to work with multiple choice tasks of reading comprehension """
|
||||
|
||||
|
||||
import csv
|
||||
@@ -21,48 +21,124 @@ import glob
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import List
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
from torch.utils.data.dataset import Dataset
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
from transformers import PreTrainedTokenizer, torch_distributed_zero_first
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InputExample(object):
|
||||
"""A single training/test example for multiple choice"""
|
||||
@dataclass(frozen=True)
|
||||
class InputExample:
|
||||
"""
|
||||
A single training/test example for multiple choice
|
||||
|
||||
def __init__(self, example_id, question, contexts, endings, label=None):
|
||||
"""Constructs a InputExample.
|
||||
Args:
|
||||
example_id: Unique id for the example.
|
||||
question: string. The untokenized text of the second sequence (question).
|
||||
contexts: list of str. The untokenized text of the first sequence (context of corresponding question).
|
||||
endings: list of str. multiple choice's options. Its length must be equal to contexts' length.
|
||||
label: (Optional) string. The label of the example. This should be
|
||||
specified for train and dev examples, but not for test examples.
|
||||
"""
|
||||
|
||||
Args:
|
||||
example_id: Unique id for the example.
|
||||
contexts: list of str. The untokenized text of the first sequence (context of corresponding question).
|
||||
question: string. The untokenized text of the second sequence (question).
|
||||
endings: list of str. multiple choice's options. Its length must be equal to contexts' length.
|
||||
label: (Optional) string. The label of the example. This should be
|
||||
specified for train and dev examples, but not for test examples.
|
||||
"""
|
||||
self.example_id = example_id
|
||||
self.question = question
|
||||
self.contexts = contexts
|
||||
self.endings = endings
|
||||
self.label = label
|
||||
example_id: str
|
||||
question: str
|
||||
contexts: List[str]
|
||||
endings: List[str]
|
||||
label: Optional[str]
|
||||
|
||||
|
||||
class InputFeatures(object):
|
||||
def __init__(self, example_id, choices_features, label):
|
||||
self.example_id = example_id
|
||||
self.choices_features = [
|
||||
{"input_ids": input_ids, "input_mask": input_mask, "segment_ids": segment_ids}
|
||||
for input_ids, input_mask, segment_ids in choices_features
|
||||
]
|
||||
self.label = label
|
||||
@dataclass(frozen=True)
|
||||
class InputFeatures:
|
||||
"""
|
||||
A single set of features of data.
|
||||
Property names are the same names as the corresponding inputs to a model.
|
||||
"""
|
||||
|
||||
example_id: str
|
||||
input_ids: List[List[int]]
|
||||
attention_mask: Optional[List[List[int]]]
|
||||
token_type_ids: Optional[List[List[int]]]
|
||||
label: Optional[int]
|
||||
|
||||
|
||||
class DataProcessor(object):
|
||||
class Split(Enum):
|
||||
train = "train"
|
||||
dev = "dev"
|
||||
test = "test"
|
||||
|
||||
|
||||
class MultipleChoiceDataset(Dataset):
|
||||
"""
|
||||
This will be superseded by a framework-agnostic approach
|
||||
soon.
|
||||
"""
|
||||
|
||||
features: List[InputFeatures]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
task: str,
|
||||
max_seq_length: Optional[int] = None,
|
||||
overwrite_cache=False,
|
||||
mode: Split = Split.train,
|
||||
local_rank=-1,
|
||||
):
|
||||
processor = processors[task]()
|
||||
|
||||
cached_features_file = os.path.join(
|
||||
data_dir,
|
||||
"cached_{}_{}_{}_{}".format(mode.value, tokenizer.__class__.__name__, str(max_seq_length), task,),
|
||||
)
|
||||
with torch_distributed_zero_first(local_rank):
|
||||
# Make sure only the first process in distributed training processes the dataset,
|
||||
# and the others will use the cache.
|
||||
|
||||
if os.path.exists(cached_features_file) and not overwrite_cache:
|
||||
logger.info(f"Loading features from cached file {cached_features_file}")
|
||||
self.features = torch.load(cached_features_file)
|
||||
else:
|
||||
logger.info(f"Creating features from dataset file at {data_dir}")
|
||||
label_list = processor.get_labels()
|
||||
if mode == Split.dev:
|
||||
examples = processor.get_dev_examples(data_dir)
|
||||
elif mode == Split.test:
|
||||
examples = processor.get_test_examples(data_dir)
|
||||
else:
|
||||
examples = processor.get_train_examples(data_dir)
|
||||
logger.info("Training examples: %s", len(examples))
|
||||
# TODO clean up all this to leverage built-in features of tokenizers
|
||||
self.features = convert_examples_to_features(
|
||||
examples,
|
||||
label_list,
|
||||
max_seq_length,
|
||||
tokenizer,
|
||||
pad_on_left=bool(tokenizer.padding_side == "left"),
|
||||
pad_token=tokenizer.pad_token_id,
|
||||
pad_token_segment_id=tokenizer.pad_token_type_id,
|
||||
)
|
||||
if local_rank in [-1, 0]:
|
||||
logger.info("Saving features into cached file %s", cached_features_file)
|
||||
torch.save(self.features, cached_features_file)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.features)
|
||||
|
||||
def __getitem__(self, i) -> InputFeatures:
|
||||
return self.features[i]
|
||||
|
||||
|
||||
class DataProcessor:
|
||||
"""Base class for data converters for multiple choice data sets."""
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
@@ -311,7 +387,7 @@ def convert_examples_to_features(
|
||||
for (ex_index, example) in tqdm.tqdm(enumerate(examples), desc="convert examples to features"):
|
||||
if ex_index % 10000 == 0:
|
||||
logger.info("Writing example %d of %d" % (ex_index, len(examples)))
|
||||
choices_features = []
|
||||
choices_inputs = []
|
||||
for ending_idx, (context, ending) in enumerate(zip(example.contexts, example.endings)):
|
||||
text_a = context
|
||||
if example.question.find("_") != -1:
|
||||
@@ -321,7 +397,7 @@ def convert_examples_to_features(
|
||||
text_b = example.question + " " + ending
|
||||
|
||||
inputs = tokenizer.encode_plus(
|
||||
text_a, text_b, add_special_tokens=True, max_length=max_length, return_token_type_ids=True
|
||||
text_a, text_b, add_special_tokens=True, max_length=max_length, pad_to_max_length=True,
|
||||
)
|
||||
if "num_truncated_tokens" in inputs and inputs["num_truncated_tokens"] > 0:
|
||||
logger.info(
|
||||
@@ -330,41 +406,31 @@ def convert_examples_to_features(
|
||||
"you need to try to use a bigger max seq length!"
|
||||
)
|
||||
|
||||
input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
|
||||
|
||||
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
||||
# tokens are attended to.
|
||||
attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
|
||||
|
||||
# Zero-pad up to the sequence length.
|
||||
padding_length = max_length - len(input_ids)
|
||||
if pad_on_left:
|
||||
input_ids = ([pad_token] * padding_length) + input_ids
|
||||
attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask
|
||||
token_type_ids = ([pad_token_segment_id] * padding_length) + token_type_ids
|
||||
else:
|
||||
input_ids = input_ids + ([pad_token] * padding_length)
|
||||
attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
|
||||
token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)
|
||||
|
||||
assert len(input_ids) == max_length
|
||||
assert len(attention_mask) == max_length
|
||||
assert len(token_type_ids) == max_length
|
||||
choices_features.append((input_ids, attention_mask, token_type_ids))
|
||||
choices_inputs.append(inputs)
|
||||
|
||||
label = label_map[example.label]
|
||||
|
||||
if ex_index < 2:
|
||||
logger.info("*** Example ***")
|
||||
logger.info("race_id: {}".format(example.example_id))
|
||||
for choice_idx, (input_ids, attention_mask, token_type_ids) in enumerate(choices_features):
|
||||
logger.info("choice: {}".format(choice_idx))
|
||||
logger.info("input_ids: {}".format(" ".join(map(str, input_ids))))
|
||||
logger.info("attention_mask: {}".format(" ".join(map(str, attention_mask))))
|
||||
logger.info("token_type_ids: {}".format(" ".join(map(str, token_type_ids))))
|
||||
logger.info("label: {}".format(label))
|
||||
input_ids = [x["input_ids"] for x in choices_inputs]
|
||||
attention_mask = (
|
||||
[x["attention_mask"] for x in choices_inputs] if "attention_mask" in choices_inputs[0] else None
|
||||
)
|
||||
token_type_ids = (
|
||||
[x["token_type_ids"] for x in choices_inputs] if "token_type_ids" in choices_inputs[0] else None
|
||||
)
|
||||
|
||||
features.append(InputFeatures(example_id=example.example_id, choices_features=choices_features, label=label,))
|
||||
features.append(
|
||||
InputFeatures(
|
||||
example_id=example.example_id,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
label=label,
|
||||
)
|
||||
)
|
||||
|
||||
for f in features[:2]:
|
||||
logger.info("*** Example ***")
|
||||
logger.info("feature: %s" % f)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
Reference in New Issue
Block a user