Add gemma 2 (#31659)

* inital commit

* Add doc

* protect?

* fixup stuffs

* update tests

* fix build documentation

* mmmmmmm config attributes

* style

* nit

* uodate

* nit

* Fix docs

* protect some stuff

---------

Co-authored-by: Lysandre <lysandre@huggingface.co>
This commit is contained in:
Arthur
2024-06-27 17:36:19 +02:00
committed by GitHub
parent 4aa17d0069
commit 0cf60f13ab
24 changed files with 3057 additions and 69 deletions

View File

@@ -47,11 +47,18 @@ if is_torch_available():
GemmaForSequenceClassification,
GemmaForTokenClassification,
GemmaModel,
GemmaTokenizer,
)
@require_torch
class GemmaModelTester:
config_class = GemmaConfig
if is_torch_available():
model_class = GemmaModel
for_causal_lm_class = GemmaForCausalLM
for_sequence_class = GemmaForSequenceClassification
for_token_class = GemmaForTokenClassification
def __init__(
self,
parent,
@@ -129,9 +136,8 @@ class GemmaModelTester:
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
# Ignore copy
def get_config(self):
return GemmaConfig(
return self.config_class(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
@@ -149,18 +155,16 @@ class GemmaModelTester:
head_dim=self.head_dim,
)
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Gemma
def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = GemmaModel(config=config)
model = self.model_class(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask)
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model_as_decoder with Llama->Gemma
def create_and_check_model_as_decoder(
self,
config,
@@ -174,7 +178,7 @@ class GemmaModelTester:
encoder_attention_mask,
):
config.add_cross_attention = True
model = GemmaModel(config)
model = self.model_class(config)
model.to(torch_device)
model.eval()
result = model(
@@ -191,7 +195,6 @@ class GemmaModelTester:
result = model(input_ids, attention_mask=input_mask)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_for_causal_lm with Llama->Gemma
def create_and_check_for_causal_lm(
self,
config,
@@ -204,13 +207,12 @@ class GemmaModelTester:
encoder_hidden_states,
encoder_attention_mask,
):
model = GemmaForCausalLM(config=config)
model = self.for_causal_lm_class(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_decoder_model_past_large_inputs with Llama->Gemma
def create_and_check_decoder_model_past_large_inputs(
self,
config,
@@ -225,7 +227,7 @@ class GemmaModelTester:
):
config.is_decoder = True
config.add_cross_attention = True
model = GemmaForCausalLM(config=config)
model = self.for_causal_lm_class(config=config)
model.to(torch_device)
model.eval()
@@ -348,7 +350,7 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
model = GemmaForSequenceClassification(config)
model = self.model_tester.for_sequence_class(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
@@ -361,7 +363,7 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size)
model = GemmaForSequenceClassification(config)
model = self.model_tester.for_sequence_class(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
@@ -376,20 +378,19 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
sequence_labels = ids_tensor(
[self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size
).to(torch.float)
model = GemmaForSequenceClassification(config)
model = self.model_tester.for_sequence_class(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_token_classification_model with Llama->Gemma,llama->Gemma
def test_Gemma_token_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
token_labels = ids_tensor([self.model_tester.batch_size, self.model_tester.seq_length], config.num_labels)
model = GemmaForTokenClassification(config=config)
model = self.model_tester.for_token_class(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=token_labels)
@@ -539,47 +540,9 @@ class GemmaIntegrationTest(unittest.TestCase):
# 8 is for A100 / A10 and 7 for T4
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
@require_read_token
def test_model_2b_fp32(self):
model_id = "google/gemma-2b"
EXPECTED_TEXTS = [
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
]
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True).to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_TEXTS)
@require_read_token
def test_model_2b_fp16(self):
model_id = "google/gemma-2b"
EXPECTED_TEXTS = [
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
]
model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16).to(
torch_device
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_TEXTS)
@require_read_token
def test_model_2b_fp16_static_cache(self):
model_id = "google/gemma-2b"
model_id = "google/gemma-2-9b"
EXPECTED_TEXTS = [
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
@@ -903,7 +866,7 @@ class GemmaIntegrationTest(unittest.TestCase):
}
prompts = ["Hello I am doing", "Hi today"]
tokenizer = GemmaTokenizer.from_pretrained("google/gemma-2b", pad_token="</s>", padding_side="right")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b", pad_token="</s>", padding_side="right")
model = GemmaForCausalLM.from_pretrained("google/gemma-2b", device_map="sequential", torch_dtype=torch.float16)
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)