BartForSequenceClassification: fix num_labels, add test (#3110)
This commit is contained in:
@@ -171,7 +171,7 @@ class BartHeadTests(unittest.TestCase):
|
||||
|
||||
vocab_size = 99
|
||||
|
||||
def test_lm_forward(self):
|
||||
def _get_config_and_data(self, output_past=False):
|
||||
input_ids = torch.tensor(
|
||||
[
|
||||
[71, 82, 18, 33, 46, 91, 2],
|
||||
@@ -191,9 +191,8 @@ class BartHeadTests(unittest.TestCase):
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
batch_size = input_ids.shape[0]
|
||||
decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size)
|
||||
|
||||
batch_size = input_ids.shape[0]
|
||||
config = BartConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
d_model=24,
|
||||
@@ -204,14 +203,25 @@ class BartHeadTests(unittest.TestCase):
|
||||
encoder_ffn_dim=32,
|
||||
decoder_ffn_dim=32,
|
||||
max_position_embeddings=48,
|
||||
output_past=output_past,
|
||||
)
|
||||
return config, input_ids, batch_size
|
||||
|
||||
def test_sequence_classification_forward(self):
|
||||
config, input_ids, batch_size = self._get_config_and_data()
|
||||
labels = _long_tensor([2] * batch_size).to(torch_device)
|
||||
model = BartForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
outputs = model.forward(input_ids=input_ids, decoder_input_ids=input_ids)
|
||||
logits = outputs[0]
|
||||
outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, labels=labels)
|
||||
logits = outputs[1]
|
||||
expected_shape = torch.Size((batch_size, config.num_labels))
|
||||
self.assertEqual(logits.shape, expected_shape)
|
||||
loss = outputs[0]
|
||||
self.assertIsInstance(loss.item(), float)
|
||||
|
||||
def test_lm_forward(self):
|
||||
config, input_ids, batch_size = self._get_config_and_data(output_past=False)
|
||||
decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size)
|
||||
lm_model = BartForMaskedLM(config)
|
||||
lm_model.to(torch_device)
|
||||
loss, logits, enc_features = lm_model.forward(
|
||||
|
||||
Reference in New Issue
Block a user