Enable multi-device for some models (#30207)
* feat: multidevice for resnet * feat: yes! resnet * fix: compare all elements in tuple * feat: support for regnet * feat: support for convnextv2 * feat: support for bit * feat: support for cvt * feat: add support for focalnet * feat: support for yolos * feat: support for glpn * feat: support for imagegpt * feat: support for levit * feat: support for mgp_str * feat: support for mobilnet_v1 * feat: support for mobilnet_v2 * feat: support for mobilevit * feat: support for mobilevitv2 * feat: support for poolformer * fix: copies * fix: code quality check * update: upstream changes from main * fix: consistency check * feat: support for sam * feat: support for switchformer * feat: support for swin * feat: support for swinv2 * feat: support for timesformer * feat: suport for trocr * feat: support for upernet * fix: check copies * update: rerun CI * update: rerun again, maybe * update: one more rerun --------- Co-authored-by: Jacky Lee <jackylee328@gmail.com>
This commit is contained in:
@@ -2907,7 +2907,10 @@ class ModelTesterMixin:
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict_class)
|
||||
|
||||
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
|
||||
if isinstance(base_output[0], tuple) and isinstance(new_output[0], tuple):
|
||||
self.assertTrue(torch.allclose(a, b, atol=1e-5) for a, b in zip(base_output[0], new_output[0]))
|
||||
else:
|
||||
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
|
||||
|
||||
@require_accelerate
|
||||
@mark.accelerate_tests
|
||||
@@ -2939,7 +2942,10 @@ class ModelTesterMixin:
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict_class)
|
||||
|
||||
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
|
||||
if isinstance(base_output[0], tuple) and isinstance(new_output[0], tuple):
|
||||
self.assertTrue(torch.allclose(a, b, atol=1e-5) for a, b in zip(base_output[0], new_output[0]))
|
||||
else:
|
||||
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
|
||||
|
||||
@require_accelerate
|
||||
@mark.accelerate_tests
|
||||
@@ -2975,7 +2981,10 @@ class ModelTesterMixin:
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict_class)
|
||||
|
||||
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
|
||||
if isinstance(base_output[0], tuple) and isinstance(new_output[0], tuple):
|
||||
self.assertTrue(torch.allclose(a, b, atol=1e-5) for a, b in zip(base_output[0], new_output[0]))
|
||||
else:
|
||||
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
|
||||
|
||||
@require_accelerate
|
||||
@mark.accelerate_tests
|
||||
@@ -3011,7 +3020,10 @@ class ModelTesterMixin:
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict_class)
|
||||
|
||||
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
|
||||
if isinstance(base_output[0], tuple) and isinstance(new_output[0], tuple):
|
||||
self.assertTrue(torch.allclose(a, b, atol=1e-5) for a, b in zip(base_output[0], new_output[0]))
|
||||
else:
|
||||
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
|
||||
|
||||
def test_problem_types(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
Reference in New Issue
Block a user