bidirectional conversion TF <=> PT - extended tests
This commit is contained in:
@@ -1,5 +1,20 @@
|
||||
version: 2
|
||||
jobs:
|
||||
build_py3_torch_and_tf:
|
||||
working_directory: ~/pytorch-transformers
|
||||
docker:
|
||||
- image: circleci/python:3.5
|
||||
resource_class: xlarge
|
||||
parallelism: 1
|
||||
steps:
|
||||
- checkout
|
||||
- run: sudo pip install torch
|
||||
- run: sudo pip install tensorflow==2.0.0-rc0
|
||||
- run: sudo pip install --progress-bar off .
|
||||
- run: sudo pip install pytest codecov pytest-cov
|
||||
- run: sudo pip install tensorboardX scikit-learn
|
||||
- run: python -m pytest -sv ./pytorch_transformers/tests/ --cov
|
||||
- run: codecov
|
||||
build_py3_torch:
|
||||
working_directory: ~/pytorch-transformers
|
||||
docker:
|
||||
|
||||
@@ -73,7 +73,8 @@ if _torch_available:
|
||||
load_tf_weights_in_xlnet, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_xlm import (XLMPreTrainedModel , XLMModel,
|
||||
XLMWithLMHeadModel, XLMForSequenceClassification,
|
||||
XLMForQuestionAnswering, XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
XLMForQuestionAnswering, XLMForQuestionAnsweringSimple,
|
||||
XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_roberta import (RobertaForMaskedLM, RobertaModel, RobertaForSequenceClassification,
|
||||
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_distilbert import (DistilBertForMaskedLM, DistilBertModel,
|
||||
@@ -150,6 +151,15 @@ if _tf_available:
|
||||
load_distilbert_pt_weights_in_tf2,
|
||||
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
|
||||
if _tf_available and _torch_available:
|
||||
from .modeling_tf_pytorch_utils import (convert_tf_weight_name_to_pt_weight_name,
|
||||
load_pytorch_checkpoint_in_tf2_model,
|
||||
load_pytorch_weights_in_tf2_model,
|
||||
load_pytorch_model_in_tf2_model,
|
||||
load_tf2_checkpoint_in_pytorch_model,
|
||||
load_tf2_weights_in_pytorch_model,
|
||||
load_tf2_model_in_pytorch_model)
|
||||
|
||||
# Files and general utilities
|
||||
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
|
||||
cached_path, add_start_docstrings, add_end_docstrings,
|
||||
|
||||
@@ -20,15 +20,49 @@ from __future__ import (absolute_import, division, print_function,
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import numpy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=None):
|
||||
""" Load pytorch checkpoints in a TF 2.0 model
|
||||
def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove=''):
|
||||
""" Convert a TF 2.0 model variable name in a pytorch model weight name.
|
||||
|
||||
Conventions for TF2.0 scopes -> PyTorch attribute names conversions:
|
||||
- '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
|
||||
- '_._' is replaced by a new level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
|
||||
|
||||
return tuple with:
|
||||
- pytorch model weight name
|
||||
- transpose: boolean indicating weither TF2.0 and PyTorch weights matrices are transposed with regards to each other
|
||||
"""
|
||||
tf_name = tf_name.replace(':0', '') # device ids
|
||||
tf_name = re.sub(r'/[^/]*___([^/]*)/', r'/\1/', tf_name) # '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
|
||||
tf_name = tf_name.replace('_._', '/') # '_._' is replaced by a level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
|
||||
tf_name = re.sub(r'//+', '/', tf_name) # Remove empty levels at the end
|
||||
tf_name = tf_name.split('/') # Convert from TF2.0 '/' separators to PyTorch '.' separators
|
||||
tf_name = tf_name[1:] # Remove level zero
|
||||
|
||||
# When should we transpose the weights
|
||||
transpose = bool(tf_name[-1] == 'kernel' or 'emb_projs' in tf_name or 'out_projs' in tf_name)
|
||||
|
||||
# Convert standard TF2.0 names in PyTorch names
|
||||
if tf_name[-1] == 'kernel' or tf_name[-1] == 'embeddings' or tf_name[-1] == 'gamma':
|
||||
tf_name[-1] = 'weight'
|
||||
if tf_name[-1] == 'beta':
|
||||
tf_name[-1] = 'bias'
|
||||
|
||||
# Remove prefix if needed
|
||||
tf_name = '.'.join(tf_name)
|
||||
if start_prefix_to_remove:
|
||||
tf_name = tf_name.replace(start_prefix_to_remove, '', 1)
|
||||
|
||||
return tf_name, transpose
|
||||
|
||||
|
||||
def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=None):
|
||||
""" Load pytorch checkpoints in a TF 2.0 model
|
||||
"""
|
||||
try:
|
||||
import tensorflow as tf
|
||||
@@ -43,25 +77,31 @@ def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_i
|
||||
|
||||
pt_state_dict = torch.load(pt_path, map_location='cpu')
|
||||
|
||||
return load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict, tf_inputs=tf_inputs)
|
||||
return load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=tf_inputs)
|
||||
|
||||
|
||||
def load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None):
|
||||
def load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=None):
|
||||
""" Load pytorch checkpoints in a TF 2.0 model
|
||||
"""
|
||||
pt_state_dict = pt_model.state_dict()
|
||||
|
||||
return load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=tf_inputs)
|
||||
|
||||
|
||||
def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None):
|
||||
""" Load pytorch state_dict in a TF 2.0 model.
|
||||
Conventions for TF2.0 scopes -> PyTorch attribute names conversions:
|
||||
- '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
|
||||
- '_._' is replaced by a level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
|
||||
"""
|
||||
try:
|
||||
import re
|
||||
import torch
|
||||
import numpy
|
||||
from tensorflow.python.keras import backend as K
|
||||
except ImportError as e:
|
||||
logger.error("Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
|
||||
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions.")
|
||||
raise e
|
||||
|
||||
if tf_inputs is not None:
|
||||
tfo = tf_model(tf_inputs, training=False) # Make sure model is built
|
||||
|
||||
# Adapt state dict - TODO remove this and update the AWS weights files instead
|
||||
# Convert old format to new format if needed from a PyTorch state_dict
|
||||
old_keys = []
|
||||
@@ -89,27 +129,8 @@ def load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None
|
||||
weight_value_tuples = []
|
||||
all_pytorch_weights = set(list(pt_state_dict.keys()))
|
||||
for symbolic_weight in symbolic_weights:
|
||||
name = symbolic_weight.name
|
||||
name = name.replace(':0', '') # device ids
|
||||
name = re.sub(r'/[^/]*___([^/]*)/', r'/\1/', name) # '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
|
||||
name = name.replace('_._', '/') # '_._' is replaced by a level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
|
||||
name = re.sub(r'//+', '/', name) # Remove empty levels at the end
|
||||
name = name.split('/') # Convert from TF2.0 '/' separators to PyTorch '.' separators
|
||||
name = name[1:] # Remove level zero
|
||||
|
||||
# When should we transpose the weights
|
||||
transpose = bool(name[-1] == 'kernel' or 'emb_projs' in name or 'out_projs' in name)
|
||||
|
||||
# Convert standard TF2.0 names in PyTorch names
|
||||
if name[-1] == 'kernel' or name[-1] == 'embeddings' or name[-1] == 'gamma':
|
||||
name[-1] = 'weight'
|
||||
if name[-1] == 'beta':
|
||||
name[-1] = 'bias'
|
||||
|
||||
# Remove prefix if needed
|
||||
name = '.'.join(name)
|
||||
if start_prefix_to_remove:
|
||||
name = name.replace(start_prefix_to_remove, '', 1)
|
||||
sw_name = symbolic_weight.name
|
||||
name, transpose = convert_tf_weight_name_to_pt_weight_name(sw_name, start_prefix_to_remove=start_prefix_to_remove)
|
||||
|
||||
# Find associated numpy array in pytorch model state dict
|
||||
assert name in pt_state_dict, "{} not found in PyTorch model".format(name)
|
||||
@@ -144,13 +165,10 @@ def load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None
|
||||
return tf_model
|
||||
|
||||
|
||||
def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path):
|
||||
def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs=None):
|
||||
""" Load TF 2.0 HDF5 checkpoint in a PyTorch model
|
||||
We use HDF5 to easily do transfer learning
|
||||
(see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
|
||||
Conventions for TF2.0 scopes -> PyTorch attribute names conversions:
|
||||
- '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
|
||||
- '_._' is replaced by a new level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
|
||||
"""
|
||||
try:
|
||||
import tensorflow as tf
|
||||
@@ -161,13 +179,97 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path):
|
||||
raise e
|
||||
|
||||
tf_path = os.path.abspath(tf_checkpoint_path)
|
||||
logger.info("Loading TensorFlow weights from {}".format(tf_path))
|
||||
logger.info("Loading TensorFlow weights from {}".format(tf_checkpoint_path))
|
||||
|
||||
tf_state_dict = torch.load(tf_path, map_location='cpu')
|
||||
# Instantiate and load the associated TF 2.0 model
|
||||
tf_model_class_name = "TF" + model_class.__name__ # Add "TF" at the beggining
|
||||
tf_model_class = getattr(pytorch_transformers, tf_model_class_name)
|
||||
tf_model = tf_model_class(pt_model.config)
|
||||
|
||||
return load_tf2_weights_in_pytorch_model(pt_model, tf_state_dict)
|
||||
if tf_inputs is not None:
|
||||
tfo = tf_model(tf_inputs, training=False) # Make sure model is built
|
||||
|
||||
def load_tf2_weights_in_pytorch_model(pt_model, tf_model):
|
||||
tf_model.load_weights(tf_checkpoint_path, by_name=True)
|
||||
|
||||
return load_tf2_model_in_pytorch_model(pt_model, tf_model)
|
||||
|
||||
def load_tf2_model_in_pytorch_model(pt_model, tf_model):
|
||||
""" Load TF 2.0 model in a pytorch model
|
||||
"""
|
||||
weights = tf_model.weights
|
||||
|
||||
return load_tf2_weights_in_pytorch_model(pt_model, weights)
|
||||
|
||||
|
||||
def load_tf2_weights_in_pytorch_model(pt_model, tf_weights):
|
||||
""" Load TF2.0 symbolic weights in a PyTorch model
|
||||
"""
|
||||
raise NotImplementedError
|
||||
try:
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
except ImportError as e:
|
||||
logger.error("Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
|
||||
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions.")
|
||||
raise e
|
||||
|
||||
new_pt_params_dict = {}
|
||||
current_pt_params_dict = dict(pt_model.named_parameters())
|
||||
|
||||
# Make sure we are able to load PyTorch base models as well as derived models (with heads)
|
||||
# TF models always have a prefix, some of PyTorch models (base ones) don't
|
||||
start_prefix_to_remove = ''
|
||||
if not any(s.startswith(pt_model.base_model_prefix) for s in current_pt_params_dict.keys()):
|
||||
start_prefix_to_remove = pt_model.base_model_prefix + '.'
|
||||
|
||||
# Build a map from potential PyTorch weight names to TF 2.0 Variables
|
||||
tf_weights_map = {}
|
||||
for tf_weight in tf_weights:
|
||||
pt_name, transpose = convert_tf_weight_name_to_pt_weight_name(tf_weight.name, start_prefix_to_remove=start_prefix_to_remove)
|
||||
tf_weights_map[pt_name] = (tf_weight.numpy(), transpose)
|
||||
|
||||
all_tf_weights = set(list(tf_weights_map.keys()))
|
||||
loaded_pt_weights_data_ptr = {}
|
||||
for pt_weight_name, pt_weight in current_pt_params_dict.items():
|
||||
# Handle PyTorch shared weight ()not duplicated in TF 2.0
|
||||
if pt_weight.data_ptr() in loaded_pt_weights_data_ptr:
|
||||
new_pt_params_dict[pt_weight_name] = loaded_pt_weights_data_ptr[pt_weight.data_ptr()]
|
||||
continue
|
||||
|
||||
# Find associated numpy array in pytorch model state dict
|
||||
if pt_weight_name not in tf_weights_map:
|
||||
raise ValueError("{} not found in TF 2.0 model".format(pt_weight_name))
|
||||
|
||||
array, transpose = tf_weights_map[pt_weight_name]
|
||||
|
||||
if transpose:
|
||||
array = numpy.transpose(array)
|
||||
|
||||
if len(pt_weight.shape) < len(array.shape):
|
||||
array = numpy.squeeze(array)
|
||||
elif len(pt_weight.shape) > len(array.shape):
|
||||
array = numpy.expand_dims(array, axis=0)
|
||||
|
||||
try:
|
||||
assert list(pt_weight.shape) == list(array.shape)
|
||||
except AssertionError as e:
|
||||
e.args += (pt_weight.shape, array.shape)
|
||||
raise e
|
||||
|
||||
logger.info("Initialize PyTorch weight {}".format(pt_weight_name))
|
||||
|
||||
new_pt_params_dict[pt_weight_name] = torch.from_numpy(array)
|
||||
loaded_pt_weights_data_ptr[pt_weight.data_ptr()] = torch.from_numpy(array)
|
||||
all_tf_weights.discard(pt_weight_name)
|
||||
|
||||
missing_keys, unexpected_keys = pt_model.load_state_dict(new_pt_params_dict, strict=False)
|
||||
|
||||
if len(missing_keys) > 0:
|
||||
logger.info("Weights of {} not initialized from TF 2.0 model: {}".format(
|
||||
pt_model.__class__.__name__, missing_keys))
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.info("Weights from TF 2.0 model not used in {}: {}".format(
|
||||
pt_model.__class__.__name__, unexpected_keys))
|
||||
|
||||
logger.info("Weights or buffers not loaded from TF 2.0 model: {}".format(all_tf_weights))
|
||||
|
||||
return pt_model
|
||||
|
||||
@@ -718,6 +718,101 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
|
||||
@add_start_docstrings("""XLM Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||
XLM_START_DOCSTRING, XLM_INPUTS_DOCSTRING)
|
||||
class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):
|
||||
r"""
|
||||
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||
Position outside of the sequence are not taken into account for computing the loss.
|
||||
**end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||||
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||
Position outside of the sequence are not taken into account for computing the loss.
|
||||
**is_impossible**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
Labels whether a question has an answer or no answer (SQuAD 2.0)
|
||||
**cls_index**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
Labels for position (index) of the classification token to use as input for computing plausibility of the answer.
|
||||
**p_mask**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Optional mask of tokens which can't be in answers (e.g. [CLS], [PAD], ...)
|
||||
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
|
||||
**start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
|
||||
Span-start scores (before SoftMax).
|
||||
**end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
|
||||
Span-end scores (before SoftMax).
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
|
||||
Examples::
|
||||
|
||||
tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048')
|
||||
model = XLMForQuestionAnsweringSimple.from_pretrained('xlm-mlm-en-2048')
|
||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||
start_positions = torch.tensor([1])
|
||||
end_positions = torch.tensor([3])
|
||||
outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
|
||||
loss, start_scores, end_scores = outputs[:2]
|
||||
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(XLMForQuestionAnsweringSimple, self).__init__(config)
|
||||
|
||||
self.transformer = XLMModel(config)
|
||||
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
|
||||
lengths=None, cache=None, head_mask=None, start_positions=None, end_positions=None):
|
||||
transformer_outputs = self.transformer(input_ids,
|
||||
attention_mask=attention_mask,
|
||||
langs=langs,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
lengths=lengths,
|
||||
cache=cache,
|
||||
head_mask=head_mask)
|
||||
|
||||
sequence_output = transformer_outputs[0]
|
||||
|
||||
logits = self.qa_outputs(sequence_output)
|
||||
start_logits, end_logits = logits.split(1, dim=-1)
|
||||
start_logits = start_logits.squeeze(-1)
|
||||
end_logits = end_logits.squeeze(-1)
|
||||
|
||||
outputs = (start_logits, end_logits,)
|
||||
if start_positions is not None and end_positions is not None:
|
||||
# If we are on multi-GPU, split add a dimension
|
||||
if len(start_positions.size()) > 1:
|
||||
start_positions = start_positions.squeeze(-1)
|
||||
if len(end_positions.size()) > 1:
|
||||
end_positions = end_positions.squeeze(-1)
|
||||
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||
ignored_index = start_logits.size(1)
|
||||
start_positions.clamp_(0, ignored_index)
|
||||
end_positions.clamp_(0, ignored_index)
|
||||
|
||||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||
start_loss = loss_fct(start_logits, start_positions)
|
||||
end_loss = loss_fct(end_logits, end_positions)
|
||||
total_loss = (start_loss + end_loss) / 2
|
||||
outputs = (total_loss,) + outputs
|
||||
|
||||
outputs = outputs + transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@add_start_docstrings("""XLM Model with a beam-search span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||
XLM_START_DOCSTRING, XLM_INPUTS_DOCSTRING)
|
||||
class XLMForQuestionAnswering(XLMPreTrainedModel):
|
||||
r"""
|
||||
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
|
||||
@@ -17,6 +17,7 @@ from __future__ import absolute_import, division, print_function
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import importlib
|
||||
import random
|
||||
import shutil
|
||||
import unittest
|
||||
@@ -25,7 +26,7 @@ import uuid
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
from pytorch_transformers import is_tf_available
|
||||
from pytorch_transformers import is_tf_available, is_torch_available
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
@@ -66,6 +67,24 @@ class TFCommonTestCases:
|
||||
# msg="Parameter {} of model {} seems not properly initialized".format(name, model_class))
|
||||
|
||||
|
||||
def test_pt_tf_model_equivalence(self):
|
||||
if not is_torch_available():
|
||||
pass
|
||||
import pytorch_transformers
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beggining
|
||||
pt_model_class = getattr(pytorch_transformers, pt_model_class_name)
|
||||
|
||||
tf_model = model_class(config)
|
||||
pt_model = pt_model_class(config)
|
||||
|
||||
tf_model = pytorch_transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=inputs_dict)
|
||||
pt_model = pytorch_transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
|
||||
|
||||
|
||||
def test_keyword_and_dict_args(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
||||
@@ -225,7 +225,7 @@ class TFXLMModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(config, input_ids, token_type_ids, input_lengths,
|
||||
sequence_labels, token_labels, is_impossible_labels, input_mask) = config_and_inputs
|
||||
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'lengths': input_lengths}
|
||||
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'langs': token_type_ids, 'lengths': input_lengths}
|
||||
return config, inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
|
||||
@@ -24,7 +24,7 @@ from pytorch_transformers import is_torch_available
|
||||
|
||||
if is_torch_available():
|
||||
from pytorch_transformers import (XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering,
|
||||
XLMForSequenceClassification)
|
||||
XLMForSequenceClassification, XLMForQuestionAnsweringSimple)
|
||||
from pytorch_transformers.modeling_xlm import XLM_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
else:
|
||||
pytestmark = pytest.mark.skip("Require Torch")
|
||||
@@ -36,7 +36,7 @@ from .configuration_common_test import ConfigTester
|
||||
class XLMModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
all_model_classes = (XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering,
|
||||
XLMForSequenceClassification) if is_torch_available() else ()
|
||||
XLMForSequenceClassification, XLMForQuestionAnsweringSimple) if is_torch_available() else ()
|
||||
|
||||
|
||||
class XLMModelTester(object):
|
||||
@@ -180,6 +180,30 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
|
||||
[self.batch_size, self.seq_length, self.vocab_size])
|
||||
|
||||
|
||||
def create_and_check_xlm_simple_qa(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask):
|
||||
model = XLMForQuestionAnsweringSimple(config)
|
||||
model.eval()
|
||||
|
||||
outputs = model(input_ids)
|
||||
|
||||
outputs = model(input_ids, start_positions=sequence_labels,
|
||||
end_positions=sequence_labels)
|
||||
loss, start_logits, end_logits = outputs
|
||||
|
||||
result = {
|
||||
"loss": loss,
|
||||
"start_logits": start_logits,
|
||||
"end_logits": end_logits,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["start_logits"].size()),
|
||||
[self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(
|
||||
list(result["end_logits"].size()),
|
||||
[self.batch_size, self.seq_length])
|
||||
self.check_loss_output(result)
|
||||
|
||||
|
||||
def create_and_check_xlm_qa(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask):
|
||||
model = XLMForQuestionAnswering(config)
|
||||
model.eval()
|
||||
@@ -276,6 +300,10 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_xlm_lm_head(*config_and_inputs)
|
||||
|
||||
def test_xlm_simple_qa(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_xlm_simple_qa(*config_and_inputs)
|
||||
|
||||
def test_xlm_qa(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_xlm_qa(*config_and_inputs)
|
||||
|
||||
Reference in New Issue
Block a user