[Wav2Vec2, Hubert] Fix ctc loss test (#12458)
* fix_torch_device_generate_test * remove @ * fix test
This commit is contained in:
committed by
GitHub
parent
b655f16d4e
commit
27d348f2fe
@@ -176,12 +176,13 @@ class HubertModelTester:
|
||||
attention_mask[i, input_lengths[i] :] = 0
|
||||
|
||||
model.config.ctc_loss_reduction = "sum"
|
||||
sum_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss
|
||||
sum_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item()
|
||||
|
||||
model.config.ctc_loss_reduction = "mean"
|
||||
mean_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss
|
||||
mean_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item()
|
||||
|
||||
self.parent.assertTrue(abs(labels.shape[0] * labels.shape[1] * mean_loss.item() - sum_loss.item()) < 1e-3)
|
||||
self.parent.assertTrue(isinstance(sum_loss, float))
|
||||
self.parent.assertTrue(isinstance(mean_loss, float))
|
||||
|
||||
def check_training(self, config, input_values, *args):
|
||||
config.ctc_zero_infinity = True
|
||||
|
||||
Reference in New Issue
Block a user