Gemma capping (#34282)

* softcapping

* soft cap before the mask

* style

* ...

* super nit

* update

* fixes

* update

* small issue with modular

* fix modular imports

* update

* fixup

* simplify a hell lot

* simplify cleaning imports

* finish fixing

* update our design

* nits

* use a deprecation cycle

* updates

* Fix modular (recursive deps need to always be computed after merges!)

* push

* fix

* update

* fix modular order

* make fix-copies

* updates

* update

* ?

* don't compile for now

* ?

* fix some stuff

* donc!

* fix copies

* update

* fixup

* ?

* fix two tests

* fix?

* for now, don't use head info

* eager when output attentoin and sdpa or flash as it's the simplest behaviour (for our tests as well :))

* fix-copies

* revert sdpa check

* Apply suggestions from code review

Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>

* rebase, fix-copies and push

* add a slow integration test

* update the test

* fix left padding issue

* fix test

* remove duplicate scaling

* quality

* add a small test and make sure it works

* 2b

---------

Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
This commit is contained in:
Arthur
2024-11-19 13:52:38 +01:00
committed by GitHub
parent 54739a320e
commit 4bff54f921
8 changed files with 431 additions and 541 deletions

View File

@@ -199,19 +199,6 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
def test_sdpa_equivalence(self):
pass
def test_eager_attention_loaded_by_default(self):
"""Gemma 2 + SDPA = inferior results, because of the logit softcapping. Eager is the default."""
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
# Usually we enable SDPA by default, but not for Gemma2
model = Gemma2Model(config)
self.assertTrue(model.config._attn_implementation == "eager")
# We can still force SDPA
config._attn_implementation = "sdpa"
model = Gemma2Model(config)
self.assertTrue(model.config._attn_implementation == "sdpa")
@slow
@require_torch_gpu
@@ -277,9 +264,30 @@ class Gemma2IntegrationTest(unittest.TestCase):
"Hi today I'm going to be talking about the history of the United States. The United States of America",
]
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to(
torch_device
)
model = AutoModelForCausalLM.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flex_attention"
).to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
output = pipe(self.input_text, max_new_tokens=20, do_sample=False, padding=True)
self.assertEqual(output[0][0]["generated_text"], EXPECTED_TEXTS[0])
self.assertEqual(output[1][0]["generated_text"], EXPECTED_TEXTS[1])
@require_read_token
def test_model_2b_pipeline_bf16_flex_attention(self):
# See https://github.com/huggingface/transformers/pull/31747 -- pipeline was broken for Gemma2 before this PR
model_id = "google/gemma-2-2b"
# EXPECTED_TEXTS should match the same non-pipeline test, minus the special tokens
EXPECTED_TEXTS = [
"Hello I am doing a project on the 1960s and I am trying to find out what the average",
"Hi today I'm going to be talking about the 10 best anime of all time.\n\n1",
]
model = AutoModelForCausalLM.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flex_attention"
).to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
@@ -365,3 +373,23 @@ class Gemma2IntegrationTest(unittest.TestCase):
)
ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text)
@require_read_token
def test_model_9b_bf16_flex_attention(self):
model_id = "google/gemma-2-9b"
EXPECTED_TEXTS = [
"<bos>Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many",
"<pad><pad><bos>Hi today I'm going to be talking about the history of the United States. The United States of America",
]
model = AutoModelForCausalLM.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flex_attention"
).to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=False)
self.assertEqual(output_text, EXPECTED_TEXTS)