[gemma3] support sequence classification task (#39465)
* add seq clf class * fix docs and add in auto-map * skip tests * optional pixels
This commit is contained in:
committed by
GitHub
parent
34133d0a79
commit
e42681b48b
@@ -267,3 +267,8 @@ visualizer("<img>What is shown in this image?")
|
||||
|
||||
[[autodoc]] Gemma3ForConditionalGeneration
|
||||
- forward
|
||||
|
||||
## Gemma3ForSequenceClassification
|
||||
|
||||
[[autodoc]] Gemma3ForSequenceClassification
|
||||
- forward
|
||||
|
||||
@@ -1135,6 +1135,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
("funnel", "FunnelForSequenceClassification"),
|
||||
("gemma", "GemmaForSequenceClassification"),
|
||||
("gemma2", "Gemma2ForSequenceClassification"),
|
||||
("gemma3", "Gemma3ForSequenceClassification"),
|
||||
("glm", "GlmForSequenceClassification"),
|
||||
("glm4", "Glm4ForSequenceClassification"),
|
||||
("gpt-sw3", "GPT2ForSequenceClassification"),
|
||||
|
||||
@@ -34,7 +34,7 @@ from ...generation import GenerationMixin
|
||||
from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
@@ -1212,10 +1212,99 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
|
||||
return create_masks_for_generate(**mask_kwargs)
|
||||
|
||||
|
||||
class Gemma3ForSequenceClassification(Gemma3PreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.model = Gemma3Model(config)
|
||||
self.score = nn.Linear(config.text_config.hidden_size, self.num_labels, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.get_input_embeddings()
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.set_input_embeddings(value)
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> SequenceClassifierOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
|
||||
transformer_outputs = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
pixel_values=pixel_values,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
token_type_ids=token_type_ids,
|
||||
use_cache=use_cache,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = transformer_outputs.last_hidden_state
|
||||
logits = self.score(hidden_states)
|
||||
|
||||
if input_ids is not None:
|
||||
batch_size = input_ids.shape[0]
|
||||
else:
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
|
||||
if self.config.text_config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.text_config.pad_token_id is None:
|
||||
last_non_pad_token = -1
|
||||
elif input_ids is not None:
|
||||
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||
non_pad_mask = (input_ids != self.config.text_config.pad_token_id).to(logits.device, torch.int32)
|
||||
token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
|
||||
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||
else:
|
||||
last_non_pad_token = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
|
||||
|
||||
return SequenceClassifierOutputWithPast(
|
||||
loss=loss,
|
||||
logits=pooled_logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Gemma3PreTrainedModel",
|
||||
"Gemma3TextModel",
|
||||
"Gemma3ForCausalLM",
|
||||
"Gemma3ForConditionalGeneration",
|
||||
"Gemma3Model",
|
||||
"Gemma3ForSequenceClassification",
|
||||
]
|
||||
|
||||
@@ -27,7 +27,7 @@ from ...configuration_utils import PretrainedConfig, layer_type_validation
|
||||
from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPast
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, SequenceClassifierOutputWithPast
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
from ...processing_utils import Unpack
|
||||
@@ -1069,6 +1069,94 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
|
||||
return create_masks_for_generate(**mask_kwargs)
|
||||
|
||||
|
||||
class Gemma3ForSequenceClassification(Gemma3PreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.model = Gemma3Model(config)
|
||||
self.score = nn.Linear(config.text_config.hidden_size, self.num_labels, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.get_input_embeddings()
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.set_input_embeddings(value)
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> SequenceClassifierOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
|
||||
transformer_outputs = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
pixel_values=pixel_values,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
token_type_ids=token_type_ids,
|
||||
use_cache=use_cache,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = transformer_outputs.last_hidden_state
|
||||
logits = self.score(hidden_states)
|
||||
|
||||
if input_ids is not None:
|
||||
batch_size = input_ids.shape[0]
|
||||
else:
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
|
||||
if self.config.text_config.pad_token_id is None and batch_size != 1:
|
||||
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
||||
if self.config.text_config.pad_token_id is None:
|
||||
last_non_pad_token = -1
|
||||
elif input_ids is not None:
|
||||
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
|
||||
non_pad_mask = (input_ids != self.config.text_config.pad_token_id).to(logits.device, torch.int32)
|
||||
token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
|
||||
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
|
||||
else:
|
||||
last_non_pad_token = -1
|
||||
logger.warning_once(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
|
||||
|
||||
return SequenceClassifierOutputWithPast(
|
||||
loss=loss,
|
||||
logits=pooled_logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Gemma3Config",
|
||||
"Gemma3TextConfig",
|
||||
@@ -1077,4 +1165,5 @@ __all__ = [
|
||||
"Gemma3ForCausalLM",
|
||||
"Gemma3ForConditionalGeneration",
|
||||
"Gemma3Model",
|
||||
"Gemma3ForSequenceClassification",
|
||||
]
|
||||
|
||||
@@ -53,6 +53,7 @@ if is_torch_available():
|
||||
from transformers import (
|
||||
Gemma3ForCausalLM,
|
||||
Gemma3ForConditionalGeneration,
|
||||
Gemma3ForSequenceClassification,
|
||||
Gemma3Model,
|
||||
Gemma3Processor,
|
||||
Gemma3TextModel,
|
||||
@@ -246,6 +247,7 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
|
||||
(
|
||||
Gemma3Model,
|
||||
Gemma3ForConditionalGeneration,
|
||||
Gemma3ForSequenceClassification,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
@@ -348,6 +350,14 @@ class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unitte
|
||||
def test_initialization(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Loading nested configs with overwritten `kwargs` isn't supported yet, FIXME @raushan.")
|
||||
def test_load_with_mismatched_shapes(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Loading nested configs with overwritten `kwargs` isn't supported yet, FIXME @raushan.")
|
||||
def test_mismatched_shapes_have_properly_initialized_weights(self):
|
||||
pass
|
||||
|
||||
def test_automodelforcausallm(self):
|
||||
"""
|
||||
Regression test for #36741/#36917 -- make sure `AutoModelForCausalLM` works with a Gemma3 config, i.e. that
|
||||
|
||||
Reference in New Issue
Block a user