Added Sequence Classification class in GPTNeo (#11906)

* seq classification changes

* fix tests
This commit is contained in:
Bhadresh Savani
2021-05-28 15:57:02 +05:30
committed by GitHub
parent 80d712fac6
commit e1205e478a
9 changed files with 159 additions and 4 deletions

View File

@@ -361,7 +361,6 @@ class GPT2ModelTester:
model = GPT2ForSequenceClassification(config)
model.to(torch_device)
model.eval()
print(config.num_labels, sequence_labels.size())
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))