Added data collator for permutation (XLNet) language modeling and related calls (#5522)
* Added data collator for XLNet language modeling and related calls Added DataCollatorForXLNetLanguageModeling in data/data_collator.py to generate necessary inputs for language modeling training with XLNetLMHeadModel. Also added related arguments, logic and calls in examples/language-modeling/run_language_modeling.py. Resolves: #4739, #2008 (partially) * Changed name to `DataCollatorForPermutationLanguageModeling` Changed the name of `DataCollatorForXLNetLanguageModeling` to the more general `DataCollatorForPermutationLanguageModelling`. Removed the `--mlm` flag requirement for the new collator and defined a separate `--plm_probability` flag for its use. CTRL uses a CLM loss just like GPT and GPT-2, so should work out of the box with this script (provided `past` is taken care of similar to `mems` for XLNet). Changed calls and imports appropriately. * Added detailed comments, changed variable names Added more detailed comments to `DataCollatorForPermutationLanguageModeling` in `data/data_collator.py` to explain working. Also cleaned up variable names and made them more informative. * Added tests for new data collator Added tests in `tests/test_trainer.py` for DataCollatorForPermutationLanguageModeling based on those in DataCollatorForLanguageModeling. A specific test has been added to check for odd-length sequences. * Fixed styling issues
This commit is contained in:
@@ -14,9 +14,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa).
|
||||
GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned
|
||||
using a masked language modeling (MLM) loss.
|
||||
Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, CTRL, BERT, RoBERTa, XLNet).
|
||||
GPT, GPT-2 and CTRL are fine-tuned using a causal language modeling (CLM) loss. BERT and RoBERTa are fine-tuned
|
||||
using a masked language modeling (MLM) loss. XLNet is fine-tuned using a permutation language modeling (PLM) loss.
|
||||
"""
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ from transformers import (
|
||||
AutoModelWithLMHead,
|
||||
AutoTokenizer,
|
||||
DataCollatorForLanguageModeling,
|
||||
DataCollatorForPermutationLanguageModeling,
|
||||
HfArgumentParser,
|
||||
LineByLineTextDataset,
|
||||
PreTrainedTokenizer,
|
||||
@@ -101,6 +102,15 @@ class DataTrainingArguments:
|
||||
mlm_probability: float = field(
|
||||
default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
|
||||
)
|
||||
plm_probability: float = field(
|
||||
default=1 / 6,
|
||||
metadata={
|
||||
"help": "Ratio of length of a span of masked tokens to surrounding context length for permutation language modeling."
|
||||
},
|
||||
)
|
||||
max_span_length: int = field(
|
||||
default=5, metadata={"help": "Maximum length of a span of masked tokens for permutation language modeling."}
|
||||
)
|
||||
|
||||
block_size: int = field(
|
||||
default=-1,
|
||||
@@ -207,8 +217,8 @@ def main():
|
||||
|
||||
if config.model_type in ["bert", "roberta", "distilbert", "camembert"] and not data_args.mlm:
|
||||
raise ValueError(
|
||||
"BERT and RoBERTa-like models do not have LM heads but masked LM heads. They must be run using the --mlm "
|
||||
"flag (masked language modeling)."
|
||||
"BERT and RoBERTa-like models do not have LM heads but masked LM heads. They must be run using the"
|
||||
"--mlm flag (masked language modeling)."
|
||||
)
|
||||
|
||||
if data_args.block_size <= 0:
|
||||
@@ -221,9 +231,14 @@ def main():
|
||||
|
||||
train_dataset = get_dataset(data_args, tokenizer=tokenizer) if training_args.do_train else None
|
||||
eval_dataset = get_dataset(data_args, tokenizer=tokenizer, evaluate=True) if training_args.do_eval else None
|
||||
data_collator = DataCollatorForLanguageModeling(
|
||||
tokenizer=tokenizer, mlm=data_args.mlm, mlm_probability=data_args.mlm_probability
|
||||
)
|
||||
if config.model_type == "xlnet":
|
||||
data_collator = DataCollatorForPermutationLanguageModeling(
|
||||
tokenizer=tokenizer, plm_probability=data_args.plm_probability, max_span_length=data_args.max_span_length,
|
||||
)
|
||||
else:
|
||||
data_collator = DataCollatorForLanguageModeling(
|
||||
tokenizer=tokenizer, mlm=data_args.mlm, mlm_probability=data_args.mlm_probability
|
||||
)
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = Trainer(
|
||||
|
||||
Reference in New Issue
Block a user