Adding [T5/MT5/UMT5]ForTokenClassification (#28443)
* Adding [T5/MT5/UMT5]ForTokenClassification * Add auto mappings for T5ForTokenClassification and variants * Adding ForTokenClassification to the list of models * Adding attention_mask param to the T5ForTokenClassification test * Remove outdated comment in test * Adding EncoderOnly and Token Classification tests for MT5 and UMT5 * Fix typo in umt5 string * Add tests for all the existing MT5 models * Fix wrong comment in dependency_versions_table * Reverting change to common test for _keys_to_ignore_on_load_missing The test is correctly picking up redundant keys in _keys_to_ignore_on_load_missing. * Removing _keys_to_ignore_on_missing from MT5 since the key is not used in the model * Add fix-copies to MT5ModelTest
This commit is contained in:
@@ -52,6 +52,7 @@ if is_torch_available():
|
||||
T5ForConditionalGeneration,
|
||||
T5ForQuestionAnswering,
|
||||
T5ForSequenceClassification,
|
||||
T5ForTokenClassification,
|
||||
T5Model,
|
||||
T5Tokenizer,
|
||||
)
|
||||
@@ -586,9 +587,11 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
||||
# `QAPipelineTests` is not working well with slow tokenizers (for some models) and we don't want to touch the file
|
||||
# `src/transformers/data/processors/squad.py` (where this test fails for this model)
|
||||
def is_pipeline_test_to_skip(
|
||||
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
self, pipeline_test_case_name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
):
|
||||
if pipeline_test_casse_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
|
||||
if tokenizer_name is None:
|
||||
return True
|
||||
if pipeline_test_case_name == "QAPipelineTests" and not tokenizer_name.endswith("Fast"):
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -998,6 +1001,22 @@ class T5EncoderOnlyModelTester:
|
||||
output = model(input_ids, attention_mask=attention_mask)["last_hidden_state"]
|
||||
self.parent.assertFalse(torch.isnan(output).any().item())
|
||||
|
||||
def create_and_check_with_token_classification_head(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
):
|
||||
labels = torch.tensor([1] * self.seq_length * self.batch_size, dtype=torch.long, device=torch_device)
|
||||
model = T5ForTokenClassification(config=config).to(torch_device).eval()
|
||||
outputs = model(
|
||||
input_ids=input_ids,
|
||||
labels=labels,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.seq_length, config.num_labels))
|
||||
self.parent.assertEqual(outputs["loss"].size(), ())
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
@@ -1013,11 +1032,18 @@ class T5EncoderOnlyModelTester:
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
class T5EncoderOnlyModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (T5EncoderModel,) if is_torch_available() else ()
|
||||
class T5EncoderOnlyModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (T5EncoderModel, T5ForTokenClassification) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_model_parallel = True
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"token-classification": T5ForTokenClassification,
|
||||
}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
all_parallelizable_model_classes = (T5EncoderModel,) if is_torch_available() else ()
|
||||
|
||||
def setUp(self):
|
||||
@@ -1036,6 +1062,10 @@ class T5EncoderOnlyModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs)
|
||||
|
||||
def test_with_token_classification_head(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_with_token_classification_head(*config_and_inputs)
|
||||
|
||||
|
||||
def use_task_specific_params(model, task):
|
||||
model.config.update(model.config.task_specific_params[task])
|
||||
|
||||
Reference in New Issue
Block a user