Add missing token classification for XLM (#3277)

* Add the missing token classification for XLM

* fix styling

* Add XLMForTokenClassification to AutoModelForTokenClassification class

* Fix docstring typo for non-existing class

* Add the missing token classification for XLM

* fix styling

* fix styling

* Add XLMForTokenClassification to AutoModelForTokenClassification class

* Fix docstring typo for non-existing class

* Add missing description for AlbertForTokenClassification

* fix styling

* Add missing docstring for AlBert

* Slow tests should be slow

Co-authored-by: Sakares Saengkaew <s.sakares@gmail.com>
Co-authored-by: LysandreJik <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
sakares saengkaew
2020-03-26 21:22:13 +07:00
committed by GitHub
parent 311970546f
commit 1a6c546c6f
5 changed files with 169 additions and 22 deletions

View File

@@ -222,6 +222,7 @@ if is_torch_available():
XLMModel,
XLMWithLMHeadModel,
XLMForSequenceClassification,
XLMForTokenClassification,
XLMForQuestionAnswering,
XLMForQuestionAnsweringSimple,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP,

View File

@@ -99,6 +99,7 @@ from .modeling_xlm import (
XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
XLMForQuestionAnsweringSimple,
XLMForSequenceClassification,
XLMForTokenClassification,
XLMModel,
XLMWithLMHeadModel,
)
@@ -235,6 +236,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
[
(DistilBertConfig, DistilBertForTokenClassification),
(CamembertConfig, CamembertForTokenClassification),
(XLMConfig, XLMForTokenClassification),
(XLMRobertaConfig, XLMRobertaForTokenClassification),
(RobertaConfig, RobertaForTokenClassification),
(BertConfig, BertForTokenClassification),
@@ -418,12 +420,12 @@ class AutoModelForPreTraining(object):
config (:class:`~transformers.PretrainedConfig`):
The model class to instantiate is selected based on the configuration class:
- isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertModelForMaskedLM` (DistilBERT model)
- isInstance of `roberta` configuration class: :class:`~transformers.RobertaModelForMaskedLM` (RoBERTa model)
- isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertForMaskedLM` (DistilBERT model)
- isInstance of `roberta` configuration class: :class:`~transformers.RobertaForMaskedLM` (RoBERTa model)
- isInstance of `bert` configuration class: :class:`~transformers.BertForPreTraining` (Bert model)
- isInstance of `openai-gpt` configuration class: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model)
- isInstance of `gpt2` configuration class: :class:`~transformers.GPT2ModelLMHeadModel` (OpenAI GPT-2 model)
- isInstance of `ctrl` configuration class: :class:`~transformers.CTRLModelLMHeadModel` (Salesforce CTRL model)
- isInstance of `gpt2` configuration class: :class:`~transformers.GPT2LMHeadModel` (OpenAI GPT-2 model)
- isInstance of `ctrl` configuration class: :class:`~transformers.CTRLLMHeadModel` (Salesforce CTRL model)
- isInstance of `transfo-xl` configuration class: :class:`~transformers.TransfoXLLMHeadModel` (Transformer-XL model)
- isInstance of `xlnet` configuration class: :class:`~transformers.XLNetLMHeadModel` (XLNet model)
- isInstance of `xlm` configuration class: :class:`~transformers.XLMWithLMHeadModel` (XLM model)
@@ -559,12 +561,12 @@ class AutoModelWithLMHead(object):
config (:class:`~transformers.PretrainedConfig`):
The model class to instantiate is selected based on the configuration class:
- isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertModelForMaskedLM` (DistilBERT model)
- isInstance of `roberta` configuration class: :class:`~transformers.RobertaModelForMaskedLM` (RoBERTa model)
- isInstance of `bert` configuration class: :class:`~transformers.BertModelForMaskedLM` (Bert model)
- isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertForMaskedLM` (DistilBERT model)
- isInstance of `roberta` configuration class: :class:`~transformers.RobertaForMaskedLM` (RoBERTa model)
- isInstance of `bert` configuration class: :class:`~transformers.BertForMaskedLM` (Bert model)
- isInstance of `openai-gpt` configuration class: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model)
- isInstance of `gpt2` configuration class: :class:`~transformers.GPT2ModelLMHeadModel` (OpenAI GPT-2 model)
- isInstance of `ctrl` configuration class: :class:`~transformers.CTRLModelLMHeadModel` (Salesforce CTRL model)
- isInstance of `gpt2` configuration class: :class:`~transformers.GPT2LMHeadModel` (OpenAI GPT-2 model)
- isInstance of `ctrl` configuration class: :class:`~transformers.CTRLLMHeadModel` (Salesforce CTRL model)
- isInstance of `transfo-xl` configuration class: :class:`~transformers.TransfoXLLMHeadModel` (Transformer-XL model)
- isInstance of `xlnet` configuration class: :class:`~transformers.XLNetLMHeadModel` (XLNet model)
- isInstance of `xlm` configuration class: :class:`~transformers.XLMWithLMHeadModel` (XLM model)
@@ -701,14 +703,14 @@ class AutoModelForSequenceClassification(object):
config (:class:`~transformers.PretrainedConfig`):
The model class to instantiate is selected based on the configuration class:
- isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertModelForSequenceClassification` (DistilBERT model)
- isInstance of `albert` configuration class: :class:`~transformers.AlbertModelForSequenceClassification` (ALBERT model)
- isInstance of `camembert` configuration class: :class:`~transformers.CamembertModelForSequenceClassification` (CamemBERT model)
- isInstance of `xlm roberta` configuration class: :class:`~transformers.XLMRobertaModelForSequenceClassification` (XLM-RoBERTa model)
- isInstance of `roberta` configuration class: :class:`~transformers.RobertaModelForSequenceClassification` (RoBERTa model)
- isInstance of `bert` configuration class: :class:`~transformers.BertModelForSequenceClassification` (Bert model)
- isInstance of `xlnet` configuration class: :class:`~transformers.XLNetModelForSequenceClassification` (XLNet model)
- isInstance of `xlm` configuration class: :class:`~transformers.XLMModelForSequenceClassification` (XLM model)
- isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertForSequenceClassification` (DistilBERT model)
- isInstance of `albert` configuration class: :class:`~transformers.AlbertForSequenceClassification` (ALBERT model)
- isInstance of `camembert` configuration class: :class:`~transformers.CamembertForSequenceClassification` (CamemBERT model)
- isInstance of `xlm roberta` configuration class: :class:`~transformers.XLMRobertaForSequenceClassification` (XLM-RoBERTa model)
- isInstance of `roberta` configuration class: :class:`~transformers.RobertaForSequenceClassification` (RoBERTa model)
- isInstance of `bert` configuration class: :class:`~transformers.BertForSequenceClassification` (Bert model)
- isInstance of `xlnet` configuration class: :class:`~transformers.XLNetForSequenceClassification` (XLNet model)
- isInstance of `xlm` configuration class: :class:`~transformers.XLMForSequenceClassification` (XLM model)
- isInstance of `flaubert` configuration class: :class:`~transformers.FlaubertForSequenceClassification` (Flaubert model)
@@ -848,11 +850,11 @@ class AutoModelForQuestionAnswering(object):
config (:class:`~transformers.PretrainedConfig`):
The model class to instantiate is selected based on the configuration class:
- isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertModelForQuestionAnswering` (DistilBERT model)
- isInstance of `albert` configuration class: :class:`~transformers.AlbertModelForQuestionAnswering` (ALBERT model)
- isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertForQuestionAnswering` (DistilBERT model)
- isInstance of `albert` configuration class: :class:`~transformers.AlbertForQuestionAnswering` (ALBERT model)
- isInstance of `bert` configuration class: :class:`~transformers.BertModelForQuestionAnswering` (Bert model)
- isInstance of `xlnet` configuration class: :class:`~transformers.XLNetModelForQuestionAnswering` (XLNet model)
- isInstance of `xlm` configuration class: :class:`~transformers.XLMModelForQuestionAnswering` (XLM model)
- isInstance of `xlnet` configuration class: :class:`~transformers.XLNetForQuestionAnswering` (XLNet model)
- isInstance of `xlm` configuration class: :class:`~transformers.XLMForQuestionAnswering` (XLM model)
- isInstance of `flaubert` configuration class: :class:`~transformers.FlaubertForQuestionAnswering` (XLM model)
Examples::
@@ -989,8 +991,10 @@ class AutoModelForTokenClassification:
The model class to instantiate is selected based on the configuration class:
- isInstance of `distilbert` configuration class: :class:`~transformers.DistilBertModelForTokenClassification` (DistilBERT model)
- isInstance of `xlm` configuration class: :class:`~transformers.XLMForTokenClassification` (XLM model)
- isInstance of `xlm roberta` configuration class: :class:`~transformers.XLMRobertaModelForTokenClassification` (XLMRoberta model)
- isInstance of `bert` configuration class: :class:`~transformers.BertModelForTokenClassification` (Bert model)
- isInstance of `albert` configuration class: :class:`~transformers.AlbertForTokenClassification` (AlBert model)
- isInstance of `xlnet` configuration class: :class:`~transformers.XLNetModelForTokenClassification` (XLNet model)
- isInstance of `camembert` configuration class: :class:`~transformers.CamembertModelForTokenClassification` (Camembert model)
- isInstance of `roberta` configuration class: :class:`~transformers.RobertaModelForTokenClassification` (Roberta model)
@@ -1025,6 +1029,7 @@ class AutoModelForTokenClassification:
The model class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
- contains `distilbert`: :class:`~transformers.DistilBertForTokenClassification` (DistilBERT model)
- contains `xlm`: :class:`~transformers.XLMForTokenClassification` (XLM model)
- contains `xlm-roberta`: :class:`~transformers.XLMRobertaForTokenClassification` (XLM-RoBERTa?Para model)
- contains `camembert`: :class:`~transformers.CamembertForTokenClassification` (Camembert model)
- contains `bert`: :class:`~transformers.BertForTokenClassification` (Bert model)

View File

@@ -1040,3 +1040,98 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
outputs = outputs + transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here
return outputs
@add_start_docstrings(
"""XLM Model with a token classification head on top (a linear layer on top of
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
XLM_START_DOCSTRING,
)
class XLMForTokenClassification(XLMPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.transformer = XLMModel(config)
self.dropout = nn.Dropout(config.dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
@add_start_docstrings_to_callable(XLM_INPUTS_DOCSTRING)
def forward(
self,
input_ids=None,
attention_mask=None,
langs=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
labels=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the token classification loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.XLMConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
Classification loss.
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`)
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
from transformers import XLMTokenizer, XLMForTokenClassification
import torch
tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-100-1280')
model = XLMForTokenClassification.from_pretrained('xlm-mlm-100-1280')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, scores = outputs[:2]
"""
outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
if labels is not None:
loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)
active_labels = torch.where(
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
)
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs
return outputs # (loss), scores, (hidden_states), (attentions)