Fix SeamlessM4Tv2ModelIntegrationTest (#27911)

change dtype of some integration tests
This commit is contained in:
Yoach Lacombe
2023-12-11 08:18:41 +00:00
committed by GitHub
parent e96c1de191
commit 5e620a92cf

View File

@@ -1014,8 +1014,9 @@ class SeamlessM4Tv2ModelIntegrationTest(unittest.TestCase):
) )
def factory_test_task(self, class1, class2, inputs, class1_kwargs, class2_kwargs): def factory_test_task(self, class1, class2, inputs, class1_kwargs, class2_kwargs):
model1 = class1.from_pretrained(self.repo_id).to(torch_device) # half-precision loading to limit GPU usage
model2 = class2.from_pretrained(self.repo_id).to(torch_device) model1 = class1.from_pretrained(self.repo_id, torch_dtype=torch.float16).to(torch_device)
model2 = class2.from_pretrained(self.repo_id, torch_dtype=torch.float16).to(torch_device)
set_seed(0) set_seed(0)
output_1 = model1.generate(**inputs, **class1_kwargs) output_1 = model1.generate(**inputs, **class1_kwargs)