Gemma capping (#34282)
* softcapping * soft cap before the mask * style * ... * super nit * update * fixes * update * small issue with modular * fix modular imports * update * fixup * simplify a hell lot * simplify cleaning imports * finish fixing * update our design * nits * use a deprecation cycle * updates * Fix modular (recursive deps need to always be computed after merges!) * push * fix * update * fix modular order * make fix-copies * updates * update * ? * don't compile for now * ? * fix some stuff * donc! * fix copies * update * fixup * ? * fix two tests * fix? * for now, don't use head info * eager when output attentoin and sdpa or flash as it's the simplest behaviour (for our tests as well :)) * fix-copies * revert sdpa check * Apply suggestions from code review Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co> * rebase, fix-copies and push * add a slow integration test * update the test * fix left padding issue * fix test * remove duplicate scaling * quality * add a small test and make sure it works * 2b --------- Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com> Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
This commit is contained in:
@@ -199,19 +199,6 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
|
||||
def test_sdpa_equivalence(self):
|
||||
pass
|
||||
|
||||
def test_eager_attention_loaded_by_default(self):
|
||||
"""Gemma 2 + SDPA = inferior results, because of the logit softcapping. Eager is the default."""
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# Usually we enable SDPA by default, but not for Gemma2
|
||||
model = Gemma2Model(config)
|
||||
self.assertTrue(model.config._attn_implementation == "eager")
|
||||
|
||||
# We can still force SDPA
|
||||
config._attn_implementation = "sdpa"
|
||||
model = Gemma2Model(config)
|
||||
self.assertTrue(model.config._attn_implementation == "sdpa")
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@@ -277,9 +264,30 @@ class Gemma2IntegrationTest(unittest.TestCase):
|
||||
"Hi today I'm going to be talking about the history of the United States. The United States of America",
|
||||
]
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to(
|
||||
torch_device
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flex_attention"
|
||||
).to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
|
||||
|
||||
output = pipe(self.input_text, max_new_tokens=20, do_sample=False, padding=True)
|
||||
|
||||
self.assertEqual(output[0][0]["generated_text"], EXPECTED_TEXTS[0])
|
||||
self.assertEqual(output[1][0]["generated_text"], EXPECTED_TEXTS[1])
|
||||
|
||||
@require_read_token
|
||||
def test_model_2b_pipeline_bf16_flex_attention(self):
|
||||
# See https://github.com/huggingface/transformers/pull/31747 -- pipeline was broken for Gemma2 before this PR
|
||||
model_id = "google/gemma-2-2b"
|
||||
# EXPECTED_TEXTS should match the same non-pipeline test, minus the special tokens
|
||||
EXPECTED_TEXTS = [
|
||||
"Hello I am doing a project on the 1960s and I am trying to find out what the average",
|
||||
"Hi today I'm going to be talking about the 10 best anime of all time.\n\n1",
|
||||
]
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flex_attention"
|
||||
).to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
|
||||
|
||||
@@ -365,3 +373,23 @@ class Gemma2IntegrationTest(unittest.TestCase):
|
||||
)
|
||||
ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text)
|
||||
|
||||
@require_read_token
|
||||
def test_model_9b_bf16_flex_attention(self):
|
||||
model_id = "google/gemma-2-9b"
|
||||
EXPECTED_TEXTS = [
|
||||
"<bos>Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many",
|
||||
"<pad><pad><bos>Hi today I'm going to be talking about the history of the United States. The United States of America",
|
||||
]
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flex_attention"
|
||||
).to(torch_device)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
output_text = tokenizer.batch_decode(output, skip_special_tokens=False)
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
Reference in New Issue
Block a user