Bart: new cache format (#35314)

* bart compile

* add mbart

* some more models touched by fix-copies

* more

* more models

* even more models

* fix copies

* fix tests

* fix copies

* fix

* biogpt accepts position ids now (breaking?)

* fix failing non-slow tests

* fix some tests

* should not be removed

* small update

* Update src/transformers/models/bart/modeling_bart.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* update for last `main`

* fix copies

* clone `update_causal_mask` from llama

* tmp

* fixup

* why? how?

* fix bart tests

* dont skip test

* address comments

* fix tests

* fix

* fixup and delete the file

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
Raushan Turganbay
2025-05-16 13:26:54 +02:00
committed by GitHub
parent 3ab47b6ce3
commit 01ad9f4b49
46 changed files with 3904 additions and 1995 deletions

View File

@@ -735,6 +735,7 @@ class ModelTesterMixin:
model = model_class(config)
model.to(torch_device)
model.eval()
print(model_class)
with torch.no_grad():
first = model(**self._prepare_for_class(inputs_dict, model_class))[0]
second = model(**self._prepare_for_class(inputs_dict, model_class))[0]
@@ -4130,6 +4131,9 @@ class ModelTesterMixin:
if "position_ids" not in inspect.signature(model.forward).parameters:
self.skipTest("Model does not support position_ids")
if "position_ids" not in inspect.signature(model.forward).parameters:
continue # this model doesn't accept position ids as input
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
@@ -4268,7 +4272,16 @@ class ModelTesterMixin:
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device)
_ = fa2_model(input_ids=dummy_input, attention_mask=dummy_attention_mask)
if config.is_encoder_decoder:
_ = fa2_model(
input_ids=dummy_input,
attention_mask=dummy_attention_mask,
decoder_input_ids=dummy_input.clone(),
decoder_attention_mask=dummy_attention_mask.clone(),
)
else:
_ = fa2_model(input_ids=dummy_input, attention_mask=dummy_attention_mask)
with tempfile.TemporaryDirectory() as tmpdirname:
fa2_model.save_pretrained(tmpdirname)
model_from_pretrained = model_class.from_pretrained(tmpdirname)
@@ -4327,8 +4340,10 @@ class ModelTesterMixin:
set_config_for_less_flaky_test(config)
if getattr(config, "sliding_window", 0) is not None and getattr(config, "sliding_window", 0) > 0:
self.skipTest(f"{model_class.__name__} with sliding window attention is not supported by this test")
model = model_class(config).to(device=torch_device, dtype=torch.float32)
model = model_class(config).to(device=torch_device, dtype=torch.float32).eval()
set_model_for_less_flaky_test(model)
if "position_ids" not in inspect.signature(model.forward).parameters:
continue # model doesn't accept position ids and probably has special way to model positions
if "position_ids" not in inspect.signature(model.forward).parameters:
continue # this model doesn't accept position ids as input