From 0486ccdd3db82ef4b1f636f638f5744a4dcc2d44 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 17 Mar 2021 18:10:17 +0300 Subject: [PATCH] small improvements (#10773) --- tests/test_modeling_wav2vec2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_modeling_wav2vec2.py b/tests/test_modeling_wav2vec2.py index ef269fd65b..434526c749 100644 --- a/tests/test_modeling_wav2vec2.py +++ b/tests/test_modeling_wav2vec2.py @@ -162,7 +162,7 @@ class Wav2Vec2ModelTester: model.eval() input_values = input_values[:3] - attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.bool) + attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.long) input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]] max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths)) @@ -171,7 +171,7 @@ class Wav2Vec2ModelTester: # pad input for i in range(len(input_lengths)): input_values[i, input_lengths[i] :] = 0.0 - attention_mask[i, input_lengths[i] :] = 0.0 + attention_mask[i, input_lengths[i] :] = 0 model.config.ctc_loss_reduction = "sum" sum_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss