Adding OPTForSeqClassification class (#18123)

* Adding OPTForSeqClassification class

* Fix import issues

* Add documentation for optforseqclassification

* Remove checkout

* fix failing tests

* fix typo

* Fix code formatting

* Incorporating the PR feedbacks

* Incorporate PR Feedbacks

* Fix failing test and add new test for multi label setup

* Fix formatting issue

* Fix failing tests

* Fix formatting issues

* Fix failing tests

* Fix failing tests

* Fix failing tests

* Fix failing tests

* PR feedback
This commit is contained in:
Raghavan
2022-07-20 13:44:21 +05:30
committed by GitHub
parent 0ed4d0dfb6
commit dcec4c4387
7 changed files with 195 additions and 8 deletions

View File

@@ -32,7 +32,7 @@ from ...test_modeling_common import ModelTesterMixin, ids_tensor
if is_torch_available():
import torch
from transformers import GPT2Tokenizer, OPTForCausalLM, OPTModel
from transformers import GPT2Tokenizer, OPTForCausalLM, OPTForSequenceClassification, OPTModel
def prepare_opt_inputs_dict(
@@ -74,7 +74,9 @@ class OPTModelTester:
pad_token_id=1,
bos_token_id=0,
embed_dim=16,
num_labels=3,
word_embed_proj_dim=16,
type_sequence_label_size=2,
):
self.parent = parent
self.batch_size = batch_size
@@ -94,11 +96,12 @@ class OPTModelTester:
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.embed_dim = embed_dim
self.num_labels = num_labels
self.type_sequence_label_size = type_sequence_label_size
self.word_embed_proj_dim = word_embed_proj_dim
self.is_encoder_decoder = False
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
3,
)
@@ -175,7 +178,7 @@ class OPTModelTester:
@require_torch
class OPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (OPTModel, OPTForCausalLM) if is_torch_available() else ()
all_model_classes = (OPTModel, OPTForCausalLM, OPTForSequenceClassification) if is_torch_available() else ()
all_generative_model_classes = (OPTForCausalLM,) if is_torch_available() else ()
is_encoder_decoder = False
fx_compatible = True
@@ -242,6 +245,33 @@ class OPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
model.generate(input_ids, attention_mask=attention_mask)
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
def test_opt_sequence_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs()
config.num_labels = 3
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
model = OPTForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
def test_opt_sequence_classification_model_for_multi_label(self):
config, input_dict = self.model_tester.prepare_config_and_inputs()
config.num_labels = 3
config.problem_type = "multi_label_classification"
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
sequence_labels = ids_tensor(
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
).to(torch.float)
model = OPTForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
def assert_tensors_close(a, b, atol=1e-12, prefix=""):
"""If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""