@@ -124,7 +124,6 @@ class MixedInt8Test(BaseMixedInt8Test):
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
@unittest.skip("Un-skip once https://github.com/mosaicml/llm-foundry/issues/703 is resolved")
|
|
||||||
def test_get_keys_to_not_convert_trust_remote_code(self):
|
def test_get_keys_to_not_convert_trust_remote_code(self):
|
||||||
r"""
|
r"""
|
||||||
Test the `get_keys_to_not_convert` function with `trust_remote_code` models.
|
Test the `get_keys_to_not_convert` function with `trust_remote_code` models.
|
||||||
@@ -135,11 +134,11 @@ class MixedInt8Test(BaseMixedInt8Test):
|
|||||||
|
|
||||||
model_id = "mosaicml/mpt-7b"
|
model_id = "mosaicml/mpt-7b"
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model_id, trust_remote_code=True, revision="72e5f594ce36f9cabfa2a9fd8f58b491eb467ee7"
|
model_id, trust_remote_code=True, revision="ada218f9a93b5f1c6dce48a4cc9ff01fcba431e7"
|
||||||
)
|
)
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model = AutoModelForCausalLM.from_config(
|
model = AutoModelForCausalLM.from_config(
|
||||||
config, trust_remote_code=True, code_revision="72e5f594ce36f9cabfa2a9fd8f58b491eb467ee7"
|
config, trust_remote_code=True, code_revision="ada218f9a93b5f1c6dce48a4cc9ff01fcba431e7"
|
||||||
)
|
)
|
||||||
self.assertEqual(get_keys_to_not_convert(model), ["transformer.wte"])
|
self.assertEqual(get_keys_to_not_convert(model), ["transformer.wte"])
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user