Add Cohere2 model (#35224)

This commit is contained in:
alexrs-cohere
2024-12-13 09:35:50 +01:00
committed by GitHub
parent e4e404fdd0
commit 64478c7631
19 changed files with 2508 additions and 9 deletions

View File

@@ -40,6 +40,11 @@ if is_torch_available():
# Copied from transformers.tests.models.llama.LlamaModelTester with Llama->Cohere
class CohereModelTester:
config_class = CohereConfig
if is_torch_available():
model_class = CohereModel
for_causal_lm_class = CohereForCausalLM
def __init__(
self,
parent,
@@ -51,7 +56,7 @@ class CohereModelTester:
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=2,
num_hidden_layers=4,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
@@ -115,7 +120,7 @@ class CohereModelTester:
# Ignore copy
def get_config(self):
return CohereConfig(
return self.config_class(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
@@ -129,13 +134,12 @@ class CohereModelTester:
is_decoder=False,
initializer_range=self.initializer_range,
pad_token_id=self.pad_token_id,
eos_token_id=self.pad_token_id,
)
def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = CohereModel(config=config)
model = self.model_class(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask)
@@ -155,7 +159,7 @@ class CohereModelTester:
encoder_attention_mask,
):
config.add_cross_attention = True
model = CohereModel(config)
model = self.model_class(config)
model.to(torch_device)
model.eval()
result = model(
@@ -184,7 +188,7 @@ class CohereModelTester:
encoder_hidden_states,
encoder_attention_mask,
):
model = CohereForCausalLM(config=config)
model = self.for_causal_lm_class(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
@@ -204,7 +208,7 @@ class CohereModelTester:
):
config.is_decoder = True
config.add_cross_attention = True
model = CohereForCausalLM(config=config)
model = self.for_causal_lm_class(config=config)
model.to(torch_device)
model.eval()
@@ -281,7 +285,7 @@ class CohereModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
)
test_headmasking = False
test_pruning = False
fx_compatible = True
fx_compatible = False
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
# This is because we are hitting edge cases with the causal_mask buffer