diff --git a/.circleci/config.yml b/.circleci/config.yml index c49cf7df8a..30a4458807 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1,5 +1,20 @@ version: 2 jobs: + build_py3_torch_and_tf: + working_directory: ~/pytorch-transformers + docker: + - image: circleci/python:3.5 + resource_class: xlarge + parallelism: 1 + steps: + - checkout + - run: sudo pip install torch + - run: sudo pip install tensorflow==2.0.0-rc0 + - run: sudo pip install --progress-bar off . + - run: sudo pip install pytest codecov pytest-cov + - run: sudo pip install tensorboardX scikit-learn + - run: python -m pytest -sv ./pytorch_transformers/tests/ --cov + - run: codecov build_py3_torch: working_directory: ~/pytorch-transformers docker: diff --git a/pytorch_transformers/__init__.py b/pytorch_transformers/__init__.py index 907115f70d..b8c7eccfe7 100644 --- a/pytorch_transformers/__init__.py +++ b/pytorch_transformers/__init__.py @@ -73,7 +73,8 @@ if _torch_available: load_tf_weights_in_xlnet, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP) from .modeling_xlm import (XLMPreTrainedModel , XLMModel, XLMWithLMHeadModel, XLMForSequenceClassification, - XLMForQuestionAnswering, XLM_PRETRAINED_MODEL_ARCHIVE_MAP) + XLMForQuestionAnswering, XLMForQuestionAnsweringSimple, + XLM_PRETRAINED_MODEL_ARCHIVE_MAP) from .modeling_roberta import (RobertaForMaskedLM, RobertaModel, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP) from .modeling_distilbert import (DistilBertForMaskedLM, DistilBertModel, @@ -150,6 +151,15 @@ if _tf_available: load_distilbert_pt_weights_in_tf2, TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP) +if _tf_available and _torch_available: + from .modeling_tf_pytorch_utils import (convert_tf_weight_name_to_pt_weight_name, + load_pytorch_checkpoint_in_tf2_model, + load_pytorch_weights_in_tf2_model, + load_pytorch_model_in_tf2_model, + load_tf2_checkpoint_in_pytorch_model, + load_tf2_weights_in_pytorch_model, + load_tf2_model_in_pytorch_model) + # Files and general utilities from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE, cached_path, add_start_docstrings, add_end_docstrings, diff --git a/pytorch_transformers/modeling_tf_pytorch_utils.py b/pytorch_transformers/modeling_tf_pytorch_utils.py index 12e5023802..8e879e1447 100644 --- a/pytorch_transformers/modeling_tf_pytorch_utils.py +++ b/pytorch_transformers/modeling_tf_pytorch_utils.py @@ -20,15 +20,49 @@ from __future__ import (absolute_import, division, print_function, import logging import os +import re +import numpy logger = logging.getLogger(__name__) -def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=None): - """ Load pytorch checkpoints in a TF 2.0 model +def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove=''): + """ Convert a TF 2.0 model variable name in a pytorch model weight name. + Conventions for TF2.0 scopes -> PyTorch attribute names conversions: - '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch) - '_._' is replaced by a new level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList) + + return tuple with: + - pytorch model weight name + - transpose: boolean indicating weither TF2.0 and PyTorch weights matrices are transposed with regards to each other + """ + tf_name = tf_name.replace(':0', '') # device ids + tf_name = re.sub(r'/[^/]*___([^/]*)/', r'/\1/', tf_name) # '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch) + tf_name = tf_name.replace('_._', '/') # '_._' is replaced by a level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList) + tf_name = re.sub(r'//+', '/', tf_name) # Remove empty levels at the end + tf_name = tf_name.split('/') # Convert from TF2.0 '/' separators to PyTorch '.' separators + tf_name = tf_name[1:] # Remove level zero + + # When should we transpose the weights + transpose = bool(tf_name[-1] == 'kernel' or 'emb_projs' in tf_name or 'out_projs' in tf_name) + + # Convert standard TF2.0 names in PyTorch names + if tf_name[-1] == 'kernel' or tf_name[-1] == 'embeddings' or tf_name[-1] == 'gamma': + tf_name[-1] = 'weight' + if tf_name[-1] == 'beta': + tf_name[-1] = 'bias' + + # Remove prefix if needed + tf_name = '.'.join(tf_name) + if start_prefix_to_remove: + tf_name = tf_name.replace(start_prefix_to_remove, '', 1) + + return tf_name, transpose + + +def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=None): + """ Load pytorch checkpoints in a TF 2.0 model """ try: import tensorflow as tf @@ -43,25 +77,31 @@ def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_i pt_state_dict = torch.load(pt_path, map_location='cpu') - return load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict, tf_inputs=tf_inputs) + return load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=tf_inputs) -def load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None): +def load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=None): + """ Load pytorch checkpoints in a TF 2.0 model + """ + pt_state_dict = pt_model.state_dict() + + return load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=tf_inputs) + + +def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None): """ Load pytorch state_dict in a TF 2.0 model. - Conventions for TF2.0 scopes -> PyTorch attribute names conversions: - - '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch) - - '_._' is replaced by a level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList) """ try: - import re import torch - import numpy from tensorflow.python.keras import backend as K except ImportError as e: logger.error("Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see " "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions.") raise e + if tf_inputs is not None: + tfo = tf_model(tf_inputs, training=False) # Make sure model is built + # Adapt state dict - TODO remove this and update the AWS weights files instead # Convert old format to new format if needed from a PyTorch state_dict old_keys = [] @@ -89,27 +129,8 @@ def load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None weight_value_tuples = [] all_pytorch_weights = set(list(pt_state_dict.keys())) for symbolic_weight in symbolic_weights: - name = symbolic_weight.name - name = name.replace(':0', '') # device ids - name = re.sub(r'/[^/]*___([^/]*)/', r'/\1/', name) # '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch) - name = name.replace('_._', '/') # '_._' is replaced by a level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList) - name = re.sub(r'//+', '/', name) # Remove empty levels at the end - name = name.split('/') # Convert from TF2.0 '/' separators to PyTorch '.' separators - name = name[1:] # Remove level zero - - # When should we transpose the weights - transpose = bool(name[-1] == 'kernel' or 'emb_projs' in name or 'out_projs' in name) - - # Convert standard TF2.0 names in PyTorch names - if name[-1] == 'kernel' or name[-1] == 'embeddings' or name[-1] == 'gamma': - name[-1] = 'weight' - if name[-1] == 'beta': - name[-1] = 'bias' - - # Remove prefix if needed - name = '.'.join(name) - if start_prefix_to_remove: - name = name.replace(start_prefix_to_remove, '', 1) + sw_name = symbolic_weight.name + name, transpose = convert_tf_weight_name_to_pt_weight_name(sw_name, start_prefix_to_remove=start_prefix_to_remove) # Find associated numpy array in pytorch model state dict assert name in pt_state_dict, "{} not found in PyTorch model".format(name) @@ -144,13 +165,10 @@ def load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None return tf_model -def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path): +def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs=None): """ Load TF 2.0 HDF5 checkpoint in a PyTorch model We use HDF5 to easily do transfer learning (see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357). - Conventions for TF2.0 scopes -> PyTorch attribute names conversions: - - '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch) - - '_._' is replaced by a new level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList) """ try: import tensorflow as tf @@ -161,13 +179,97 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path): raise e tf_path = os.path.abspath(tf_checkpoint_path) - logger.info("Loading TensorFlow weights from {}".format(tf_path)) + logger.info("Loading TensorFlow weights from {}".format(tf_checkpoint_path)) - tf_state_dict = torch.load(tf_path, map_location='cpu') + # Instantiate and load the associated TF 2.0 model + tf_model_class_name = "TF" + model_class.__name__ # Add "TF" at the beggining + tf_model_class = getattr(pytorch_transformers, tf_model_class_name) + tf_model = tf_model_class(pt_model.config) - return load_tf2_weights_in_pytorch_model(pt_model, tf_state_dict) + if tf_inputs is not None: + tfo = tf_model(tf_inputs, training=False) # Make sure model is built -def load_tf2_weights_in_pytorch_model(pt_model, tf_model): + tf_model.load_weights(tf_checkpoint_path, by_name=True) + + return load_tf2_model_in_pytorch_model(pt_model, tf_model) + +def load_tf2_model_in_pytorch_model(pt_model, tf_model): + """ Load TF 2.0 model in a pytorch model + """ + weights = tf_model.weights + + return load_tf2_weights_in_pytorch_model(pt_model, weights) + + +def load_tf2_weights_in_pytorch_model(pt_model, tf_weights): """ Load TF2.0 symbolic weights in a PyTorch model """ - raise NotImplementedError + try: + import tensorflow as tf + import torch + except ImportError as e: + logger.error("Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see " + "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions.") + raise e + + new_pt_params_dict = {} + current_pt_params_dict = dict(pt_model.named_parameters()) + + # Make sure we are able to load PyTorch base models as well as derived models (with heads) + # TF models always have a prefix, some of PyTorch models (base ones) don't + start_prefix_to_remove = '' + if not any(s.startswith(pt_model.base_model_prefix) for s in current_pt_params_dict.keys()): + start_prefix_to_remove = pt_model.base_model_prefix + '.' + + # Build a map from potential PyTorch weight names to TF 2.0 Variables + tf_weights_map = {} + for tf_weight in tf_weights: + pt_name, transpose = convert_tf_weight_name_to_pt_weight_name(tf_weight.name, start_prefix_to_remove=start_prefix_to_remove) + tf_weights_map[pt_name] = (tf_weight.numpy(), transpose) + + all_tf_weights = set(list(tf_weights_map.keys())) + loaded_pt_weights_data_ptr = {} + for pt_weight_name, pt_weight in current_pt_params_dict.items(): + # Handle PyTorch shared weight ()not duplicated in TF 2.0 + if pt_weight.data_ptr() in loaded_pt_weights_data_ptr: + new_pt_params_dict[pt_weight_name] = loaded_pt_weights_data_ptr[pt_weight.data_ptr()] + continue + + # Find associated numpy array in pytorch model state dict + if pt_weight_name not in tf_weights_map: + raise ValueError("{} not found in TF 2.0 model".format(pt_weight_name)) + + array, transpose = tf_weights_map[pt_weight_name] + + if transpose: + array = numpy.transpose(array) + + if len(pt_weight.shape) < len(array.shape): + array = numpy.squeeze(array) + elif len(pt_weight.shape) > len(array.shape): + array = numpy.expand_dims(array, axis=0) + + try: + assert list(pt_weight.shape) == list(array.shape) + except AssertionError as e: + e.args += (pt_weight.shape, array.shape) + raise e + + logger.info("Initialize PyTorch weight {}".format(pt_weight_name)) + + new_pt_params_dict[pt_weight_name] = torch.from_numpy(array) + loaded_pt_weights_data_ptr[pt_weight.data_ptr()] = torch.from_numpy(array) + all_tf_weights.discard(pt_weight_name) + + missing_keys, unexpected_keys = pt_model.load_state_dict(new_pt_params_dict, strict=False) + + if len(missing_keys) > 0: + logger.info("Weights of {} not initialized from TF 2.0 model: {}".format( + pt_model.__class__.__name__, missing_keys)) + if len(unexpected_keys) > 0: + logger.info("Weights from TF 2.0 model not used in {}: {}".format( + pt_model.__class__.__name__, unexpected_keys)) + + logger.info("Weights or buffers not loaded from TF 2.0 model: {}".format(all_tf_weights)) + + return pt_model diff --git a/pytorch_transformers/modeling_xlm.py b/pytorch_transformers/modeling_xlm.py index 92febd296d..95629ba535 100644 --- a/pytorch_transformers/modeling_xlm.py +++ b/pytorch_transformers/modeling_xlm.py @@ -718,6 +718,101 @@ class XLMForSequenceClassification(XLMPreTrainedModel): @add_start_docstrings("""XLM Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """, XLM_START_DOCSTRING, XLM_INPUTS_DOCSTRING) +class XLMForQuestionAnsweringSimple(XLMPreTrainedModel): + r""" + **start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). + Position outside of the sequence are not taken into account for computing the loss. + **end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). + Position outside of the sequence are not taken into account for computing the loss. + **is_impossible**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: + Labels whether a question has an answer or no answer (SQuAD 2.0) + **cls_index**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: + Labels for position (index) of the classification token to use as input for computing plausibility of the answer. + **p_mask**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: + Optional mask of tokens which can't be in answers (e.g. [CLS], [PAD], ...) + + Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: + **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: + Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. + **start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)`` + Span-start scores (before SoftMax). + **end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)`` + Span-end scores (before SoftMax). + **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) + list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) + of shape ``(batch_size, sequence_length, hidden_size)``: + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + **attentions**: (`optional`, returned when ``config.output_attentions=True``) + list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + + Examples:: + + tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048') + model = XLMForQuestionAnsweringSimple.from_pretrained('xlm-mlm-en-2048') + input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 + start_positions = torch.tensor([1]) + end_positions = torch.tensor([3]) + outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions) + loss, start_scores, end_scores = outputs[:2] + + """ + def __init__(self, config): + super(XLMForQuestionAnsweringSimple, self).__init__(config) + + self.transformer = XLMModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() + + def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None, + lengths=None, cache=None, head_mask=None, start_positions=None, end_positions=None): + transformer_outputs = self.transformer(input_ids, + attention_mask=attention_mask, + langs=langs, + token_type_ids=token_type_ids, + position_ids=position_ids, + lengths=lengths, + cache=cache, + head_mask=head_mask) + + sequence_output = transformer_outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + outputs = (start_logits, end_logits,) + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions.clamp_(0, ignored_index) + end_positions.clamp_(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + outputs = (total_loss,) + outputs + + outputs = outputs + transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here + + return outputs + + +@add_start_docstrings("""XLM Model with a beam-search span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of + the hidden-states output to compute `span start logits` and `span end logits`). """, + XLM_START_DOCSTRING, XLM_INPUTS_DOCSTRING) class XLMForQuestionAnswering(XLMPreTrainedModel): r""" **start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: diff --git a/pytorch_transformers/tests/modeling_tf_common_test.py b/pytorch_transformers/tests/modeling_tf_common_test.py index 332db01408..ac25320189 100644 --- a/pytorch_transformers/tests/modeling_tf_common_test.py +++ b/pytorch_transformers/tests/modeling_tf_common_test.py @@ -17,6 +17,7 @@ from __future__ import absolute_import, division, print_function import copy import json import logging +import importlib import random import shutil import unittest @@ -25,7 +26,7 @@ import uuid import pytest import sys -from pytorch_transformers import is_tf_available +from pytorch_transformers import is_tf_available, is_torch_available if is_tf_available(): import tensorflow as tf @@ -66,6 +67,24 @@ class TFCommonTestCases: # msg="Parameter {} of model {} seems not properly initialized".format(name, model_class)) + def test_pt_tf_model_equivalence(self): + if not is_torch_available(): + pass + import pytorch_transformers + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beggining + pt_model_class = getattr(pytorch_transformers, pt_model_class_name) + + tf_model = model_class(config) + pt_model = pt_model_class(config) + + tf_model = pytorch_transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=inputs_dict) + pt_model = pytorch_transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model) + + def test_keyword_and_dict_args(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/pytorch_transformers/tests/modeling_tf_xlm_test.py b/pytorch_transformers/tests/modeling_tf_xlm_test.py index 5aed818ddd..26329bebb6 100644 --- a/pytorch_transformers/tests/modeling_tf_xlm_test.py +++ b/pytorch_transformers/tests/modeling_tf_xlm_test.py @@ -225,7 +225,7 @@ class TFXLMModelTest(TFCommonTestCases.TFCommonModelTester): config_and_inputs = self.prepare_config_and_inputs() (config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask) = config_and_inputs - inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'lengths': input_lengths} + inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'langs': token_type_ids, 'lengths': input_lengths} return config, inputs_dict def setUp(self): diff --git a/pytorch_transformers/tests/modeling_xlm_test.py b/pytorch_transformers/tests/modeling_xlm_test.py index 4f7f81c002..c3f75e2623 100644 --- a/pytorch_transformers/tests/modeling_xlm_test.py +++ b/pytorch_transformers/tests/modeling_xlm_test.py @@ -24,7 +24,7 @@ from pytorch_transformers import is_torch_available if is_torch_available(): from pytorch_transformers import (XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering, - XLMForSequenceClassification) + XLMForSequenceClassification, XLMForQuestionAnsweringSimple) from pytorch_transformers.modeling_xlm import XLM_PRETRAINED_MODEL_ARCHIVE_MAP else: pytestmark = pytest.mark.skip("Require Torch") @@ -36,7 +36,7 @@ from .configuration_common_test import ConfigTester class XLMModelTest(CommonTestCases.CommonModelTester): all_model_classes = (XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering, - XLMForSequenceClassification) if is_torch_available() else () + XLMForSequenceClassification, XLMForQuestionAnsweringSimple) if is_torch_available() else () class XLMModelTester(object): @@ -180,6 +180,30 @@ class XLMModelTest(CommonTestCases.CommonModelTester): [self.batch_size, self.seq_length, self.vocab_size]) + def create_and_check_xlm_simple_qa(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask): + model = XLMForQuestionAnsweringSimple(config) + model.eval() + + outputs = model(input_ids) + + outputs = model(input_ids, start_positions=sequence_labels, + end_positions=sequence_labels) + loss, start_logits, end_logits = outputs + + result = { + "loss": loss, + "start_logits": start_logits, + "end_logits": end_logits, + } + self.parent.assertListEqual( + list(result["start_logits"].size()), + [self.batch_size, self.seq_length]) + self.parent.assertListEqual( + list(result["end_logits"].size()), + [self.batch_size, self.seq_length]) + self.check_loss_output(result) + + def create_and_check_xlm_qa(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask): model = XLMForQuestionAnswering(config) model.eval() @@ -276,6 +300,10 @@ class XLMModelTest(CommonTestCases.CommonModelTester): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_xlm_lm_head(*config_and_inputs) + def test_xlm_simple_qa(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_xlm_simple_qa(*config_and_inputs) + def test_xlm_qa(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_xlm_qa(*config_and_inputs)