Bart: new cache format (#35314)
* bart compile * add mbart * some more models touched by fix-copies * more * more models * even more models * fix copies * fix tests * fix copies * fix * biogpt accepts position ids now (breaking?) * fix failing non-slow tests * fix some tests * should not be removed * small update * Update src/transformers/models/bart/modeling_bart.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * update for last `main` * fix copies * clone `update_causal_mask` from llama * tmp * fixup * why? how? * fix bart tests * dont skip test * address comments * fix tests * fix * fixup and delete the file --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
committed by
GitHub
parent
3ab47b6ce3
commit
01ad9f4b49
@@ -735,6 +735,7 @@ class ModelTesterMixin:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
print(model_class)
|
||||
with torch.no_grad():
|
||||
first = model(**self._prepare_for_class(inputs_dict, model_class))[0]
|
||||
second = model(**self._prepare_for_class(inputs_dict, model_class))[0]
|
||||
@@ -4130,6 +4131,9 @@ class ModelTesterMixin:
|
||||
if "position_ids" not in inspect.signature(model.forward).parameters:
|
||||
self.skipTest("Model does not support position_ids")
|
||||
|
||||
if "position_ids" not in inspect.signature(model.forward).parameters:
|
||||
continue # this model doesn't accept position ids as input
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
|
||||
@@ -4268,7 +4272,16 @@ class ModelTesterMixin:
|
||||
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
|
||||
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device)
|
||||
|
||||
_ = fa2_model(input_ids=dummy_input, attention_mask=dummy_attention_mask)
|
||||
if config.is_encoder_decoder:
|
||||
_ = fa2_model(
|
||||
input_ids=dummy_input,
|
||||
attention_mask=dummy_attention_mask,
|
||||
decoder_input_ids=dummy_input.clone(),
|
||||
decoder_attention_mask=dummy_attention_mask.clone(),
|
||||
)
|
||||
else:
|
||||
_ = fa2_model(input_ids=dummy_input, attention_mask=dummy_attention_mask)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fa2_model.save_pretrained(tmpdirname)
|
||||
model_from_pretrained = model_class.from_pretrained(tmpdirname)
|
||||
@@ -4327,8 +4340,10 @@ class ModelTesterMixin:
|
||||
set_config_for_less_flaky_test(config)
|
||||
if getattr(config, "sliding_window", 0) is not None and getattr(config, "sliding_window", 0) > 0:
|
||||
self.skipTest(f"{model_class.__name__} with sliding window attention is not supported by this test")
|
||||
model = model_class(config).to(device=torch_device, dtype=torch.float32)
|
||||
model = model_class(config).to(device=torch_device, dtype=torch.float32).eval()
|
||||
set_model_for_less_flaky_test(model)
|
||||
if "position_ids" not in inspect.signature(model.forward).parameters:
|
||||
continue # model doesn't accept position ids and probably has special way to model positions
|
||||
|
||||
if "position_ids" not in inspect.signature(model.forward).parameters:
|
||||
continue # this model doesn't accept position ids as input
|
||||
|
||||
Reference in New Issue
Block a user