Added data collator for permutation (XLNet) language modeling and related calls (#5522)
* Added data collator for XLNet language modeling and related calls Added DataCollatorForXLNetLanguageModeling in data/data_collator.py to generate necessary inputs for language modeling training with XLNetLMHeadModel. Also added related arguments, logic and calls in examples/language-modeling/run_language_modeling.py. Resolves: #4739, #2008 (partially) * Changed name to `DataCollatorForPermutationLanguageModeling` Changed the name of `DataCollatorForXLNetLanguageModeling` to the more general `DataCollatorForPermutationLanguageModelling`. Removed the `--mlm` flag requirement for the new collator and defined a separate `--plm_probability` flag for its use. CTRL uses a CLM loss just like GPT and GPT-2, so should work out of the box with this script (provided `past` is taken care of similar to `mems` for XLNet). Changed calls and imports appropriately. * Added detailed comments, changed variable names Added more detailed comments to `DataCollatorForPermutationLanguageModeling` in `data/data_collator.py` to explain working. Also cleaned up variable names and made them more informative. * Added tests for new data collator Added tests in `tests/test_trainer.py` for DataCollatorForPermutationLanguageModeling based on those in DataCollatorForLanguageModeling. A specific test has been added to check for odd-length sequences. * Fixed styling issues
This commit is contained in:
@@ -12,6 +12,7 @@ if is_torch_available():
|
||||
AutoModelForSequenceClassification,
|
||||
default_data_collator,
|
||||
DataCollatorForLanguageModeling,
|
||||
DataCollatorForPermutationLanguageModeling,
|
||||
GlueDataset,
|
||||
GlueDataTrainingArguments,
|
||||
TextDataset,
|
||||
@@ -123,6 +124,34 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
|
||||
self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
|
||||
|
||||
def test_plm(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("xlnet-base-cased")
|
||||
data_collator = DataCollatorForPermutationLanguageModeling(tokenizer)
|
||||
# ^ permutation lm
|
||||
|
||||
dataset = LineByLineTextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512)
|
||||
examples = [dataset[i] for i in range(len(dataset))]
|
||||
batch = data_collator(examples)
|
||||
self.assertIsInstance(batch, dict)
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size((31, 112)))
|
||||
self.assertEqual(batch["perm_mask"].shape, torch.Size((31, 112, 112)))
|
||||
self.assertEqual(batch["target_mapping"].shape, torch.Size((31, 112, 112)))
|
||||
self.assertEqual(batch["labels"].shape, torch.Size((31, 112)))
|
||||
|
||||
dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True)
|
||||
examples = [dataset[i] for i in range(len(dataset))]
|
||||
batch = data_collator(examples)
|
||||
self.assertIsInstance(batch, dict)
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
|
||||
self.assertEqual(batch["perm_mask"].shape, torch.Size((2, 512, 512)))
|
||||
self.assertEqual(batch["target_mapping"].shape, torch.Size((2, 512, 512)))
|
||||
self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
|
||||
|
||||
example = [torch.randint(5, [5])]
|
||||
with self.assertRaises(ValueError):
|
||||
# Expect error due to odd sequence length
|
||||
data_collator(example)
|
||||
|
||||
|
||||
@require_torch
|
||||
class TrainerIntegrationTest(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user