Add tests for batching support (#29297)
* add tests for batching support * Update src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update tests/test_modeling_common.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update tests/test_modeling_common.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update tests/test_modeling_common.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * fixes and comments * use cosine distance for conv models * skip mra model testing * Update tests/models/vilt/test_modeling_vilt.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * finzalize and make style * check model type by input names * Update tests/models/vilt/test_modeling_vilt.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * fixed batch size for all testers * Revert "fixed batch size for all testers" This reverts commit 525f3a0a058f069fbda00352cf202b728d40df99. * add batch_size for all testers * dict from model output * do not skip layoutlm * bring back some code from git revert * Update tests/test_modeling_common.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/test_modeling_common.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * clean-up * where did minus go in tolerance * make whisper happy * deal with consequences of losing minus * deal with consequences of losing minus * maskformer needs its own test for happiness * fix more models * tag flaky CV models from Amy's approval * make codestyle --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
11163fff58
commit
8e64ba2890
@@ -66,13 +66,13 @@ class UnivNetModelTester:
|
||||
|
||||
def prepare_noise_sequence(self):
|
||||
generator = torch.manual_seed(self.seed)
|
||||
noise_shape = (self.seq_length, self.in_channels)
|
||||
noise_shape = (self.batch_size, self.seq_length, self.in_channels)
|
||||
# Create noise on CPU for reproducibility
|
||||
noise_sequence = torch.randn(noise_shape, generator=generator, dtype=torch.float)
|
||||
return noise_sequence
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
spectrogram = floats_tensor([self.seq_length, self.num_mel_bins], scale=1.0)
|
||||
spectrogram = floats_tensor([self.batch_size, self.seq_length, self.num_mel_bins], scale=1.0)
|
||||
noise_sequence = self.prepare_noise_sequence()
|
||||
noise_sequence = noise_sequence.to(spectrogram.device)
|
||||
config = self.get_config()
|
||||
@@ -89,7 +89,7 @@ class UnivNetModelTester:
|
||||
def create_and_check_model(self, config, spectrogram, noise_sequence):
|
||||
model = UnivNetModel(config=config).to(torch_device).eval()
|
||||
result = model(spectrogram, noise_sequence)[0]
|
||||
self.parent.assertEqual(result.shape, (1, self.seq_length * 256))
|
||||
self.parent.assertEqual(result.shape, (self.batch_size, self.seq_length * 256))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config, spectrogram, noise_sequence = self.prepare_config_and_inputs()
|
||||
@@ -182,8 +182,8 @@ class UnivNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
batched_spectrogram = inputs["input_features"].unsqueeze(0).repeat(2, 1, 1)
|
||||
batched_noise_sequence = inputs["noise_sequence"].unsqueeze(0).repeat(2, 1, 1)
|
||||
batched_spectrogram = inputs["input_features"]
|
||||
batched_noise_sequence = inputs["noise_sequence"]
|
||||
with torch.no_grad():
|
||||
batched_outputs = model(
|
||||
batched_spectrogram.to(torch_device),
|
||||
@@ -205,37 +205,11 @@ class UnivNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(inputs["input_features"].to(torch_device), inputs["noise_sequence"].to(torch_device))[
|
||||
0
|
||||
]
|
||||
outputs = model(
|
||||
inputs["input_features"][:1].to(torch_device), inputs["noise_sequence"][:1].to(torch_device)
|
||||
)[0]
|
||||
self.assertTrue(outputs.shape[0] == 1, msg="Unbatched input should create batched output with bsz = 1")
|
||||
|
||||
def test_unbatched_batched_outputs_consistency(self):
|
||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
unbatched_spectrogram = inputs["input_features"].detach().clone()
|
||||
unbatched_noise_sequence = inputs["noise_sequence"].detach().clone()
|
||||
batched_spectrogram = inputs["input_features"].unsqueeze(0)
|
||||
batched_noise_sequence = inputs["noise_sequence"].unsqueeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
unbatched_outputs = model(
|
||||
unbatched_spectrogram.to(torch_device),
|
||||
unbatched_noise_sequence.to(torch_device),
|
||||
)[0]
|
||||
|
||||
batched_outputs = model(
|
||||
batched_spectrogram.to(torch_device),
|
||||
batched_noise_sequence.to(torch_device),
|
||||
)[0]
|
||||
|
||||
torch.testing.assert_close(unbatched_outputs, batched_outputs)
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
|
||||
Reference in New Issue
Block a user