Added data collator for permutation (XLNet) language modeling and related calls (#5522)

* Added data collator for XLNet language modeling and related calls

Added DataCollatorForXLNetLanguageModeling in data/data_collator.py
to generate necessary inputs for language modeling training with
XLNetLMHeadModel. Also added related arguments, logic and calls in
examples/language-modeling/run_language_modeling.py.

Resolves: #4739, #2008 (partially)

* Changed name to `DataCollatorForPermutationLanguageModeling`

Changed the name of `DataCollatorForXLNetLanguageModeling` to the more general `DataCollatorForPermutationLanguageModelling`.
Removed the `--mlm` flag requirement for the new collator and defined a separate `--plm_probability` flag for its use.
CTRL uses a CLM loss just like GPT and GPT-2, so should work out of the box with this script (provided `past` is taken care of
similar to `mems` for XLNet).
Changed calls and imports appropriately.

* Added detailed comments, changed variable names

Added more detailed comments to `DataCollatorForPermutationLanguageModeling` in `data/data_collator.py` to explain working. Also cleaned up variable names and made them more informative.

* Added tests for new data collator

Added tests in `tests/test_trainer.py` for DataCollatorForPermutationLanguageModeling based on those in DataCollatorForLanguageModeling. A specific test has been added to check for odd-length sequences.

* Fixed styling issues
This commit is contained in:
Shashank Gupta
2020-07-07 13:47:37 +05:30
committed by GitHub
parent 1d2332861f
commit 3dcb748e31
4 changed files with 183 additions and 11 deletions

View File

@@ -12,6 +12,7 @@ if is_torch_available():
AutoModelForSequenceClassification,
default_data_collator,
DataCollatorForLanguageModeling,
DataCollatorForPermutationLanguageModeling,
GlueDataset,
GlueDataTrainingArguments,
TextDataset,
@@ -123,6 +124,34 @@ class DataCollatorIntegrationTest(unittest.TestCase):
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
def test_plm(self):
tokenizer = AutoTokenizer.from_pretrained("xlnet-base-cased")
data_collator = DataCollatorForPermutationLanguageModeling(tokenizer)
# ^ permutation lm
dataset = LineByLineTextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512)
examples = [dataset[i] for i in range(len(dataset))]
batch = data_collator(examples)
self.assertIsInstance(batch, dict)
self.assertEqual(batch["input_ids"].shape, torch.Size((31, 112)))
self.assertEqual(batch["perm_mask"].shape, torch.Size((31, 112, 112)))
self.assertEqual(batch["target_mapping"].shape, torch.Size((31, 112, 112)))
self.assertEqual(batch["labels"].shape, torch.Size((31, 112)))
dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True)
examples = [dataset[i] for i in range(len(dataset))]
batch = data_collator(examples)
self.assertIsInstance(batch, dict)
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
self.assertEqual(batch["perm_mask"].shape, torch.Size((2, 512, 512)))
self.assertEqual(batch["target_mapping"].shape, torch.Size((2, 512, 512)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
example = [torch.randint(5, [5])]
with self.assertRaises(ValueError):
# Expect error due to odd sequence length
data_collator(example)
@require_torch
class TrainerIntegrationTest(unittest.TestCase):