Add examples for detection models finetuning (#30422)
* Training script for object detection * Evaluation script for object detection * Training script for object detection with eval loop outside trainer * Trainer DETR finetuning * No trainer DETR finetuning * Eval script * Refine object detection example with trainer * Remove commented code and enable telemetry * No trainer example * Add requirements for object detection examples * Add test for trainer example * Readme draft * Fix uploading to HUB * Readme improvements * Update eval script * Adding tests for object-detection examples * Add object-detection example * Add object-detection resources to docs * Update README with custom dataset instructions * Update year * Replace valid with validation * Update instructions for custom dataset * Remove eval script * Remove use_auth_token * Add copied from and telemetry * Fixup * Update readme * Fix id2label * Fix links in docs * Update examples/pytorch/object-detection/run_object_detection.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update examples/pytorch/object-detection/run_object_detection.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Move description to the top * Fix Trainer example * Update no trainer example * Update albumentations version --------- Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
508c0bfe55
commit
998dbe068b
@@ -48,6 +48,7 @@ SRC_DIRS = [
|
||||
"speech-pretraining",
|
||||
"image-pretraining",
|
||||
"semantic-segmentation",
|
||||
"object-detection",
|
||||
]
|
||||
]
|
||||
sys.path.extend(SRC_DIRS)
|
||||
@@ -62,6 +63,7 @@ if SRC_DIRS is not None:
|
||||
import run_mae
|
||||
import run_mlm
|
||||
import run_ner
|
||||
import run_object_detection
|
||||
import run_qa as run_squad
|
||||
import run_semantic_segmentation
|
||||
import run_seq2seq_qa as run_squad_seq2seq
|
||||
@@ -609,3 +611,31 @@ class ExamplesTests(TestCasePlus):
|
||||
run_semantic_segmentation.main()
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["eval_overall_accuracy"], 0.1)
|
||||
|
||||
@patch.dict(os.environ, {"WANDB_DISABLED": "true"})
|
||||
def test_run_object_detection(self):
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_object_detection.py
|
||||
--model_name_or_path qubvel-hf/detr-resnet-50-finetuned-10k-cppe5
|
||||
--output_dir {tmp_dir}
|
||||
--dataset_name qubvel-hf/cppe-5-sample
|
||||
--do_train
|
||||
--do_eval
|
||||
--remove_unused_columns False
|
||||
--overwrite_output_dir True
|
||||
--eval_do_concat_batches False
|
||||
--max_steps 10
|
||||
--learning_rate=1e-6
|
||||
--per_device_train_batch_size=2
|
||||
--per_device_eval_batch_size=1
|
||||
--seed 32
|
||||
""".split()
|
||||
|
||||
if is_torch_fp16_available_on_device(torch_device):
|
||||
testargs.append("--fp16")
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
run_object_detection.main()
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["test_map"], 0.1)
|
||||
|
||||
Reference in New Issue
Block a user