small improvements (#10773)

This commit is contained in:
Patrick von Platen
2021-03-17 18:10:17 +03:00
committed by GitHub
parent d7e0d59bb7
commit 0486ccdd3d

View File

@@ -162,7 +162,7 @@ class Wav2Vec2ModelTester:
model.eval()
input_values = input_values[:3]
attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.bool)
attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.long)
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
@@ -171,7 +171,7 @@ class Wav2Vec2ModelTester:
# pad input
for i in range(len(input_lengths)):
input_values[i, input_lengths[i] :] = 0.0
attention_mask[i, input_lengths[i] :] = 0.0
attention_mask[i, input_lengths[i] :] = 0
model.config.ctc_loss_reduction = "sum"
sum_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss