Purge unused ModelTester code (#37085)

* Purge correctly this time

* Remove more methods from recent PRs

* make fixup
This commit is contained in:
Matt
2025-04-03 17:48:35 +01:00
committed by GitHub
parent 1b29409d89
commit 2d46a08b63
72 changed files with 3 additions and 4286 deletions

View File

@@ -246,32 +246,6 @@ class Wav2Vec2BertModelTester:
result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size)
)
def create_and_check_batch_inference(self, config, input_features, *args):
# test does not pass for models making use of `group_norm`
# check: https://github.com/pytorch/fairseq/issues/3227
model = Wav2Vec2BertModel(config=config)
model.to(torch_device)
model.eval()
input_features = input_features[:3]
attention_mask = torch.ones(input_features.shape, device=torch_device, dtype=torch.bool)
input_lengths = [input_features.shape[-1] // i for i in [4, 2, 1]]
# pad input
for i in range(len(input_lengths)):
input_features[i, input_lengths[i] :] = 0.0
attention_mask[i, input_lengths[i] :] = 0.0
batch_outputs = model(input_features, attention_mask=attention_mask).last_hidden_state
for i in range(input_features.shape[0]):
input_slice = input_features[i : i + 1, : input_lengths[i]]
output = model(input_slice).last_hidden_state
batch_output = batch_outputs[i : i + 1, : output.shape[1]]
self.parent.assertTrue(torch.allclose(output, batch_output, atol=1e-3))
def check_ctc_loss(self, config, input_features, *args):
model = Wav2Vec2BertForCTC(config=config)
model.to(torch_device)