Update FSNER code in examples->research_projects->fsner (#13864)
* Add example use of few-shot named entity recognition model in research_projects folder. * Apply suggestions from code review Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update fsner example README.md. - Change wrong import FSNERTokenizerWrapper to FSNERTokenizerUtils in the example code - Add a link to the model identifier * Update examples/research_projects/fsner/src/fsner/model.py Fix spelling mistake in the default parameter of pretrained model name. Co-authored-by: Stefan Schweter <stefan@schweter.it> * Add example use of few-shot named entity recognition model in research_projects folder. * Apply suggestions from code review Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update fsner example README.md. - Change wrong import FSNERTokenizerWrapper to FSNERTokenizerUtils in the example code - Add a link to the model identifier * Update examples/research_projects/fsner/src/fsner/model.py Fix spelling mistake in the default parameter of pretrained model name. Co-authored-by: Stefan Schweter <stefan@schweter.it> * Run Checking/fixing examples/flax/language-modeling/run_clm_flax.py examples/flax/question-answering/run_qa.py examples/flax/question-answering/utils_qa.py examples/flax/token-classification/run_flax_ner.py examples/legacy/multiple_choice/utils_multiple_choice.py examples/legacy/seq2seq/seq2seq_trainer.py examples/legacy/token-classification/utils_ner.py examples/pytorch/image-classification/run_image_classification.py examples/pytorch/language-modeling/run_clm.py examples/pytorch/language-modeling/run_clm_no_trainer.py examples/pytorch/language-modeling/run_mlm.py examples/pytorch/language-modeling/run_mlm_no_trainer.py examples/pytorch/language-modeling/run_plm.py examples/pytorch/multiple-choice/run_swag.py examples/pytorch/multiple-choice/run_swag_no_trainer.py examples/pytorch/question-answering/run_qa.py examples/pytorch/question-answering/run_qa_beam_search.py examples/pytorch/question-answering/run_qa_beam_search_no_trainer.py examples/pytorch/question-answering/run_qa_no_trainer.py examples/pytorch/summarization/run_summarization.py examples/pytorch/summarization/run_summarization_no_trainer.py examples/pytorch/test_examples.py examples/pytorch/text-classification/run_glue.py examples/pytorch/text-classification/run_glue_no_trainer.py examples/pytorch/text-classification/run_xnli.py examples/pytorch/token-classification/run_ner.py examples/pytorch/token-classification/run_ner_no_trainer.py examples/pytorch/translation/run_translation.py examples/pytorch/translation/run_translation_no_trainer.py examples/research_projects/adversarial/utils_hans.py examples/research_projects/distillation/grouped_batch_sampler.py examples/research_projects/fsner/setup.py examples/research_projects/fsner/src/fsner/__init__.py examples/research_projects/fsner/src/fsner/model.py examples/research_projects/fsner/src/fsner/tokenizer_utils.py examples/research_projects/jax-projects/big_bird/evaluate.py examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py examples/tensorflow/language-modeling/run_clm.py examples/tensorflow/multiple-choice/run_swag.py examples/tensorflow/question-answering/run_qa.py examples/tensorflow/summarization/run_summarization.py examples/tensorflow/text-classification/run_glue.py examples/tensorflow/translation/run_translation.py src/transformers/__init__.py src/transformers/commands/add_new_model.py src/transformers/configuration_utils.py src/transformers/convert_slow_tokenizer.py src/transformers/data/__init__.py src/transformers/data/data_collator.py src/transformers/data/datasets/glue.py src/transformers/data/datasets/language_modeling.py src/transformers/data/datasets/squad.py src/transformers/deepspeed.py src/transformers/dependency_versions_table.py src/transformers/feature_extraction_sequence_utils.py src/transformers/file_utils.py src/transformers/generation_flax_utils.py src/transformers/generation_logits_process.py src/transformers/generation_tf_utils.py src/transformers/generation_utils.py src/transformers/integrations.py src/transformers/modelcard.py src/transformers/modeling_flax_utils.py src/transformers/modeling_outputs.py src/transformers/modeling_tf_utils.py src/transformers/modeling_utils.py src/transformers/models/__init__.py src/transformers/models/albert/__init__.py src/transformers/models/albert/modeling_albert.py src/transformers/models/albert/modeling_flax_albert.py src/transformers/models/albert/tokenization_albert_fast.py src/transformers/models/auto/__init__.py src/transformers/models/auto/auto_factory.py src/transformers/models/auto/configuration_auto.py src/transformers/models/auto/dynamic.py src/transformers/models/auto/feature_extraction_auto.py src/transformers/models/auto/modeling_auto.py src/transformers/models/auto/modeling_flax_auto.py src/transformers/models/auto/modeling_tf_auto.py src/transformers/models/auto/tokenization_auto.py src/transformers/models/bart/configuration_bart.py src/transformers/models/bart/modeling_bart.py src/transformers/models/bart/modeling_flax_bart.py src/transformers/models/bart/modeling_tf_bart.py src/transformers/models/barthez/tokenization_barthez_fast.py src/transformers/models/beit/__init__.py src/transformers/models/beit/configuration_beit.py src/transformers/models/beit/modeling_beit.py src/transformers/models/beit/modeling_flax_beit.py src/transformers/models/bert/configuration_bert.py src/transformers/models/bert/modeling_bert.py src/transformers/models/bert/modeling_flax_bert.py src/transformers/models/bert_generation/configuration_bert_generation.py src/transformers/models/bert_generation/modeling_bert_generation.py src/transformers/models/big_bird/configuration_big_bird.py src/transformers/models/big_bird/modeling_big_bird.py src/transformers/models/big_bird/modeling_flax_big_bird.py src/transformers/models/big_bird/tokenization_big_bird_fast.py src/transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py src/transformers/models/blenderbot/configuration_blenderbot.py src/transformers/models/blenderbot/modeling_blenderbot.py src/transformers/models/blenderbot/modeling_tf_blenderbot.py src/transformers/models/blenderbot_small/configuration_blenderbot_small.py src/transformers/models/blenderbot_small/modeling_blenderbot_small.py src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py src/transformers/models/byt5/tokenization_byt5.py src/transformers/models/camembert/tokenization_camembert_fast.py src/transformers/models/canine/configuration_canine.py src/transformers/models/canine/modeling_canine.py src/transformers/models/clip/configuration_clip.py src/transformers/models/clip/convert_clip_original_pytorch_to_hf.py src/transformers/models/clip/modeling_clip.py src/transformers/models/clip/modeling_flax_clip.py src/transformers/models/clip/tokenization_clip.py src/transformers/models/convbert/modeling_convbert.py src/transformers/models/ctrl/configuration_ctrl.py src/transformers/models/deberta/modeling_tf_deberta.py src/transformers/models/deberta_v2/__init__.py src/transformers/models/deberta_v2/modeling_deberta_v2.py src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py src/transformers/models/deit/configuration_deit.py src/transformers/models/deit/modeling_deit.py src/transformers/models/detr/configuration_detr.py src/transformers/models/detr/modeling_detr.py src/transformers/models/distilbert/__init__.py src/transformers/models/distilbert/configuration_distilbert.py src/transformers/models/distilbert/modeling_distilbert.py src/transformers/models/distilbert/modeling_flax_distilbert.py src/transformers/models/dpr/configuration_dpr.py src/transformers/models/dpr/modeling_dpr.py src/transformers/models/electra/modeling_electra.py src/transformers/models/electra/modeling_flax_electra.py src/transformers/models/encoder_decoder/__init__.py src/transformers/models/encoder_decoder/modeling_encoder_decoder.py src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py src/transformers/models/flaubert/configuration_flaubert.py src/transformers/models/flaubert/modeling_flaubert.py src/transformers/models/fnet/__init__.py src/transformers/models/fnet/configuration_fnet.py src/transformers/models/fnet/convert_fnet_original_flax_checkpoint_to_pytorch.py src/transformers/models/fnet/modeling_fnet.py src/transformers/models/fnet/tokenization_fnet.py src/transformers/models/fnet/tokenization_fnet_fast.py src/transformers/models/fsmt/configuration_fsmt.py src/transformers/models/fsmt/modeling_fsmt.py src/transformers/models/funnel/configuration_funnel.py src/transformers/models/gpt2/__init__.py src/transformers/models/gpt2/configuration_gpt2.py src/transformers/models/gpt2/modeling_flax_gpt2.py src/transformers/models/gpt2/modeling_gpt2.py src/transformers/models/gpt2/modeling_tf_gpt2.py src/transformers/models/gpt_neo/configuration_gpt_neo.py src/transformers/models/gpt_neo/modeling_gpt_neo.py src/transformers/models/gptj/__init__.py src/transformers/models/gptj/configuration_gptj.py src/transformers/models/gptj/modeling_gptj.py src/transformers/models/herbert/tokenization_herbert_fast.py src/transformers/models/hubert/__init__.py src/transformers/models/hubert/configuration_hubert.py src/transformers/models/hubert/convert_hubert_original_s3prl_checkpoint_to_pytorch.py src/transformers/models/hubert/modeling_hubert.py src/transformers/models/hubert/modeling_tf_hubert.py src/transformers/models/ibert/modeling_ibert.py src/transformers/models/layoutlm/__init__.py src/transformers/models/layoutlm/configuration_layoutlm.py src/transformers/models/layoutlm/modeling_layoutlm.py src/transformers/models/layoutlmv2/__init__.py src/transformers/models/layoutlmv2/configuration_layoutlmv2.py src/transformers/models/layoutlmv2/feature_extraction_layoutlmv2.py src/transformers/models/layoutlmv2/modeling_layoutlmv2.py src/transformers/models/layoutlmv2/processing_layoutlmv2.py src/transformers/models/layoutlmv2/tokenization_layoutlmv2.py src/transformers/models/layoutlmv2/tokenization_layoutlmv2_fast.py src/transformers/models/led/configuration_led.py src/transformers/models/led/modeling_led.py src/transformers/models/longformer/modeling_longformer.py src/transformers/models/luke/configuration_luke.py src/transformers/models/luke/modeling_luke.py src/transformers/models/luke/tokenization_luke.py src/transformers/models/lxmert/configuration_lxmert.py src/transformers/models/m2m_100/configuration_m2m_100.py src/transformers/models/m2m_100/modeling_m2m_100.py src/transformers/models/m2m_100/tokenization_m2m_100.py src/transformers/models/marian/configuration_marian.py src/transformers/models/marian/modeling_flax_marian.py src/transformers/models/marian/modeling_marian.py src/transformers/models/marian/modeling_tf_marian.py src/transformers/models/mbart/configuration_mbart.py src/transformers/models/mbart/modeling_flax_mbart.py src/transformers/models/mbart/modeling_mbart.py src/transformers/models/mbart/tokenization_mbart.py src/transformers/models/mbart/tokenization_mbart_fast.py src/transformers/models/mbart50/tokenization_mbart50.py src/transformers/models/mbart50/tokenization_mbart50_fast.py src/transformers/models/megatron_bert/configuration_megatron_bert.py src/transformers/models/megatron_bert/convert_megatron_bert_checkpoint.py src/transformers/models/megatron_bert/modeling_megatron_bert.py src/transformers/models/megatron_gpt2/convert_megatron_gpt2_checkpoint.py src/transformers/models/openai/configuration_openai.py src/transformers/models/pegasus/__init__.py src/transformers/models/pegasus/configuration_pegasus.py src/transformers/models/pegasus/modeling_flax_pegasus.py src/transformers/models/pegasus/modeling_pegasus.py src/transformers/models/pegasus/modeling_tf_pegasus.py src/transformers/models/pegasus/tokenization_pegasus_fast.py src/transformers/models/prophetnet/configuration_prophetnet.py src/transformers/models/prophetnet/modeling_prophetnet.py src/transformers/models/rag/modeling_rag.py src/transformers/models/rag/modeling_tf_rag.py src/transformers/models/reformer/configuration_reformer.py src/transformers/models/reformer/tokenization_reformer_fast.py src/transformers/models/rembert/configuration_rembert.py src/transformers/models/rembert/modeling_rembert.py src/transformers/models/rembert/tokenization_rembert_fast.py src/transformers/models/roberta/modeling_flax_roberta.py src/transformers/models/roberta/modeling_roberta.py src/transformers/models/roberta/modeling_tf_roberta.py src/transformers/models/roformer/configuration_roformer.py src/transformers/models/roformer/modeling_roformer.py src/transformers/models/speech_encoder_decoder/__init__.py src/transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py src/transformers/models/speech_encoder_decoder/convert_speech_to_text_wav2vec2_seq2seq_original_to_pytorch.py src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py src/transformers/models/speech_to_text/configuration_speech_to_text.py src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py src/transformers/models/speech_to_text/modeling_speech_to_text.py src/transformers/models/speech_to_text_2/__init__.py src/transformers/models/speech_to_text_2/configuration_speech_to_text_2.py src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py src/transformers/models/speech_to_text_2/processing_speech_to_text_2.py src/transformers/models/speech_to_text_2/tokenization_speech_to_text_2.py src/transformers/models/splinter/configuration_splinter.py src/transformers/models/splinter/modeling_splinter.py src/transformers/models/t5/configuration_t5.py src/transformers/models/t5/modeling_flax_t5.py src/transformers/models/t5/modeling_t5.py src/transformers/models/t5/modeling_tf_t5.py src/transformers/models/t5/tokenization_t5_fast.py src/transformers/models/tapas/__init__.py src/transformers/models/tapas/configuration_tapas.py src/transformers/models/tapas/convert_tapas_original_tf_checkpoint_to_pytorch.py src/transformers/models/tapas/modeling_tapas.py src/transformers/models/tapas/tokenization_tapas.py src/transformers/models/transfo_xl/configuration_transfo_xl.py src/transformers/models/visual_bert/modeling_visual_bert.py src/transformers/models/vit/configuration_vit.py src/transformers/models/vit/convert_dino_to_pytorch.py src/transformers/models/vit/modeling_flax_vit.py src/transformers/models/vit/modeling_vit.py src/transformers/models/wav2vec2/__init__.py src/transformers/models/wav2vec2/configuration_wav2vec2.py src/transformers/models/wav2vec2/convert_wav2vec2_original_s3prl_checkpoint_to_pytorch.py src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py src/transformers/models/wav2vec2/modeling_wav2vec2.py src/transformers/models/wav2vec2/tokenization_wav2vec2.py src/transformers/models/xlm/configuration_xlm.py src/transformers/models/xlm_roberta/tokenization_xlm_roberta.py src/transformers/models/xlm_roberta/tokenization_xlm_roberta_fast.py src/transformers/models/xlnet/configuration_xlnet.py src/transformers/models/xlnet/tokenization_xlnet_fast.py src/transformers/onnx/convert.py src/transformers/onnx/features.py src/transformers/optimization.py src/transformers/pipelines/__init__.py src/transformers/pipelines/audio_classification.py src/transformers/pipelines/automatic_speech_recognition.py src/transformers/pipelines/base.py src/transformers/pipelines/conversational.py src/transformers/pipelines/feature_extraction.py src/transformers/pipelines/fill_mask.py src/transformers/pipelines/image_classification.py src/transformers/pipelines/object_detection.py src/transformers/pipelines/question_answering.py src/transformers/pipelines/table_question_answering.py src/transformers/pipelines/text2text_generation.py src/transformers/pipelines/text_classification.py src/transformers/pipelines/text_generation.py src/transformers/pipelines/token_classification.py src/transformers/pipelines/zero_shot_classification.py src/transformers/testing_utils.py src/transformers/tokenization_utils.py src/transformers/tokenization_utils_base.py src/transformers/tokenization_utils_fast.py src/transformers/trainer.py src/transformers/trainer_callback.py src/transformers/trainer_pt_utils.py src/transformers/trainer_seq2seq.py src/transformers/trainer_utils.py src/transformers/training_args.py src/transformers/training_args_seq2seq.py src/transformers/utils/dummy_detectron2_objects.py src/transformers/utils/dummy_flax_objects.py src/transformers/utils/dummy_pt_objects.py src/transformers/utils/dummy_tf_objects.py src/transformers/utils/dummy_tokenizers_objects.py src/transformers/utils/dummy_vision_objects.py tests/deepspeed/test_deepspeed.py tests/sagemaker/conftest.py tests/sagemaker/test_multi_node_data_parallel.py tests/test_configuration_auto.py tests/test_configuration_common.py tests/test_data_collator.py tests/test_feature_extraction_auto.py tests/test_feature_extraction_layoutlmv2.py tests/test_feature_extraction_speech_to_text.py tests/test_feature_extraction_wav2vec2.py tests/test_file_utils.py tests/test_modeling_auto.py tests/test_modeling_bart.py tests/test_modeling_beit.py tests/test_modeling_bert.py tests/test_modeling_clip.py tests/test_modeling_common.py tests/test_modeling_convbert.py tests/test_modeling_deit.py tests/test_modeling_distilbert.py tests/test_modeling_encoder_decoder.py tests/test_modeling_flaubert.py tests/test_modeling_flax_albert.py tests/test_modeling_flax_bart.py tests/test_modeling_flax_beit.py tests/test_modeling_flax_distilbert.py tests/test_modeling_flax_encoder_decoder.py tests/test_modeling_flax_gpt2.py tests/test_modeling_flax_gpt_neo.py tests/test_modeling_flax_mt5.py tests/test_modeling_flax_pegasus.py tests/test_modeling_fnet.py tests/test_modeling_gpt2.py tests/test_modeling_gpt_neo.py tests/test_modeling_gptj.py tests/test_modeling_hubert.py tests/test_modeling_layoutlmv2.py tests/test_modeling_pegasus.py tests/test_modeling_rag.py tests/test_modeling_reformer.py tests/test_modeling_speech_encoder_decoder.py tests/test_modeling_speech_to_text.py tests/test_modeling_speech_to_text_2.py tests/test_modeling_tf_auto.py tests/test_modeling_tf_deberta_v2.py tests/test_modeling_tf_hubert.py tests/test_modeling_tf_pytorch.py tests/test_modeling_tf_wav2vec2.py tests/test_modeling_wav2vec2.py tests/test_onnx_v2.py tests/test_pipelines_audio_classification.py tests/test_pipelines_automatic_speech_recognition.py tests/test_pipelines_common.py tests/test_pipelines_conversational.py tests/test_pipelines_feature_extraction.py tests/test_pipelines_fill_mask.py tests/test_pipelines_image_classification.py tests/test_pipelines_object_detection.py tests/test_pipelines_question_answering.py tests/test_pipelines_summarization.py tests/test_pipelines_table_question_answering.py tests/test_pipelines_text2text_generation.py tests/test_pipelines_text_classification.py tests/test_pipelines_text_generation.py tests/test_pipelines_token_classification.py tests/test_pipelines_translation.py tests/test_pipelines_zero_shot.py tests/test_processor_layoutlmv2.py tests/test_processor_wav2vec2.py tests/test_sequence_feature_extraction_common.py tests/test_tokenization_auto.py tests/test_tokenization_byt5.py tests/test_tokenization_canine.py tests/test_tokenization_common.py tests/test_tokenization_fnet.py tests/test_tokenization_layoutlmv2.py tests/test_tokenization_luke.py tests/test_tokenization_mbart.py tests/test_tokenization_mbart50.py tests/test_tokenization_speech_to_text_2.py tests/test_tokenization_t5.py tests/test_tokenization_tapas.py tests/test_tokenization_xlm_roberta.py tests/test_trainer.py tests/test_trainer_distributed.py tests/test_trainer_tpu.py tests/test_utils_check_copies.py utils/check_copies.py utils/check_repo.py utils/notification_service.py utils/release.py utils/tests_fetcher.py python utils/custom_init_isort.py python utils/style_doc.py src/transformers docs/source --max_len 119 running deps_table_update updating src/transformers/dependency_versions_table.py python utils/check_copies.py python utils/check_table.py python utils/check_dummies.py python utils/check_repo.py Checking all models are public. Checking all models are properly tested. Checking all objects are properly documented. Checking all models are in at least one auto class. python utils/check_inits.py python utils/tests_fetcher.py --sanity_check and fix suggested changes. * Run black examples tests src utils isort examples tests src utils Skipped 1 files make autogenerate_code make[1]: Entering directory '/mnt/c/Users/Admin/Desktop/Home/Projects/transformers' running deps_table_update updating src/transformers/dependency_versions_table.py make[1]: Leaving directory '/mnt/c/Users/Admin/Desktop/Home/Projects/transformers' make extra_style_checks make[1]: Entering directory '/mnt/c/Users/Admin/Desktop/Home/Projects/transformers' python utils/custom_init_isort.py python utils/style_doc.py src/transformers docs/source --max_len 119 make[1]: Leaving directory '/mnt/c/Users/Admin/Desktop/Home/Projects/transformers' for reformatting code. * Add installation dependencies for examples/research_projects/fsner. * Add support to pass in variable numbers of examples to FSNER model. * Retrieve start_token_id and end_token_id from tokenizer instead of hardcoding in the FSNER model. * Run black examples tests src utils isort examples tests src utils Skipped 1 files make autogenerate_code make[1]: Entering directory '/home/saif/transformers' running deps_table_update updating src/transformers/dependency_versions_table.py make[1]: Leaving directory '/home/saif/transformers' make extra_style_checks make[1]: Entering directory '/home/saif/transformers' python utils/custom_init_isort.py python utils/style_doc.py src/transformers docs/source --max_len 119 make[1]: Leaving directory '/home/saif/transformers' for FSNER * Update FSNER readme.md with a header image. * Update FSNER readme Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Co-authored-by: Stefan Schweter <stefan@schweter.it>
This commit is contained in:
committed by
GitHub
parent
e7b16f33ae
commit
155b23008e
@@ -1,6 +1,8 @@
|
||||
# FSNER
|
||||
<p align="center"> <img src="http://sayef.tech:8082/uploads/FSNER-LOGO-2.png" alt="FSNER LOGO"> </p>
|
||||
|
||||
Implemented by [sayef](https://huggingface.co/sayef).
|
||||
<p align="center">
|
||||
Implemented by <a href="https://huggingface.co/sayef"> sayef </a>.
|
||||
</p>
|
||||
|
||||
## Overview
|
||||
|
||||
@@ -25,13 +27,17 @@ The FSNER model was proposed in [Example-Based Named Entity Recognition](https:/
|
||||
## Installation and Example Usage
|
||||
------
|
||||
|
||||
You can use the FSNER model in two ways:
|
||||
You can use the FSNER model in 3 ways:
|
||||
|
||||
1. Install as a package: `python setup.py install` and import the model as shown in the code example below
|
||||
1. Install directly from PyPI: `pip install fsner` and import the model as shown in the code example below
|
||||
|
||||
or
|
||||
|
||||
2. Change directory to `src` and import the model as shown in the code example below
|
||||
2. Install from source: `python setup.py install` and import the model as shown in the code example below
|
||||
|
||||
or
|
||||
|
||||
3. Clone repo and change directory to `src` and import the model as shown in the code example below
|
||||
|
||||
|
||||
|
||||
@@ -65,7 +71,6 @@ supports = [
|
||||
"[E] Walmart [/E] is a leading e-commerce company",
|
||||
"I recently ordered a book from [E] Amazon [/E]",
|
||||
"I ordered this from [E] ShopClues [/E]",
|
||||
"Fridge can be ordered in [E] Amazon [/E]",
|
||||
"[E] Flipkart [/E] started it's journey from zero"
|
||||
]
|
||||
]
|
||||
@@ -73,7 +78,7 @@ supports = [
|
||||
device = 'cpu'
|
||||
|
||||
W_query = tokenizer.tokenize(query).to(device)
|
||||
W_supports = tokenizer.tokenize([s for support in supports for s in support]).to(device)
|
||||
W_supports = tokenizer.tokenize(supports).to(device)
|
||||
|
||||
start_prob, end_prob = model(W_query, W_supports)
|
||||
|
||||
|
||||
@@ -40,22 +40,41 @@ class FSNERModel(torch.nn.Module):
|
||||
p_end (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): Scores of each token as
|
||||
being end token of an entity
|
||||
"""
|
||||
|
||||
support_sizes = W_supports["sizes"].tolist()
|
||||
start_token_id = W_supports["start_token_id"].item()
|
||||
end_token_id = W_supports["end_token_id"].item()
|
||||
|
||||
del W_supports["sizes"]
|
||||
del W_supports["start_token_id"]
|
||||
del W_supports["end_token_id"]
|
||||
|
||||
q = self.BERT(**W_query)
|
||||
S = self.BERT(**W_supports)
|
||||
|
||||
# reshape from (batch_size, 384, 784) to (batch_size, 1, 384, 784)
|
||||
q = q.view(q.shape[0], -1, q.shape[1], q.shape[2])
|
||||
p_starts = None
|
||||
p_ends = None
|
||||
|
||||
# reshape from (batch_size*n_exaples_per_entity, 384, 784) to (batch_size, n_exaples_per_entity, 384, 784)
|
||||
S = S.view(q.shape[0], -1, S.shape[1], S.shape[2])
|
||||
start_token_masks = W_supports["input_ids"] == start_token_id
|
||||
end_token_masks = W_supports["input_ids"] == end_token_id
|
||||
|
||||
s_start = S[(W_supports["input_ids"] == 30522).view(S.shape[:3])].view(S.shape[0], -1, 1, S.shape[-1])
|
||||
s_end = S[(W_supports["input_ids"] == 30523).view(S.shape[:3])].view(S.shape[0], -1, 1, S.shape[-1])
|
||||
for i, size in enumerate(support_sizes):
|
||||
if i == 0:
|
||||
s = 0
|
||||
else:
|
||||
s = support_sizes[i - 1]
|
||||
|
||||
p_start = torch.sum(torch.einsum("bitf,bejf->bet", q, s_start), dim=1)
|
||||
p_end = torch.sum(torch.einsum("bitf,bejf->bet", q, s_end), dim=1)
|
||||
s_start = S[s : s + size][start_token_masks[s : s + size]]
|
||||
s_end = S[s : s + size][end_token_masks[s : s + size]]
|
||||
|
||||
p_start = p_start.softmax(dim=1)
|
||||
p_end = p_end.softmax(dim=1)
|
||||
p_start = torch.matmul(q[i], s_start.T).sum(1).softmax(0)
|
||||
p_end = torch.matmul(q[i], s_end.T).sum(1).softmax(0)
|
||||
|
||||
return p_start, p_end
|
||||
if p_starts is not None:
|
||||
p_starts = torch.vstack((p_starts, p_start))
|
||||
p_ends = torch.vstack((p_ends, p_end))
|
||||
else:
|
||||
p_starts = p_start
|
||||
p_ends = p_end
|
||||
|
||||
return p_starts, p_ends
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import torch
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
@@ -6,7 +8,50 @@ class FSNERTokenizerUtils(object):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
|
||||
|
||||
def tokenize(self, x):
|
||||
return self.tokenizer(x, padding="max_length", max_length=384, truncation=True, return_tensors="pt")
|
||||
"""
|
||||
Wrapper function for tokenizing query and supports
|
||||
Args:
|
||||
x (`List[str] or List[List[str]]`):
|
||||
List of strings for query or list of lists of strings for supports.
|
||||
Returns:
|
||||
`transformers.tokenization_utils_base.BatchEncoding` dict with additional keys and values for start_token_id, end_token_id and sizes of example lists for each entity type
|
||||
"""
|
||||
|
||||
if isinstance(x, list) and all([isinstance(_x, list) for _x in x]):
|
||||
d = None
|
||||
for l in x:
|
||||
t = self.tokenizer(
|
||||
l,
|
||||
padding="max_length",
|
||||
max_length=384,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
t["sizes"] = torch.tensor([len(l)])
|
||||
if d is not None:
|
||||
for k in d.keys():
|
||||
d[k] = torch.cat((d[k], t[k]), 0)
|
||||
else:
|
||||
d = t
|
||||
|
||||
d["start_token_id"] = torch.tensor(self.tokenizer.convert_tokens_to_ids("[E]"))
|
||||
d["end_token_id"] = torch.tensor(self.tokenizer.convert_tokens_to_ids("[/E]"))
|
||||
|
||||
elif isinstance(x, list) and all([isinstance(_x, str) for _x in x]):
|
||||
d = self.tokenizer(
|
||||
x,
|
||||
padding="max_length",
|
||||
max_length=384,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
else:
|
||||
raise Exception(
|
||||
"Type of parameter x was not recognized! Only `list of strings` for query or `list of lists of strings` for supports are supported."
|
||||
)
|
||||
|
||||
return d
|
||||
|
||||
def extract_entity_from_scores(self, query, W_query, p_start, p_end, thresh=0.70):
|
||||
"""
|
||||
@@ -34,7 +79,14 @@ class FSNERTokenizerUtils(object):
|
||||
for start_id in start_indexes:
|
||||
for end_id in end_indexes:
|
||||
if start_id < end_id:
|
||||
output.append((start_id, end_id, p_start[idx][start_id].item(), p_end[idx][end_id].item()))
|
||||
output.append(
|
||||
(
|
||||
start_id,
|
||||
end_id,
|
||||
p_start[idx][start_id].item(),
|
||||
p_end[idx][end_id].item(),
|
||||
)
|
||||
)
|
||||
|
||||
output.sort(key=lambda tup: (tup[2] * tup[3]), reverse=True)
|
||||
temp = []
|
||||
|
||||
Reference in New Issue
Block a user