Ctrl for sequence classification (#8812)

* add CTRLForSequenceClassification

* pass local test

* merge with master

* fix modeling test for sequence classification

* fix deco

* fix assert
This commit is contained in:
elk-cloner
2020-12-01 12:19:27 +03:30
committed by GitHub
parent 7f34d75780
commit 4a9e502a36
7 changed files with 167 additions and 7 deletions

View File

@@ -26,7 +26,13 @@ from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention
if is_torch_available():
import torch
from transformers import CTRL_PRETRAINED_MODEL_ARCHIVE_LIST, CTRLConfig, CTRLLMHeadModel, CTRLModel
from transformers import (
CTRL_PRETRAINED_MODEL_ARCHIVE_LIST,
CTRLConfig,
CTRLForSequenceClassification,
CTRLLMHeadModel,
CTRLModel,
)
class CTRLModelTester:
@@ -57,6 +63,7 @@ class CTRLModelTester:
self.num_labels = 3
self.num_choices = 4
self.scope = None
self.pad_token_id = self.vocab_size - 1
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
@@ -94,6 +101,7 @@ class CTRLModelTester:
n_ctx=self.max_position_embeddings,
# type_vocab_size=self.type_vocab_size,
# initializer_range=self.initializer_range,
pad_token_id=self.pad_token_id,
)
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
@@ -149,11 +157,20 @@ class CTRLModelTester:
return config, inputs_dict
def create_and_check_ctrl_for_sequence_classification(self, config, input_ids, head_mask, token_type_ids, *args):
config.num_labels = self.num_labels
model = CTRLForSequenceClassification(config)
model.to(torch_device)
model.eval()
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
result = model(input_ids, token_type_ids=token_type_ids, labels=sequence_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
@require_torch
class CTRLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (CTRLModel, CTRLLMHeadModel) if is_torch_available() else ()
all_model_classes = (CTRLModel, CTRLLMHeadModel, CTRLForSequenceClassification) if is_torch_available() else ()
all_generative_model_classes = (CTRLLMHeadModel,) if is_torch_available() else ()
test_pruning = True
test_torchscript = False