bidirectional conversion TF <=> PT - extended tests
This commit is contained in:
@@ -1,5 +1,20 @@
|
|||||||
version: 2
|
version: 2
|
||||||
jobs:
|
jobs:
|
||||||
|
build_py3_torch_and_tf:
|
||||||
|
working_directory: ~/pytorch-transformers
|
||||||
|
docker:
|
||||||
|
- image: circleci/python:3.5
|
||||||
|
resource_class: xlarge
|
||||||
|
parallelism: 1
|
||||||
|
steps:
|
||||||
|
- checkout
|
||||||
|
- run: sudo pip install torch
|
||||||
|
- run: sudo pip install tensorflow==2.0.0-rc0
|
||||||
|
- run: sudo pip install --progress-bar off .
|
||||||
|
- run: sudo pip install pytest codecov pytest-cov
|
||||||
|
- run: sudo pip install tensorboardX scikit-learn
|
||||||
|
- run: python -m pytest -sv ./pytorch_transformers/tests/ --cov
|
||||||
|
- run: codecov
|
||||||
build_py3_torch:
|
build_py3_torch:
|
||||||
working_directory: ~/pytorch-transformers
|
working_directory: ~/pytorch-transformers
|
||||||
docker:
|
docker:
|
||||||
|
|||||||
@@ -73,7 +73,8 @@ if _torch_available:
|
|||||||
load_tf_weights_in_xlnet, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
|
load_tf_weights_in_xlnet, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
from .modeling_xlm import (XLMPreTrainedModel , XLMModel,
|
from .modeling_xlm import (XLMPreTrainedModel , XLMModel,
|
||||||
XLMWithLMHeadModel, XLMForSequenceClassification,
|
XLMWithLMHeadModel, XLMForSequenceClassification,
|
||||||
XLMForQuestionAnswering, XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
|
XLMForQuestionAnswering, XLMForQuestionAnsweringSimple,
|
||||||
|
XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
from .modeling_roberta import (RobertaForMaskedLM, RobertaModel, RobertaForSequenceClassification,
|
from .modeling_roberta import (RobertaForMaskedLM, RobertaModel, RobertaForSequenceClassification,
|
||||||
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
|
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
from .modeling_distilbert import (DistilBertForMaskedLM, DistilBertModel,
|
from .modeling_distilbert import (DistilBertForMaskedLM, DistilBertModel,
|
||||||
@@ -150,6 +151,15 @@ if _tf_available:
|
|||||||
load_distilbert_pt_weights_in_tf2,
|
load_distilbert_pt_weights_in_tf2,
|
||||||
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
|
|
||||||
|
if _tf_available and _torch_available:
|
||||||
|
from .modeling_tf_pytorch_utils import (convert_tf_weight_name_to_pt_weight_name,
|
||||||
|
load_pytorch_checkpoint_in_tf2_model,
|
||||||
|
load_pytorch_weights_in_tf2_model,
|
||||||
|
load_pytorch_model_in_tf2_model,
|
||||||
|
load_tf2_checkpoint_in_pytorch_model,
|
||||||
|
load_tf2_weights_in_pytorch_model,
|
||||||
|
load_tf2_model_in_pytorch_model)
|
||||||
|
|
||||||
# Files and general utilities
|
# Files and general utilities
|
||||||
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
|
from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE,
|
||||||
cached_path, add_start_docstrings, add_end_docstrings,
|
cached_path, add_start_docstrings, add_end_docstrings,
|
||||||
|
|||||||
@@ -20,15 +20,49 @@ from __future__ import (absolute_import, division, print_function,
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
|
import numpy
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=None):
|
def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove=''):
|
||||||
""" Load pytorch checkpoints in a TF 2.0 model
|
""" Convert a TF 2.0 model variable name in a pytorch model weight name.
|
||||||
|
|
||||||
Conventions for TF2.0 scopes -> PyTorch attribute names conversions:
|
Conventions for TF2.0 scopes -> PyTorch attribute names conversions:
|
||||||
- '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
|
- '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
|
||||||
- '_._' is replaced by a new level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
|
- '_._' is replaced by a new level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
|
||||||
|
|
||||||
|
return tuple with:
|
||||||
|
- pytorch model weight name
|
||||||
|
- transpose: boolean indicating weither TF2.0 and PyTorch weights matrices are transposed with regards to each other
|
||||||
|
"""
|
||||||
|
tf_name = tf_name.replace(':0', '') # device ids
|
||||||
|
tf_name = re.sub(r'/[^/]*___([^/]*)/', r'/\1/', tf_name) # '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
|
||||||
|
tf_name = tf_name.replace('_._', '/') # '_._' is replaced by a level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
|
||||||
|
tf_name = re.sub(r'//+', '/', tf_name) # Remove empty levels at the end
|
||||||
|
tf_name = tf_name.split('/') # Convert from TF2.0 '/' separators to PyTorch '.' separators
|
||||||
|
tf_name = tf_name[1:] # Remove level zero
|
||||||
|
|
||||||
|
# When should we transpose the weights
|
||||||
|
transpose = bool(tf_name[-1] == 'kernel' or 'emb_projs' in tf_name or 'out_projs' in tf_name)
|
||||||
|
|
||||||
|
# Convert standard TF2.0 names in PyTorch names
|
||||||
|
if tf_name[-1] == 'kernel' or tf_name[-1] == 'embeddings' or tf_name[-1] == 'gamma':
|
||||||
|
tf_name[-1] = 'weight'
|
||||||
|
if tf_name[-1] == 'beta':
|
||||||
|
tf_name[-1] = 'bias'
|
||||||
|
|
||||||
|
# Remove prefix if needed
|
||||||
|
tf_name = '.'.join(tf_name)
|
||||||
|
if start_prefix_to_remove:
|
||||||
|
tf_name = tf_name.replace(start_prefix_to_remove, '', 1)
|
||||||
|
|
||||||
|
return tf_name, transpose
|
||||||
|
|
||||||
|
|
||||||
|
def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=None):
|
||||||
|
""" Load pytorch checkpoints in a TF 2.0 model
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
@@ -43,25 +77,31 @@ def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_i
|
|||||||
|
|
||||||
pt_state_dict = torch.load(pt_path, map_location='cpu')
|
pt_state_dict = torch.load(pt_path, map_location='cpu')
|
||||||
|
|
||||||
return load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict, tf_inputs=tf_inputs)
|
return load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=tf_inputs)
|
||||||
|
|
||||||
|
|
||||||
def load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None):
|
def load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=None):
|
||||||
|
""" Load pytorch checkpoints in a TF 2.0 model
|
||||||
|
"""
|
||||||
|
pt_state_dict = pt_model.state_dict()
|
||||||
|
|
||||||
|
return load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=tf_inputs)
|
||||||
|
|
||||||
|
|
||||||
|
def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None):
|
||||||
""" Load pytorch state_dict in a TF 2.0 model.
|
""" Load pytorch state_dict in a TF 2.0 model.
|
||||||
Conventions for TF2.0 scopes -> PyTorch attribute names conversions:
|
|
||||||
- '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
|
|
||||||
- '_._' is replaced by a level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
import re
|
|
||||||
import torch
|
import torch
|
||||||
import numpy
|
|
||||||
from tensorflow.python.keras import backend as K
|
from tensorflow.python.keras import backend as K
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.error("Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
|
logger.error("Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
|
||||||
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions.")
|
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions.")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
if tf_inputs is not None:
|
||||||
|
tfo = tf_model(tf_inputs, training=False) # Make sure model is built
|
||||||
|
|
||||||
# Adapt state dict - TODO remove this and update the AWS weights files instead
|
# Adapt state dict - TODO remove this and update the AWS weights files instead
|
||||||
# Convert old format to new format if needed from a PyTorch state_dict
|
# Convert old format to new format if needed from a PyTorch state_dict
|
||||||
old_keys = []
|
old_keys = []
|
||||||
@@ -89,27 +129,8 @@ def load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None
|
|||||||
weight_value_tuples = []
|
weight_value_tuples = []
|
||||||
all_pytorch_weights = set(list(pt_state_dict.keys()))
|
all_pytorch_weights = set(list(pt_state_dict.keys()))
|
||||||
for symbolic_weight in symbolic_weights:
|
for symbolic_weight in symbolic_weights:
|
||||||
name = symbolic_weight.name
|
sw_name = symbolic_weight.name
|
||||||
name = name.replace(':0', '') # device ids
|
name, transpose = convert_tf_weight_name_to_pt_weight_name(sw_name, start_prefix_to_remove=start_prefix_to_remove)
|
||||||
name = re.sub(r'/[^/]*___([^/]*)/', r'/\1/', name) # '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
|
|
||||||
name = name.replace('_._', '/') # '_._' is replaced by a level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
|
|
||||||
name = re.sub(r'//+', '/', name) # Remove empty levels at the end
|
|
||||||
name = name.split('/') # Convert from TF2.0 '/' separators to PyTorch '.' separators
|
|
||||||
name = name[1:] # Remove level zero
|
|
||||||
|
|
||||||
# When should we transpose the weights
|
|
||||||
transpose = bool(name[-1] == 'kernel' or 'emb_projs' in name or 'out_projs' in name)
|
|
||||||
|
|
||||||
# Convert standard TF2.0 names in PyTorch names
|
|
||||||
if name[-1] == 'kernel' or name[-1] == 'embeddings' or name[-1] == 'gamma':
|
|
||||||
name[-1] = 'weight'
|
|
||||||
if name[-1] == 'beta':
|
|
||||||
name[-1] = 'bias'
|
|
||||||
|
|
||||||
# Remove prefix if needed
|
|
||||||
name = '.'.join(name)
|
|
||||||
if start_prefix_to_remove:
|
|
||||||
name = name.replace(start_prefix_to_remove, '', 1)
|
|
||||||
|
|
||||||
# Find associated numpy array in pytorch model state dict
|
# Find associated numpy array in pytorch model state dict
|
||||||
assert name in pt_state_dict, "{} not found in PyTorch model".format(name)
|
assert name in pt_state_dict, "{} not found in PyTorch model".format(name)
|
||||||
@@ -144,13 +165,10 @@ def load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None
|
|||||||
return tf_model
|
return tf_model
|
||||||
|
|
||||||
|
|
||||||
def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path):
|
def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs=None):
|
||||||
""" Load TF 2.0 HDF5 checkpoint in a PyTorch model
|
""" Load TF 2.0 HDF5 checkpoint in a PyTorch model
|
||||||
We use HDF5 to easily do transfer learning
|
We use HDF5 to easily do transfer learning
|
||||||
(see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
|
(see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
|
||||||
Conventions for TF2.0 scopes -> PyTorch attribute names conversions:
|
|
||||||
- '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
|
|
||||||
- '_._' is replaced by a new level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
@@ -161,13 +179,97 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path):
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
tf_path = os.path.abspath(tf_checkpoint_path)
|
tf_path = os.path.abspath(tf_checkpoint_path)
|
||||||
logger.info("Loading TensorFlow weights from {}".format(tf_path))
|
logger.info("Loading TensorFlow weights from {}".format(tf_checkpoint_path))
|
||||||
|
|
||||||
tf_state_dict = torch.load(tf_path, map_location='cpu')
|
# Instantiate and load the associated TF 2.0 model
|
||||||
|
tf_model_class_name = "TF" + model_class.__name__ # Add "TF" at the beggining
|
||||||
|
tf_model_class = getattr(pytorch_transformers, tf_model_class_name)
|
||||||
|
tf_model = tf_model_class(pt_model.config)
|
||||||
|
|
||||||
return load_tf2_weights_in_pytorch_model(pt_model, tf_state_dict)
|
if tf_inputs is not None:
|
||||||
|
tfo = tf_model(tf_inputs, training=False) # Make sure model is built
|
||||||
|
|
||||||
def load_tf2_weights_in_pytorch_model(pt_model, tf_model):
|
tf_model.load_weights(tf_checkpoint_path, by_name=True)
|
||||||
|
|
||||||
|
return load_tf2_model_in_pytorch_model(pt_model, tf_model)
|
||||||
|
|
||||||
|
def load_tf2_model_in_pytorch_model(pt_model, tf_model):
|
||||||
|
""" Load TF 2.0 model in a pytorch model
|
||||||
|
"""
|
||||||
|
weights = tf_model.weights
|
||||||
|
|
||||||
|
return load_tf2_weights_in_pytorch_model(pt_model, weights)
|
||||||
|
|
||||||
|
|
||||||
|
def load_tf2_weights_in_pytorch_model(pt_model, tf_weights):
|
||||||
""" Load TF2.0 symbolic weights in a PyTorch model
|
""" Load TF2.0 symbolic weights in a PyTorch model
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
try:
|
||||||
|
import tensorflow as tf
|
||||||
|
import torch
|
||||||
|
except ImportError as e:
|
||||||
|
logger.error("Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
|
||||||
|
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions.")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
new_pt_params_dict = {}
|
||||||
|
current_pt_params_dict = dict(pt_model.named_parameters())
|
||||||
|
|
||||||
|
# Make sure we are able to load PyTorch base models as well as derived models (with heads)
|
||||||
|
# TF models always have a prefix, some of PyTorch models (base ones) don't
|
||||||
|
start_prefix_to_remove = ''
|
||||||
|
if not any(s.startswith(pt_model.base_model_prefix) for s in current_pt_params_dict.keys()):
|
||||||
|
start_prefix_to_remove = pt_model.base_model_prefix + '.'
|
||||||
|
|
||||||
|
# Build a map from potential PyTorch weight names to TF 2.0 Variables
|
||||||
|
tf_weights_map = {}
|
||||||
|
for tf_weight in tf_weights:
|
||||||
|
pt_name, transpose = convert_tf_weight_name_to_pt_weight_name(tf_weight.name, start_prefix_to_remove=start_prefix_to_remove)
|
||||||
|
tf_weights_map[pt_name] = (tf_weight.numpy(), transpose)
|
||||||
|
|
||||||
|
all_tf_weights = set(list(tf_weights_map.keys()))
|
||||||
|
loaded_pt_weights_data_ptr = {}
|
||||||
|
for pt_weight_name, pt_weight in current_pt_params_dict.items():
|
||||||
|
# Handle PyTorch shared weight ()not duplicated in TF 2.0
|
||||||
|
if pt_weight.data_ptr() in loaded_pt_weights_data_ptr:
|
||||||
|
new_pt_params_dict[pt_weight_name] = loaded_pt_weights_data_ptr[pt_weight.data_ptr()]
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Find associated numpy array in pytorch model state dict
|
||||||
|
if pt_weight_name not in tf_weights_map:
|
||||||
|
raise ValueError("{} not found in TF 2.0 model".format(pt_weight_name))
|
||||||
|
|
||||||
|
array, transpose = tf_weights_map[pt_weight_name]
|
||||||
|
|
||||||
|
if transpose:
|
||||||
|
array = numpy.transpose(array)
|
||||||
|
|
||||||
|
if len(pt_weight.shape) < len(array.shape):
|
||||||
|
array = numpy.squeeze(array)
|
||||||
|
elif len(pt_weight.shape) > len(array.shape):
|
||||||
|
array = numpy.expand_dims(array, axis=0)
|
||||||
|
|
||||||
|
try:
|
||||||
|
assert list(pt_weight.shape) == list(array.shape)
|
||||||
|
except AssertionError as e:
|
||||||
|
e.args += (pt_weight.shape, array.shape)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
logger.info("Initialize PyTorch weight {}".format(pt_weight_name))
|
||||||
|
|
||||||
|
new_pt_params_dict[pt_weight_name] = torch.from_numpy(array)
|
||||||
|
loaded_pt_weights_data_ptr[pt_weight.data_ptr()] = torch.from_numpy(array)
|
||||||
|
all_tf_weights.discard(pt_weight_name)
|
||||||
|
|
||||||
|
missing_keys, unexpected_keys = pt_model.load_state_dict(new_pt_params_dict, strict=False)
|
||||||
|
|
||||||
|
if len(missing_keys) > 0:
|
||||||
|
logger.info("Weights of {} not initialized from TF 2.0 model: {}".format(
|
||||||
|
pt_model.__class__.__name__, missing_keys))
|
||||||
|
if len(unexpected_keys) > 0:
|
||||||
|
logger.info("Weights from TF 2.0 model not used in {}: {}".format(
|
||||||
|
pt_model.__class__.__name__, unexpected_keys))
|
||||||
|
|
||||||
|
logger.info("Weights or buffers not loaded from TF 2.0 model: {}".format(all_tf_weights))
|
||||||
|
|
||||||
|
return pt_model
|
||||||
|
|||||||
@@ -718,6 +718,101 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
|
|||||||
@add_start_docstrings("""XLM Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
@add_start_docstrings("""XLM Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||||
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||||
XLM_START_DOCSTRING, XLM_INPUTS_DOCSTRING)
|
XLM_START_DOCSTRING, XLM_INPUTS_DOCSTRING)
|
||||||
|
class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):
|
||||||
|
r"""
|
||||||
|
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||||
|
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||||||
|
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||||
|
Position outside of the sequence are not taken into account for computing the loss.
|
||||||
|
**end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||||
|
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||||||
|
Positions are clamped to the length of the sequence (`sequence_length`).
|
||||||
|
Position outside of the sequence are not taken into account for computing the loss.
|
||||||
|
**is_impossible**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||||
|
Labels whether a question has an answer or no answer (SQuAD 2.0)
|
||||||
|
**cls_index**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||||
|
Labels for position (index) of the classification token to use as input for computing plausibility of the answer.
|
||||||
|
**p_mask**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||||
|
Optional mask of tokens which can't be in answers (e.g. [CLS], [PAD], ...)
|
||||||
|
|
||||||
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||||
|
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||||
|
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
|
||||||
|
**start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
|
||||||
|
Span-start scores (before SoftMax).
|
||||||
|
**end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
|
||||||
|
Span-end scores (before SoftMax).
|
||||||
|
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||||
|
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||||
|
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
|
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||||
|
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048')
|
||||||
|
model = XLMForQuestionAnsweringSimple.from_pretrained('xlm-mlm-en-2048')
|
||||||
|
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||||
|
start_positions = torch.tensor([1])
|
||||||
|
end_positions = torch.tensor([3])
|
||||||
|
outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
|
||||||
|
loss, start_scores, end_scores = outputs[:2]
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self, config):
|
||||||
|
super(XLMForQuestionAnsweringSimple, self).__init__(config)
|
||||||
|
|
||||||
|
self.transformer = XLMModel(config)
|
||||||
|
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def forward(self, input_ids, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
|
||||||
|
lengths=None, cache=None, head_mask=None, start_positions=None, end_positions=None):
|
||||||
|
transformer_outputs = self.transformer(input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
langs=langs,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
lengths=lengths,
|
||||||
|
cache=cache,
|
||||||
|
head_mask=head_mask)
|
||||||
|
|
||||||
|
sequence_output = transformer_outputs[0]
|
||||||
|
|
||||||
|
logits = self.qa_outputs(sequence_output)
|
||||||
|
start_logits, end_logits = logits.split(1, dim=-1)
|
||||||
|
start_logits = start_logits.squeeze(-1)
|
||||||
|
end_logits = end_logits.squeeze(-1)
|
||||||
|
|
||||||
|
outputs = (start_logits, end_logits,)
|
||||||
|
if start_positions is not None and end_positions is not None:
|
||||||
|
# If we are on multi-GPU, split add a dimension
|
||||||
|
if len(start_positions.size()) > 1:
|
||||||
|
start_positions = start_positions.squeeze(-1)
|
||||||
|
if len(end_positions.size()) > 1:
|
||||||
|
end_positions = end_positions.squeeze(-1)
|
||||||
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||||||
|
ignored_index = start_logits.size(1)
|
||||||
|
start_positions.clamp_(0, ignored_index)
|
||||||
|
end_positions.clamp_(0, ignored_index)
|
||||||
|
|
||||||
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||||||
|
start_loss = loss_fct(start_logits, start_positions)
|
||||||
|
end_loss = loss_fct(end_logits, end_positions)
|
||||||
|
total_loss = (start_loss + end_loss) / 2
|
||||||
|
outputs = (total_loss,) + outputs
|
||||||
|
|
||||||
|
outputs = outputs + transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings("""XLM Model with a beam-search span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||||
|
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||||
|
XLM_START_DOCSTRING, XLM_INPUTS_DOCSTRING)
|
||||||
class XLMForQuestionAnswering(XLMPreTrainedModel):
|
class XLMForQuestionAnswering(XLMPreTrainedModel):
|
||||||
r"""
|
r"""
|
||||||
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from __future__ import absolute_import, division, print_function
|
|||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import importlib
|
||||||
import random
|
import random
|
||||||
import shutil
|
import shutil
|
||||||
import unittest
|
import unittest
|
||||||
@@ -25,7 +26,7 @@ import uuid
|
|||||||
import pytest
|
import pytest
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from pytorch_transformers import is_tf_available
|
from pytorch_transformers import is_tf_available, is_torch_available
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
@@ -66,6 +67,24 @@ class TFCommonTestCases:
|
|||||||
# msg="Parameter {} of model {} seems not properly initialized".format(name, model_class))
|
# msg="Parameter {} of model {} seems not properly initialized".format(name, model_class))
|
||||||
|
|
||||||
|
|
||||||
|
def test_pt_tf_model_equivalence(self):
|
||||||
|
if not is_torch_available():
|
||||||
|
pass
|
||||||
|
import pytorch_transformers
|
||||||
|
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beggining
|
||||||
|
pt_model_class = getattr(pytorch_transformers, pt_model_class_name)
|
||||||
|
|
||||||
|
tf_model = model_class(config)
|
||||||
|
pt_model = pt_model_class(config)
|
||||||
|
|
||||||
|
tf_model = pytorch_transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=inputs_dict)
|
||||||
|
pt_model = pytorch_transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
|
||||||
|
|
||||||
|
|
||||||
def test_keyword_and_dict_args(self):
|
def test_keyword_and_dict_args(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
|||||||
@@ -225,7 +225,7 @@ class TFXLMModelTest(TFCommonTestCases.TFCommonModelTester):
|
|||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
(config, input_ids, token_type_ids, input_lengths,
|
(config, input_ids, token_type_ids, input_lengths,
|
||||||
sequence_labels, token_labels, is_impossible_labels, input_mask) = config_and_inputs
|
sequence_labels, token_labels, is_impossible_labels, input_mask) = config_and_inputs
|
||||||
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'lengths': input_lengths}
|
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'langs': token_type_ids, 'lengths': input_lengths}
|
||||||
return config, inputs_dict
|
return config, inputs_dict
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from pytorch_transformers import is_torch_available
|
|||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from pytorch_transformers import (XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering,
|
from pytorch_transformers import (XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering,
|
||||||
XLMForSequenceClassification)
|
XLMForSequenceClassification, XLMForQuestionAnsweringSimple)
|
||||||
from pytorch_transformers.modeling_xlm import XLM_PRETRAINED_MODEL_ARCHIVE_MAP
|
from pytorch_transformers.modeling_xlm import XLM_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
else:
|
else:
|
||||||
pytestmark = pytest.mark.skip("Require Torch")
|
pytestmark = pytest.mark.skip("Require Torch")
|
||||||
@@ -36,7 +36,7 @@ from .configuration_common_test import ConfigTester
|
|||||||
class XLMModelTest(CommonTestCases.CommonModelTester):
|
class XLMModelTest(CommonTestCases.CommonModelTester):
|
||||||
|
|
||||||
all_model_classes = (XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering,
|
all_model_classes = (XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering,
|
||||||
XLMForSequenceClassification) if is_torch_available() else ()
|
XLMForSequenceClassification, XLMForQuestionAnsweringSimple) if is_torch_available() else ()
|
||||||
|
|
||||||
|
|
||||||
class XLMModelTester(object):
|
class XLMModelTester(object):
|
||||||
@@ -180,6 +180,30 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
|
|||||||
[self.batch_size, self.seq_length, self.vocab_size])
|
[self.batch_size, self.seq_length, self.vocab_size])
|
||||||
|
|
||||||
|
|
||||||
|
def create_and_check_xlm_simple_qa(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask):
|
||||||
|
model = XLMForQuestionAnsweringSimple(config)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
outputs = model(input_ids)
|
||||||
|
|
||||||
|
outputs = model(input_ids, start_positions=sequence_labels,
|
||||||
|
end_positions=sequence_labels)
|
||||||
|
loss, start_logits, end_logits = outputs
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"loss": loss,
|
||||||
|
"start_logits": start_logits,
|
||||||
|
"end_logits": end_logits,
|
||||||
|
}
|
||||||
|
self.parent.assertListEqual(
|
||||||
|
list(result["start_logits"].size()),
|
||||||
|
[self.batch_size, self.seq_length])
|
||||||
|
self.parent.assertListEqual(
|
||||||
|
list(result["end_logits"].size()),
|
||||||
|
[self.batch_size, self.seq_length])
|
||||||
|
self.check_loss_output(result)
|
||||||
|
|
||||||
|
|
||||||
def create_and_check_xlm_qa(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask):
|
def create_and_check_xlm_qa(self, config, input_ids, token_type_ids, input_lengths, sequence_labels, token_labels, is_impossible_labels, input_mask):
|
||||||
model = XLMForQuestionAnswering(config)
|
model = XLMForQuestionAnswering(config)
|
||||||
model.eval()
|
model.eval()
|
||||||
@@ -276,6 +300,10 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_xlm_lm_head(*config_and_inputs)
|
self.model_tester.create_and_check_xlm_lm_head(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_xlm_simple_qa(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_xlm_simple_qa(*config_and_inputs)
|
||||||
|
|
||||||
def test_xlm_qa(self):
|
def test_xlm_qa(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_xlm_qa(*config_and_inputs)
|
self.model_tester.create_and_check_xlm_qa(*config_and_inputs)
|
||||||
|
|||||||
Reference in New Issue
Block a user