diff --git a/examples/run_squad.py b/examples/run_squad.py index 86d00bd770..4cd555fa73 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -219,6 +219,11 @@ def train(args, train_dataset, model, tokenizer): inputs.update({"cls_index": batch[5], "p_mask": batch[6]}) if args.version_2_with_negative: inputs.update({"is_impossible": batch[7]}) + if hasattr(model, "config") and hasattr(model.config, "lang2id"): + inputs.update( + {"langs": (torch.ones(batch[0].shape, dtype=torch.int64) * args.lang_id).to(args.device)} + ) + outputs = model(**inputs) # model outputs are always tuple in transformers (see doc) loss = outputs[0] @@ -330,6 +335,11 @@ def evaluate(args, model, tokenizer, prefix=""): # XLNet and XLM use more arguments for their predictions if args.model_type in ["xlnet", "xlm"]: inputs.update({"cls_index": batch[4], "p_mask": batch[5]}) + # for lang_id-sensitive xlm models + if hasattr(model, "config") and hasattr(model.config, "lang2id"): + inputs.update( + {"langs": (torch.ones(batch[0].shape, dtype=torch.int64) * args.lang_id).to(args.device)} + ) outputs = model(**inputs) @@ -635,6 +645,12 @@ def main(): help="If true, all of the warnings related to data processing will be printed. " "A number of warnings are expected for a normal SQuAD evaluation.", ) + parser.add_argument( + "--lang_id", + default=0, + type=int, + help="language id of input for language-specific xlm models (see tokenization_xlm.PRETRAINED_INIT_CONFIGURATION)", + ) parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.") parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")