Add Cohere2 model (#35224)
This commit is contained in:
@@ -40,6 +40,11 @@ if is_torch_available():
|
||||
|
||||
# Copied from transformers.tests.models.llama.LlamaModelTester with Llama->Cohere
|
||||
class CohereModelTester:
|
||||
config_class = CohereConfig
|
||||
if is_torch_available():
|
||||
model_class = CohereModel
|
||||
for_causal_lm_class = CohereForCausalLM
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
@@ -51,7 +56,7 @@ class CohereModelTester:
|
||||
use_labels=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=2,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=37,
|
||||
hidden_act="gelu",
|
||||
@@ -115,7 +120,7 @@ class CohereModelTester:
|
||||
|
||||
# Ignore copy
|
||||
def get_config(self):
|
||||
return CohereConfig(
|
||||
return self.config_class(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
@@ -129,13 +134,12 @@ class CohereModelTester:
|
||||
is_decoder=False,
|
||||
initializer_range=self.initializer_range,
|
||||
pad_token_id=self.pad_token_id,
|
||||
eos_token_id=self.pad_token_id,
|
||||
)
|
||||
|
||||
def create_and_check_model(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = CohereModel(config=config)
|
||||
model = self.model_class(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask)
|
||||
@@ -155,7 +159,7 @@ class CohereModelTester:
|
||||
encoder_attention_mask,
|
||||
):
|
||||
config.add_cross_attention = True
|
||||
model = CohereModel(config)
|
||||
model = self.model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(
|
||||
@@ -184,7 +188,7 @@ class CohereModelTester:
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
):
|
||||
model = CohereForCausalLM(config=config)
|
||||
model = self.for_causal_lm_class(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
||||
@@ -204,7 +208,7 @@ class CohereModelTester:
|
||||
):
|
||||
config.is_decoder = True
|
||||
config.add_cross_attention = True
|
||||
model = CohereForCausalLM(config=config)
|
||||
model = self.for_causal_lm_class(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
@@ -281,7 +285,7 @@ class CohereModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
||||
)
|
||||
test_headmasking = False
|
||||
test_pruning = False
|
||||
fx_compatible = True
|
||||
fx_compatible = False
|
||||
|
||||
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
|
||||
# This is because we are hitting edge cases with the causal_mask buffer
|
||||
|
||||
Reference in New Issue
Block a user