BartForSequenceClassification: fix num_labels, add test (#3110)
This commit is contained in:
@@ -1324,7 +1324,7 @@ class BartForSequenceClassification(PretrainedBartModel):
|
|||||||
# Prepend logits
|
# Prepend logits
|
||||||
outputs = (logits,) + outputs[1:] # Add hidden states and attention if they are here
|
outputs = (logits,) + outputs[1:] # Add hidden states and attention if they are here
|
||||||
if labels is not None: # prepend loss to output,
|
if labels is not None: # prepend loss to output,
|
||||||
loss = F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1))
|
loss = F.cross_entropy(logits.view(-1, self.config.num_labels), labels.view(-1))
|
||||||
outputs = (loss,) + outputs
|
outputs = (loss,) + outputs
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|||||||
@@ -171,7 +171,7 @@ class BartHeadTests(unittest.TestCase):
|
|||||||
|
|
||||||
vocab_size = 99
|
vocab_size = 99
|
||||||
|
|
||||||
def test_lm_forward(self):
|
def _get_config_and_data(self, output_past=False):
|
||||||
input_ids = torch.tensor(
|
input_ids = torch.tensor(
|
||||||
[
|
[
|
||||||
[71, 82, 18, 33, 46, 91, 2],
|
[71, 82, 18, 33, 46, 91, 2],
|
||||||
@@ -191,9 +191,8 @@ class BartHeadTests(unittest.TestCase):
|
|||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=torch_device,
|
device=torch_device,
|
||||||
)
|
)
|
||||||
batch_size = input_ids.shape[0]
|
|
||||||
decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size)
|
|
||||||
|
|
||||||
|
batch_size = input_ids.shape[0]
|
||||||
config = BartConfig(
|
config = BartConfig(
|
||||||
vocab_size=self.vocab_size,
|
vocab_size=self.vocab_size,
|
||||||
d_model=24,
|
d_model=24,
|
||||||
@@ -204,14 +203,25 @@ class BartHeadTests(unittest.TestCase):
|
|||||||
encoder_ffn_dim=32,
|
encoder_ffn_dim=32,
|
||||||
decoder_ffn_dim=32,
|
decoder_ffn_dim=32,
|
||||||
max_position_embeddings=48,
|
max_position_embeddings=48,
|
||||||
|
output_past=output_past,
|
||||||
)
|
)
|
||||||
|
return config, input_ids, batch_size
|
||||||
|
|
||||||
|
def test_sequence_classification_forward(self):
|
||||||
|
config, input_ids, batch_size = self._get_config_and_data()
|
||||||
|
labels = _long_tensor([2] * batch_size).to(torch_device)
|
||||||
model = BartForSequenceClassification(config)
|
model = BartForSequenceClassification(config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
outputs = model.forward(input_ids=input_ids, decoder_input_ids=input_ids)
|
outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, labels=labels)
|
||||||
logits = outputs[0]
|
logits = outputs[1]
|
||||||
expected_shape = torch.Size((batch_size, config.num_labels))
|
expected_shape = torch.Size((batch_size, config.num_labels))
|
||||||
self.assertEqual(logits.shape, expected_shape)
|
self.assertEqual(logits.shape, expected_shape)
|
||||||
|
loss = outputs[0]
|
||||||
|
self.assertIsInstance(loss.item(), float)
|
||||||
|
|
||||||
|
def test_lm_forward(self):
|
||||||
|
config, input_ids, batch_size = self._get_config_and_data(output_past=False)
|
||||||
|
decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size)
|
||||||
lm_model = BartForMaskedLM(config)
|
lm_model = BartForMaskedLM(config)
|
||||||
lm_model.to(torch_device)
|
lm_model.to(torch_device)
|
||||||
loss, logits, enc_features = lm_model.forward(
|
loss, logits, enc_features = lm_model.forward(
|
||||||
|
|||||||
Reference in New Issue
Block a user