[refactor] set attention implementation (#38974)
* update * fix some tests * init from config, changes it in-place, add deepcopy in tests * fix modernbert * don't delete thsi config attr * update * style and copies * skip tests in generation * fix style * accidentally removed flash-attn-3, revert * docs * forgot about flags set to False * fix copies * address a few comments * fix copies * custom code BC
This commit is contained in:
committed by
GitHub
parent
6017f5e8ed
commit
8d6259b0b8
@@ -699,7 +699,7 @@ class ModelTesterMixin:
|
||||
def test_from_pretrained_no_checkpoint(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model = model_class(copy.deepcopy(config))
|
||||
state_dict = model.state_dict()
|
||||
|
||||
new_model = model_class.from_pretrained(
|
||||
@@ -714,7 +714,7 @@ class ModelTesterMixin:
|
||||
if model_class._keep_in_fp32_modules is None:
|
||||
self.skipTest(reason="Model class has no _keep_in_fp32_modules attribute defined")
|
||||
|
||||
model = model_class(config)
|
||||
model = model_class(copy.deepcopy(config))
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
|
||||
@@ -730,7 +730,7 @@ class ModelTesterMixin:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model = model_class(copy.deepcopy(config))
|
||||
_keys_to_ignore_on_save = getattr(model, "_keys_to_ignore_on_save", None)
|
||||
if _keys_to_ignore_on_save is None:
|
||||
continue
|
||||
@@ -766,7 +766,7 @@ class ModelTesterMixin:
|
||||
continue
|
||||
|
||||
config.gradient_checkpointing = True
|
||||
model = model_class(config)
|
||||
model = model_class(copy.deepcopy(config))
|
||||
self.assertTrue(model.is_gradient_checkpointing)
|
||||
|
||||
def test_gradient_checkpointing_enable_disable(self):
|
||||
@@ -777,7 +777,7 @@ class ModelTesterMixin:
|
||||
continue
|
||||
|
||||
# at init model should have gradient checkpointing disabled
|
||||
model = model_class(config)
|
||||
model = model_class(copy.deepcopy(config))
|
||||
self.assertFalse(model.is_gradient_checkpointing)
|
||||
|
||||
# check enable works
|
||||
@@ -810,7 +810,7 @@ class ModelTesterMixin:
|
||||
continue
|
||||
|
||||
# at init model should have gradient checkpointing disabled
|
||||
model = model_class(config)
|
||||
model = model_class(copy.deepcopy(config))
|
||||
self.assertFalse(model.is_gradient_checkpointing)
|
||||
|
||||
# check enable works
|
||||
@@ -871,7 +871,7 @@ class ModelTesterMixin:
|
||||
|
||||
# First, initialize the model from config -> this ensure everything is correctly initialized, even if
|
||||
# _init_weights() does not take all weights into account correctly
|
||||
model_from_config = model_class(config)
|
||||
model_from_config = model_class(copy.deepcopy(config))
|
||||
# Here, passing an empty state dict will force all weights to be moved from meta to cpu, then be initialized
|
||||
# by _init_weights()
|
||||
model_from_pretrained = model_class.from_pretrained(None, config=config, state_dict={})
|
||||
@@ -944,7 +944,7 @@ class ModelTesterMixin:
|
||||
base_class_copy._init_weights = _mock_init_weights
|
||||
base_class_copy.init_weights = _mock_all_init_weights
|
||||
|
||||
model = model_class(config)
|
||||
model = model_class(copy.deepcopy(config))
|
||||
state_dict = model.state_dict()
|
||||
|
||||
def check_equal(loaded):
|
||||
@@ -969,7 +969,7 @@ class ModelTesterMixin:
|
||||
|
||||
configs_no_init = _config_zero_init(config)
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=configs_no_init)
|
||||
model = model_class(config=copy.deepcopy(configs_no_init))
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
data = torch.flatten(param.data)
|
||||
@@ -1000,7 +1000,7 @@ class ModelTesterMixin:
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model = model_class(copy.deepcopy(config))
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
@@ -1075,7 +1075,7 @@ class ModelTesterMixin:
|
||||
if hasattr(self.model_tester, "prepare_config_and_inputs_for_model_class"):
|
||||
config, batched_input = self.model_tester.prepare_config_and_inputs_for_model_class(model_class)
|
||||
batched_input_prepared = self._prepare_for_class(batched_input, model_class)
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
model = model_class(copy.deepcopy(config)).to(torch_device).eval()
|
||||
set_model_for_less_flaky_test(model)
|
||||
|
||||
batch_size = self.model_tester.batch_size
|
||||
@@ -1932,7 +1932,7 @@ class ModelTesterMixin:
|
||||
|
||||
def test_hidden_states_output(self):
|
||||
def check_hidden_states_output(inputs_dict, config, model_class):
|
||||
model = model_class(config)
|
||||
model = model_class(copy.deepcopy(config))
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
@@ -2061,16 +2061,15 @@ class ModelTesterMixin:
|
||||
) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
torch.manual_seed(0)
|
||||
config = copy.deepcopy(original_config)
|
||||
model = model_class(config)
|
||||
model = model_class(copy.deepcopy(original_config))
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
hidden_states_no_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0]
|
||||
|
||||
torch.manual_seed(0)
|
||||
config.chunk_size_feed_forward = 1
|
||||
model = model_class(config)
|
||||
original_config.chunk_size_feed_forward = 1
|
||||
model = model_class(copy.deepcopy(original_config))
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
@@ -2445,7 +2444,7 @@ class ModelTesterMixin:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model = model_class(copy.deepcopy(config))
|
||||
self.assertIsInstance(model.get_input_embeddings(), nn.Embedding)
|
||||
|
||||
new_input_embedding_layer = nn.Embedding(10, 10)
|
||||
@@ -2505,7 +2504,7 @@ class ModelTesterMixin:
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config.torchscript = True
|
||||
model_not_tied = model_class(config)
|
||||
model_not_tied = model_class(copy.deepcopy(config))
|
||||
if model_not_tied.get_output_embeddings() is None:
|
||||
continue
|
||||
|
||||
@@ -2582,7 +2581,7 @@ class ModelTesterMixin:
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.get_text_config().tie_word_embeddings = True
|
||||
for model_class in self.all_model_classes:
|
||||
model_tied = model_class(config)
|
||||
model_tied = model_class(copy.deepcopy(config))
|
||||
|
||||
ptrs = collections.defaultdict(list)
|
||||
for name, tensor in model_tied.state_dict().items():
|
||||
@@ -2707,7 +2706,7 @@ class ModelTesterMixin:
|
||||
recursive_check(tuple_output, dict_output)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model = model_class(copy.deepcopy(config))
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
@@ -3033,7 +3032,7 @@ class ModelTesterMixin:
|
||||
continue
|
||||
|
||||
inputs_dict_class = self._prepare_for_class(inputs_dict, model_class)
|
||||
model = model_class(config).eval()
|
||||
model = model_class(copy.deepcopy(config)).eval()
|
||||
model = model.to(torch_device)
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict_class)
|
||||
@@ -3077,7 +3076,7 @@ class ModelTesterMixin:
|
||||
continue
|
||||
|
||||
inputs_dict_class = self._prepare_for_class(inputs_dict, model_class)
|
||||
model = model_class(config).eval()
|
||||
model = model_class(copy.deepcopy(config)).eval()
|
||||
model = model.to(torch_device)
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict_class)
|
||||
@@ -3115,7 +3114,7 @@ class ModelTesterMixin:
|
||||
continue
|
||||
|
||||
inputs_dict_class = self._prepare_for_class(inputs_dict, model_class)
|
||||
model = model_class(config).eval()
|
||||
model = model_class(copy.deepcopy(config)).eval()
|
||||
model = model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
@@ -3470,7 +3469,7 @@ class ModelTesterMixin:
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model = model_class(copy.deepcopy(config))
|
||||
num_params = model.num_parameters()
|
||||
assert num_params < 1000000, (
|
||||
f"{model_class} is too big for the common tests ({num_params})! It should have 1M max."
|
||||
@@ -3628,7 +3627,7 @@ class ModelTesterMixin:
|
||||
|
||||
# set eager as it will be the one supported in all models
|
||||
# we just need to test if passing 'attn_implementation' as a dict fails or not
|
||||
attn_implementation_per_subconfig = {}
|
||||
attn_implementation_per_subconfig = {"": "eager"}
|
||||
for key in config.sub_configs.keys():
|
||||
attn_implementation_per_subconfig[key] = "eager"
|
||||
|
||||
@@ -4717,7 +4716,7 @@ class ModelTesterMixin:
|
||||
for model_class in self.all_model_classes:
|
||||
# If it does not raise here, the test passes
|
||||
with torch.device("meta"):
|
||||
_ = model_class(config)
|
||||
_ = model_class(copy.deepcopy(config))
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_can_load_with_device_context_manager(self):
|
||||
|
||||
Reference in New Issue
Block a user