|
|
|
@@ -1,7 +1,10 @@
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
|
|
|
import shutil
|
|
|
|
|
|
|
|
import tempfile
|
|
|
|
import unittest
|
|
|
|
import unittest
|
|
|
|
|
|
|
|
|
|
|
|
from transformers import AutoTokenizer, is_torch_available
|
|
|
|
from transformers import BertTokenizer, is_torch_available, set_seed
|
|
|
|
from transformers.testing_utils import require_torch, slow
|
|
|
|
from transformers.testing_utils import require_torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if is_torch_available():
|
|
|
|
if is_torch_available():
|
|
|
|
@@ -12,22 +15,25 @@ if is_torch_available():
|
|
|
|
DataCollatorForNextSentencePrediction,
|
|
|
|
DataCollatorForNextSentencePrediction,
|
|
|
|
DataCollatorForPermutationLanguageModeling,
|
|
|
|
DataCollatorForPermutationLanguageModeling,
|
|
|
|
DataCollatorForSOP,
|
|
|
|
DataCollatorForSOP,
|
|
|
|
GlueDataset,
|
|
|
|
DataCollatorForTokenClassification,
|
|
|
|
GlueDataTrainingArguments,
|
|
|
|
DataCollatorWithPadding,
|
|
|
|
LineByLineTextDataset,
|
|
|
|
|
|
|
|
LineByLineWithSOPTextDataset,
|
|
|
|
|
|
|
|
TextDataset,
|
|
|
|
|
|
|
|
TextDatasetForNextSentencePrediction,
|
|
|
|
|
|
|
|
default_data_collator,
|
|
|
|
default_data_collator,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PATH_SAMPLE_TEXT = "./tests/fixtures/sample_text.txt"
|
|
|
|
|
|
|
|
PATH_SAMPLE_TEXT_DIR = "./tests/fixtures/tests_samples/wiki_text"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@require_torch
|
|
|
|
@require_torch
|
|
|
|
class DataCollatorIntegrationTest(unittest.TestCase):
|
|
|
|
class DataCollatorIntegrationTest(unittest.TestCase):
|
|
|
|
|
|
|
|
def setUp(self):
|
|
|
|
|
|
|
|
self.tmpdirname = tempfile.mkdtemp()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
|
|
|
|
|
|
|
|
self.vocab_file = os.path.join(self.tmpdirname, "vocab.txt")
|
|
|
|
|
|
|
|
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
|
|
|
|
|
|
|
|
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tearDown(self):
|
|
|
|
|
|
|
|
shutil.rmtree(self.tmpdirname)
|
|
|
|
|
|
|
|
|
|
|
|
def test_default_with_dict(self):
|
|
|
|
def test_default_with_dict(self):
|
|
|
|
features = [{"label": i, "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)]
|
|
|
|
features = [{"label": i, "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)]
|
|
|
|
batch = default_data_collator(features)
|
|
|
|
batch = default_data_collator(features)
|
|
|
|
@@ -57,6 +63,17 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
|
|
|
self.assertEqual(batch["labels"].dtype, torch.long)
|
|
|
|
self.assertEqual(batch["labels"].dtype, torch.long)
|
|
|
|
self.assertEqual(batch["inputs"].shape, torch.Size([8, 10]))
|
|
|
|
self.assertEqual(batch["inputs"].shape, torch.Size([8, 10]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_default_classification_and_regression(self):
|
|
|
|
|
|
|
|
data_collator = default_data_collator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
features = [{"input_ids": [0, 1, 2, 3, 4], "label": i} for i in range(4)]
|
|
|
|
|
|
|
|
batch = data_collator(features)
|
|
|
|
|
|
|
|
self.assertEqual(batch["labels"].dtype, torch.long)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
features = [{"input_ids": [0, 1, 2, 3, 4], "label": float(i)} for i in range(4)]
|
|
|
|
|
|
|
|
batch = data_collator(features)
|
|
|
|
|
|
|
|
self.assertEqual(batch["labels"].dtype, torch.float)
|
|
|
|
|
|
|
|
|
|
|
|
def test_default_with_no_labels(self):
|
|
|
|
def test_default_with_no_labels(self):
|
|
|
|
features = [{"label": None, "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)]
|
|
|
|
features = [{"label": None, "inputs": [0, 1, 2, 3, 4, 5]} for i in range(8)]
|
|
|
|
batch = default_data_collator(features)
|
|
|
|
batch = default_data_collator(features)
|
|
|
|
@@ -69,128 +86,144 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
|
|
|
self.assertTrue("labels" not in batch)
|
|
|
|
self.assertTrue("labels" not in batch)
|
|
|
|
self.assertEqual(batch["inputs"].shape, torch.Size([8, 6]))
|
|
|
|
self.assertEqual(batch["inputs"].shape, torch.Size([8, 6]))
|
|
|
|
|
|
|
|
|
|
|
|
@slow
|
|
|
|
def test_data_collator_with_padding(self):
|
|
|
|
def test_default_classification(self):
|
|
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
|
|
MODEL_ID = "bert-base-cased-finetuned-mrpc"
|
|
|
|
features = [{"input_ids": [0, 1, 2]}, {"input_ids": [0, 1, 2, 3, 4, 5]}]
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
|
|
|
|
|
|
|
data_args = GlueDataTrainingArguments(
|
|
|
|
|
|
|
|
task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
|
|
|
|
|
|
|
|
data_collator = default_data_collator
|
|
|
|
|
|
|
|
batch = data_collator(dataset.features)
|
|
|
|
|
|
|
|
self.assertEqual(batch["labels"].dtype, torch.long)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@slow
|
|
|
|
data_collator = DataCollatorWithPadding(tokenizer)
|
|
|
|
def test_default_regression(self):
|
|
|
|
batch = data_collator(features)
|
|
|
|
MODEL_ID = "distilroberta-base"
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 6]))
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
|
|
|
self.assertEqual(batch["input_ids"][0].tolist(), [0, 1, 2] + [tokenizer.pad_token_id] * 3)
|
|
|
|
data_args = GlueDataTrainingArguments(
|
|
|
|
|
|
|
|
task_name="sts-b", data_dir="./tests/fixtures/tests_samples/STS-B", overwrite_cache=True
|
|
|
|
data_collator = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=10)
|
|
|
|
)
|
|
|
|
batch = data_collator(features)
|
|
|
|
dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 10]))
|
|
|
|
data_collator = default_data_collator
|
|
|
|
|
|
|
|
batch = data_collator(dataset.features)
|
|
|
|
data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
|
|
|
|
self.assertEqual(batch["labels"].dtype, torch.float)
|
|
|
|
batch = data_collator(features)
|
|
|
|
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 8]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_data_collator_for_token_classification(self):
|
|
|
|
|
|
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
|
|
|
|
|
|
features = [
|
|
|
|
|
|
|
|
{"input_ids": [0, 1, 2], "labels": [0, 1, 2]},
|
|
|
|
|
|
|
|
{"input_ids": [0, 1, 2, 3, 4, 5], "labels": [0, 1, 2, 3, 4, 5]},
|
|
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_collator = DataCollatorForTokenClassification(tokenizer)
|
|
|
|
|
|
|
|
batch = data_collator(features)
|
|
|
|
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 6]))
|
|
|
|
|
|
|
|
self.assertEqual(batch["input_ids"][0].tolist(), [0, 1, 2] + [tokenizer.pad_token_id] * 3)
|
|
|
|
|
|
|
|
self.assertEqual(batch["labels"].shape, torch.Size([2, 6]))
|
|
|
|
|
|
|
|
self.assertEqual(batch["labels"][0].tolist(), [0, 1, 2] + [-100] * 3)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_collator = DataCollatorForTokenClassification(tokenizer, padding="max_length", max_length=10)
|
|
|
|
|
|
|
|
batch = data_collator(features)
|
|
|
|
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 10]))
|
|
|
|
|
|
|
|
self.assertEqual(batch["labels"].shape, torch.Size([2, 10]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_collator = DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=8)
|
|
|
|
|
|
|
|
batch = data_collator(features)
|
|
|
|
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 8]))
|
|
|
|
|
|
|
|
self.assertEqual(batch["labels"].shape, torch.Size([2, 8]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_collator = DataCollatorForTokenClassification(tokenizer, label_pad_token_id=-1)
|
|
|
|
|
|
|
|
batch = data_collator(features)
|
|
|
|
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 6]))
|
|
|
|
|
|
|
|
self.assertEqual(batch["input_ids"][0].tolist(), [0, 1, 2] + [tokenizer.pad_token_id] * 3)
|
|
|
|
|
|
|
|
self.assertEqual(batch["labels"].shape, torch.Size([2, 6]))
|
|
|
|
|
|
|
|
self.assertEqual(batch["labels"][0].tolist(), [0, 1, 2] + [-1] * 3)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_data_collator_for_language_modeling(self):
|
|
|
|
|
|
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
|
|
|
|
|
|
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
|
|
|
|
|
|
|
pad_features = [{"input_ids": list(range(5))}, {"input_ids": list(range(10))}]
|
|
|
|
|
|
|
|
|
|
|
|
@slow
|
|
|
|
|
|
|
|
def test_lm_tokenizer_without_padding(self):
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
|
|
|
|
|
|
|
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
|
|
|
|
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
|
|
|
|
# ^ causal lm
|
|
|
|
batch = data_collator(no_pad_features)
|
|
|
|
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10)))
|
|
|
|
|
|
|
|
self.assertEqual(batch["labels"].shape, torch.Size((2, 10)))
|
|
|
|
|
|
|
|
|
|
|
|
dataset = LineByLineTextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512)
|
|
|
|
batch = data_collator(pad_features)
|
|
|
|
examples = [dataset[i] for i in range(len(dataset))]
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10)))
|
|
|
|
|
|
|
|
self.assertEqual(batch["labels"].shape, torch.Size((2, 10)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer._pad_token = None
|
|
|
|
|
|
|
|
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
# Expect error due to padding token missing on gpt2:
|
|
|
|
# Expect error due to padding token missing
|
|
|
|
data_collator(examples)
|
|
|
|
data_collator(pad_features)
|
|
|
|
|
|
|
|
|
|
|
|
dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True)
|
|
|
|
set_seed(42) # For reproducibility
|
|
|
|
examples = [dataset[i] for i in range(len(dataset))]
|
|
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
|
|
batch = data_collator(examples)
|
|
|
|
|
|
|
|
self.assertIsInstance(batch, dict)
|
|
|
|
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
|
|
|
|
|
|
|
|
self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@slow
|
|
|
|
|
|
|
|
def test_lm_tokenizer_with_padding(self):
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")
|
|
|
|
|
|
|
|
data_collator = DataCollatorForLanguageModeling(tokenizer)
|
|
|
|
data_collator = DataCollatorForLanguageModeling(tokenizer)
|
|
|
|
# ^ masked lm
|
|
|
|
batch = data_collator(no_pad_features)
|
|
|
|
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10)))
|
|
|
|
|
|
|
|
self.assertEqual(batch["labels"].shape, torch.Size((2, 10)))
|
|
|
|
|
|
|
|
|
|
|
|
dataset = LineByLineTextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512)
|
|
|
|
masked_tokens = batch["input_ids"] == tokenizer.mask_token_id
|
|
|
|
examples = [dataset[i] for i in range(len(dataset))]
|
|
|
|
self.assertTrue(torch.any(masked_tokens))
|
|
|
|
batch = data_collator(examples)
|
|
|
|
self.assertTrue(all(x == -100 for x in batch["labels"][~masked_tokens].tolist()))
|
|
|
|
self.assertIsInstance(batch, dict)
|
|
|
|
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size((31, 107)))
|
|
|
|
|
|
|
|
self.assertEqual(batch["labels"].shape, torch.Size((31, 107)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True)
|
|
|
|
batch = data_collator(pad_features)
|
|
|
|
examples = [dataset[i] for i in range(len(dataset))]
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10)))
|
|
|
|
batch = data_collator(examples)
|
|
|
|
self.assertEqual(batch["labels"].shape, torch.Size((2, 10)))
|
|
|
|
self.assertIsInstance(batch, dict)
|
|
|
|
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
|
|
|
|
masked_tokens = batch["input_ids"] == tokenizer.mask_token_id
|
|
|
|
self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
|
|
|
|
self.assertTrue(torch.any(masked_tokens))
|
|
|
|
|
|
|
|
self.assertTrue(all(x == -100 for x in batch["labels"][~masked_tokens].tolist()))
|
|
|
|
|
|
|
|
|
|
|
|
@slow
|
|
|
|
|
|
|
|
def test_plm(self):
|
|
|
|
def test_plm(self):
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("xlnet-base-cased")
|
|
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
|
|
|
|
|
|
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
|
|
|
|
|
|
|
pad_features = [{"input_ids": list(range(5))}, {"input_ids": list(range(10))}]
|
|
|
|
|
|
|
|
|
|
|
|
data_collator = DataCollatorForPermutationLanguageModeling(tokenizer)
|
|
|
|
data_collator = DataCollatorForPermutationLanguageModeling(tokenizer)
|
|
|
|
# ^ permutation lm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset = LineByLineTextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512)
|
|
|
|
batch = data_collator(pad_features)
|
|
|
|
examples = [dataset[i] for i in range(len(dataset))]
|
|
|
|
|
|
|
|
batch = data_collator(examples)
|
|
|
|
|
|
|
|
self.assertIsInstance(batch, dict)
|
|
|
|
self.assertIsInstance(batch, dict)
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size((31, 112)))
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10)))
|
|
|
|
self.assertEqual(batch["perm_mask"].shape, torch.Size((31, 112, 112)))
|
|
|
|
self.assertEqual(batch["perm_mask"].shape, torch.Size((2, 10, 10)))
|
|
|
|
self.assertEqual(batch["target_mapping"].shape, torch.Size((31, 112, 112)))
|
|
|
|
self.assertEqual(batch["target_mapping"].shape, torch.Size((2, 10, 10)))
|
|
|
|
self.assertEqual(batch["labels"].shape, torch.Size((31, 112)))
|
|
|
|
self.assertEqual(batch["labels"].shape, torch.Size((2, 10)))
|
|
|
|
|
|
|
|
|
|
|
|
dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True)
|
|
|
|
batch = data_collator(no_pad_features)
|
|
|
|
examples = [dataset[i] for i in range(len(dataset))]
|
|
|
|
|
|
|
|
batch = data_collator(examples)
|
|
|
|
|
|
|
|
self.assertIsInstance(batch, dict)
|
|
|
|
self.assertIsInstance(batch, dict)
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10)))
|
|
|
|
self.assertEqual(batch["perm_mask"].shape, torch.Size((2, 512, 512)))
|
|
|
|
self.assertEqual(batch["perm_mask"].shape, torch.Size((2, 10, 10)))
|
|
|
|
self.assertEqual(batch["target_mapping"].shape, torch.Size((2, 512, 512)))
|
|
|
|
self.assertEqual(batch["target_mapping"].shape, torch.Size((2, 10, 10)))
|
|
|
|
self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
|
|
|
|
self.assertEqual(batch["labels"].shape, torch.Size((2, 10)))
|
|
|
|
|
|
|
|
|
|
|
|
example = [torch.randint(5, [5])]
|
|
|
|
example = [torch.randint(5, [5])]
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
with self.assertRaises(ValueError):
|
|
|
|
# Expect error due to odd sequence length
|
|
|
|
# Expect error due to odd sequence length
|
|
|
|
data_collator(example)
|
|
|
|
data_collator(example)
|
|
|
|
|
|
|
|
|
|
|
|
@slow
|
|
|
|
|
|
|
|
def test_nsp(self):
|
|
|
|
def test_nsp(self):
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
|
|
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
|
|
|
|
|
|
features = [{"tokens_a": [0, 1, 2, 3, 4], "tokens_b": [0, 1, 2, 3, 4], "is_random_next": i} for i in range(2)]
|
|
|
|
data_collator = DataCollatorForNextSentencePrediction(tokenizer)
|
|
|
|
data_collator = DataCollatorForNextSentencePrediction(tokenizer)
|
|
|
|
|
|
|
|
batch = data_collator(features)
|
|
|
|
|
|
|
|
|
|
|
|
dataset = TextDatasetForNextSentencePrediction(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512)
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
|
|
|
|
examples = [dataset[i] for i in range(len(dataset))]
|
|
|
|
self.assertEqual(batch["token_type_ids"].shape, torch.Size((2, 512)))
|
|
|
|
batch = data_collator(examples)
|
|
|
|
self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
|
|
|
|
self.assertIsInstance(batch, dict)
|
|
|
|
self.assertEqual(batch["next_sentence_label"].shape, torch.Size((2,)))
|
|
|
|
|
|
|
|
|
|
|
|
# Since there are randomly generated false samples, the total number of samples is not fixed.
|
|
|
|
|
|
|
|
total_samples = batch["input_ids"].shape[0]
|
|
|
|
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size((total_samples, 512)))
|
|
|
|
|
|
|
|
self.assertEqual(batch["token_type_ids"].shape, torch.Size((total_samples, 512)))
|
|
|
|
|
|
|
|
self.assertEqual(batch["labels"].shape, torch.Size((total_samples, 512)))
|
|
|
|
|
|
|
|
self.assertEqual(batch["next_sentence_label"].shape, torch.Size((total_samples,)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@slow
|
|
|
|
|
|
|
|
def test_sop(self):
|
|
|
|
def test_sop(self):
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")
|
|
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
|
|
|
|
|
|
features = [
|
|
|
|
|
|
|
|
{
|
|
|
|
|
|
|
|
"input_ids": torch.tensor([0, 1, 2, 3, 4]),
|
|
|
|
|
|
|
|
"token_type_ids": torch.tensor([0, 1, 2, 3, 4]),
|
|
|
|
|
|
|
|
"sentence_order_label": torch.tensor(i),
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
for i in range(2)
|
|
|
|
|
|
|
|
]
|
|
|
|
data_collator = DataCollatorForSOP(tokenizer)
|
|
|
|
data_collator = DataCollatorForSOP(tokenizer)
|
|
|
|
|
|
|
|
batch = data_collator(features)
|
|
|
|
|
|
|
|
|
|
|
|
dataset = LineByLineWithSOPTextDataset(tokenizer, file_dir=PATH_SAMPLE_TEXT_DIR, block_size=512)
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 5)))
|
|
|
|
examples = [dataset[i] for i in range(len(dataset))]
|
|
|
|
self.assertEqual(batch["token_type_ids"].shape, torch.Size((2, 5)))
|
|
|
|
batch = data_collator(examples)
|
|
|
|
self.assertEqual(batch["labels"].shape, torch.Size((2, 5)))
|
|
|
|
self.assertIsInstance(batch, dict)
|
|
|
|
self.assertEqual(batch["sentence_order_label"].shape, torch.Size((2,)))
|
|
|
|
|
|
|
|
|
|
|
|
# Since there are randomly generated false samples, the total number of samples is not fixed.
|
|
|
|
|
|
|
|
total_samples = batch["input_ids"].shape[0]
|
|
|
|
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size((total_samples, 512)))
|
|
|
|
|
|
|
|
self.assertEqual(batch["token_type_ids"].shape, torch.Size((total_samples, 512)))
|
|
|
|
|
|
|
|
self.assertEqual(batch["labels"].shape, torch.Size((total_samples, 512)))
|
|
|
|
|
|
|
|
self.assertEqual(batch["sentence_order_label"].shape, torch.Size((total_samples,)))
|
|
|
|
|
|
|
|
|