Add support for ModernBertForMultipleChoice (#39232)
* implement ModernBertForMultipleChoice * fixup, style, repo consistency * generate modeling_modernbert * add tests + docs * fix test
This commit is contained in:
@@ -115,6 +115,11 @@ echo -e "Plants create [MASK] through a process known as photosynthesis." | tran
|
|||||||
[[autodoc]] ModernBertForTokenClassification
|
[[autodoc]] ModernBertForTokenClassification
|
||||||
- forward
|
- forward
|
||||||
|
|
||||||
|
## ModernBertForMultipleChoice
|
||||||
|
|
||||||
|
[[autodoc]] ModernBertForMultipleChoice
|
||||||
|
- forward
|
||||||
|
|
||||||
## ModernBertForQuestionAnswering
|
## ModernBertForQuestionAnswering
|
||||||
|
|
||||||
[[autodoc]] ModernBertForQuestionAnswering
|
[[autodoc]] ModernBertForQuestionAnswering
|
||||||
|
|||||||
@@ -1469,6 +1469,7 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
|
|||||||
("mega", "MegaForMultipleChoice"),
|
("mega", "MegaForMultipleChoice"),
|
||||||
("megatron-bert", "MegatronBertForMultipleChoice"),
|
("megatron-bert", "MegatronBertForMultipleChoice"),
|
||||||
("mobilebert", "MobileBertForMultipleChoice"),
|
("mobilebert", "MobileBertForMultipleChoice"),
|
||||||
|
("modernbert", "ModernBertForMultipleChoice"),
|
||||||
("mpnet", "MPNetForMultipleChoice"),
|
("mpnet", "MPNetForMultipleChoice"),
|
||||||
("mra", "MraForMultipleChoice"),
|
("mra", "MraForMultipleChoice"),
|
||||||
("nezha", "NezhaForMultipleChoice"),
|
("nezha", "NezhaForMultipleChoice"),
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ from ...modeling_layers import GradientCheckpointingLayer
|
|||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutput,
|
BaseModelOutput,
|
||||||
MaskedLMOutput,
|
MaskedLMOutput,
|
||||||
|
MultipleChoiceModelOutput,
|
||||||
QuestionAnsweringModelOutput,
|
QuestionAnsweringModelOutput,
|
||||||
SequenceClassifierOutput,
|
SequenceClassifierOutput,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
@@ -605,7 +606,12 @@ class ModernBertPreTrainedModel(PreTrainedModel):
|
|||||||
init_weight(module.decoder, stds["out"])
|
init_weight(module.decoder, stds["out"])
|
||||||
elif isinstance(
|
elif isinstance(
|
||||||
module,
|
module,
|
||||||
(ModernBertForSequenceClassification, ModernBertForTokenClassification, ModernBertForQuestionAnswering),
|
(
|
||||||
|
ModernBertForSequenceClassification,
|
||||||
|
ModernBertForMultipleChoice,
|
||||||
|
ModernBertForTokenClassification,
|
||||||
|
ModernBertForQuestionAnswering,
|
||||||
|
),
|
||||||
):
|
):
|
||||||
init_weight(module.classifier, stds["final_out"])
|
init_weight(module.classifier, stds["final_out"])
|
||||||
elif isinstance(module, nn.LayerNorm):
|
elif isinstance(module, nn.LayerNorm):
|
||||||
@@ -1393,6 +1399,123 @@ class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@auto_docstring(
|
||||||
|
custom_intro="""
|
||||||
|
The ModernBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a softmax) e.g. for RocStories/SWAG tasks.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
class ModernBertForMultipleChoice(ModernBertPreTrainedModel):
|
||||||
|
def __init__(self, config: ModernBertConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.model = ModernBertModel(config)
|
||||||
|
self.head = ModernBertPredictionHead(config)
|
||||||
|
self.drop = torch.nn.Dropout(config.classifier_dropout)
|
||||||
|
self.classifier = nn.Linear(config.hidden_size, 1)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
@auto_docstring
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
sliding_window_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
labels: Optional[torch.Tensor] = None,
|
||||||
|
indices: Optional[torch.Tensor] = None,
|
||||||
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
|
max_seqlen: Optional[int] = None,
|
||||||
|
batch_size: Optional[int] = None,
|
||||||
|
seq_len: Optional[int] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[tuple[torch.Tensor], MultipleChoiceModelOutput]:
|
||||||
|
r"""
|
||||||
|
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
|
||||||
|
perform global attention, while the rest perform local attention. This mask is used to avoid attending to
|
||||||
|
far-away tokens in the local attention layers when not using Flash Attention.
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
|
||||||
|
num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors.
|
||||||
|
indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
|
||||||
|
Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
|
||||||
|
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
|
||||||
|
Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
|
||||||
|
max_seqlen (`int`, *optional*):
|
||||||
|
Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
|
||||||
|
batch_size (`int`, *optional*):
|
||||||
|
Batch size of the input sequences. Used to pad the output tensors.
|
||||||
|
seq_len (`int`, *optional*):
|
||||||
|
Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
|
||||||
|
"""
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
|
||||||
|
|
||||||
|
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
|
||||||
|
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
||||||
|
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
||||||
|
inputs_embeds = (
|
||||||
|
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
|
||||||
|
if inputs_embeds is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
self._maybe_set_compile()
|
||||||
|
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
sliding_window_mask=sliding_window_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
indices=indices,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
max_seqlen=max_seqlen,
|
||||||
|
batch_size=batch_size,
|
||||||
|
seq_len=seq_len,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
last_hidden_state = outputs[0]
|
||||||
|
|
||||||
|
if self.config.classifier_pooling == "cls":
|
||||||
|
last_hidden_state = last_hidden_state[:, 0]
|
||||||
|
elif self.config.classifier_pooling == "mean":
|
||||||
|
last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(
|
||||||
|
dim=1, keepdim=True
|
||||||
|
)
|
||||||
|
|
||||||
|
pooled_output = self.head(last_hidden_state)
|
||||||
|
pooled_output = self.drop(pooled_output)
|
||||||
|
logits = self.classifier(pooled_output)
|
||||||
|
|
||||||
|
reshaped_logits = logits.view(-1, num_choices)
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss_fct = nn.CrossEntropyLoss()
|
||||||
|
loss = loss_fct(reshaped_logits, labels)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (reshaped_logits,) + outputs[1:]
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return MultipleChoiceModelOutput(
|
||||||
|
loss=loss,
|
||||||
|
logits=reshaped_logits,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ModernBertModel",
|
"ModernBertModel",
|
||||||
"ModernBertPreTrainedModel",
|
"ModernBertPreTrainedModel",
|
||||||
@@ -1400,4 +1523,5 @@ __all__ = [
|
|||||||
"ModernBertForSequenceClassification",
|
"ModernBertForSequenceClassification",
|
||||||
"ModernBertForTokenClassification",
|
"ModernBertForTokenClassification",
|
||||||
"ModernBertForQuestionAnswering",
|
"ModernBertForQuestionAnswering",
|
||||||
|
"ModernBertForMultipleChoice",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ from ...modeling_layers import GradientCheckpointingLayer
|
|||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutput,
|
BaseModelOutput,
|
||||||
MaskedLMOutput,
|
MaskedLMOutput,
|
||||||
|
MultipleChoiceModelOutput,
|
||||||
QuestionAnsweringModelOutput,
|
QuestionAnsweringModelOutput,
|
||||||
SequenceClassifierOutput,
|
SequenceClassifierOutput,
|
||||||
TokenClassifierOutput,
|
TokenClassifierOutput,
|
||||||
@@ -805,7 +806,12 @@ class ModernBertPreTrainedModel(PreTrainedModel):
|
|||||||
init_weight(module.decoder, stds["out"])
|
init_weight(module.decoder, stds["out"])
|
||||||
elif isinstance(
|
elif isinstance(
|
||||||
module,
|
module,
|
||||||
(ModernBertForSequenceClassification, ModernBertForTokenClassification, ModernBertForQuestionAnswering),
|
(
|
||||||
|
ModernBertForSequenceClassification,
|
||||||
|
ModernBertForMultipleChoice,
|
||||||
|
ModernBertForTokenClassification,
|
||||||
|
ModernBertForQuestionAnswering,
|
||||||
|
),
|
||||||
):
|
):
|
||||||
init_weight(module.classifier, stds["final_out"])
|
init_weight(module.classifier, stds["final_out"])
|
||||||
elif isinstance(module, nn.LayerNorm):
|
elif isinstance(module, nn.LayerNorm):
|
||||||
@@ -1521,6 +1527,123 @@ class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@auto_docstring(
|
||||||
|
custom_intro="""
|
||||||
|
The ModernBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a softmax) e.g. for RocStories/SWAG tasks.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
class ModernBertForMultipleChoice(ModernBertPreTrainedModel):
|
||||||
|
def __init__(self, config: ModernBertConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.model = ModernBertModel(config)
|
||||||
|
self.head = ModernBertPredictionHead(config)
|
||||||
|
self.drop = torch.nn.Dropout(config.classifier_dropout)
|
||||||
|
self.classifier = nn.Linear(config.hidden_size, 1)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
@auto_docstring
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
sliding_window_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
labels: Optional[torch.Tensor] = None,
|
||||||
|
indices: Optional[torch.Tensor] = None,
|
||||||
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
|
max_seqlen: Optional[int] = None,
|
||||||
|
batch_size: Optional[int] = None,
|
||||||
|
seq_len: Optional[int] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[tuple[torch.Tensor], MultipleChoiceModelOutput]:
|
||||||
|
r"""
|
||||||
|
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
|
||||||
|
perform global attention, while the rest perform local attention. This mask is used to avoid attending to
|
||||||
|
far-away tokens in the local attention layers when not using Flash Attention.
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
|
||||||
|
num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors.
|
||||||
|
indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
|
||||||
|
Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
|
||||||
|
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
|
||||||
|
Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
|
||||||
|
max_seqlen (`int`, *optional*):
|
||||||
|
Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
|
||||||
|
batch_size (`int`, *optional*):
|
||||||
|
Batch size of the input sequences. Used to pad the output tensors.
|
||||||
|
seq_len (`int`, *optional*):
|
||||||
|
Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
|
||||||
|
"""
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
|
||||||
|
|
||||||
|
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
|
||||||
|
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
||||||
|
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
||||||
|
inputs_embeds = (
|
||||||
|
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
|
||||||
|
if inputs_embeds is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
self._maybe_set_compile()
|
||||||
|
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
sliding_window_mask=sliding_window_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
indices=indices,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
max_seqlen=max_seqlen,
|
||||||
|
batch_size=batch_size,
|
||||||
|
seq_len=seq_len,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
last_hidden_state = outputs[0]
|
||||||
|
|
||||||
|
if self.config.classifier_pooling == "cls":
|
||||||
|
last_hidden_state = last_hidden_state[:, 0]
|
||||||
|
elif self.config.classifier_pooling == "mean":
|
||||||
|
last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(
|
||||||
|
dim=1, keepdim=True
|
||||||
|
)
|
||||||
|
|
||||||
|
pooled_output = self.head(last_hidden_state)
|
||||||
|
pooled_output = self.drop(pooled_output)
|
||||||
|
logits = self.classifier(pooled_output)
|
||||||
|
|
||||||
|
reshaped_logits = logits.view(-1, num_choices)
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss_fct = nn.CrossEntropyLoss()
|
||||||
|
loss = loss_fct(reshaped_logits, labels)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (reshaped_logits,) + outputs[1:]
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return MultipleChoiceModelOutput(
|
||||||
|
loss=loss,
|
||||||
|
logits=reshaped_logits,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ModernBertConfig",
|
"ModernBertConfig",
|
||||||
"ModernBertModel",
|
"ModernBertModel",
|
||||||
@@ -1529,4 +1652,5 @@ __all__ = [
|
|||||||
"ModernBertForSequenceClassification",
|
"ModernBertForSequenceClassification",
|
||||||
"ModernBertForTokenClassification",
|
"ModernBertForTokenClassification",
|
||||||
"ModernBertForQuestionAnswering",
|
"ModernBertForQuestionAnswering",
|
||||||
|
"ModernBertForMultipleChoice",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ if is_torch_available():
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
MODEL_FOR_PRETRAINING_MAPPING,
|
MODEL_FOR_PRETRAINING_MAPPING,
|
||||||
ModernBertForMaskedLM,
|
ModernBertForMaskedLM,
|
||||||
|
ModernBertForMultipleChoice,
|
||||||
ModernBertForQuestionAnswering,
|
ModernBertForQuestionAnswering,
|
||||||
ModernBertForSequenceClassification,
|
ModernBertForSequenceClassification,
|
||||||
ModernBertForTokenClassification,
|
ModernBertForTokenClassification,
|
||||||
@@ -202,6 +203,22 @@ class ModernBertModelTester:
|
|||||||
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
|
||||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
|
||||||
|
|
||||||
|
def create_and_check_for_multiple_choice(
|
||||||
|
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
):
|
||||||
|
config.num_labels = self.num_labels
|
||||||
|
model = ModernBertForMultipleChoice(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||||
|
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||||
|
result = model(
|
||||||
|
multiple_choice_inputs_ids,
|
||||||
|
attention_mask=multiple_choice_input_mask,
|
||||||
|
labels=choice_labels,
|
||||||
|
)
|
||||||
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
(
|
(
|
||||||
@@ -227,6 +244,7 @@ class ModernBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
|||||||
ModernBertForSequenceClassification,
|
ModernBertForSequenceClassification,
|
||||||
ModernBertForTokenClassification,
|
ModernBertForTokenClassification,
|
||||||
ModernBertForQuestionAnswering,
|
ModernBertForQuestionAnswering,
|
||||||
|
ModernBertForMultipleChoice,
|
||||||
)
|
)
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else ()
|
else ()
|
||||||
@@ -298,6 +316,7 @@ class ModernBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
|||||||
ModernBertForSequenceClassification,
|
ModernBertForSequenceClassification,
|
||||||
ModernBertForTokenClassification,
|
ModernBertForTokenClassification,
|
||||||
ModernBertForQuestionAnswering,
|
ModernBertForQuestionAnswering,
|
||||||
|
ModernBertForMultipleChoice,
|
||||||
]
|
]
|
||||||
):
|
):
|
||||||
self.assertIn(
|
self.assertIn(
|
||||||
@@ -318,6 +337,10 @@ class ModernBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
|
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_for_multiple_choice(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
|
||||||
|
|
||||||
def test_for_warning_if_padding_and_no_attention_mask(self):
|
def test_for_warning_if_padding_and_no_attention_mask(self):
|
||||||
(
|
(
|
||||||
config,
|
config,
|
||||||
|
|||||||
Reference in New Issue
Block a user