Add TF image classification example script (#19956)

* TF image classification script

* Update requirements

* Fix up

* Add tests

* Update test fetcher
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Fix directory path

* Adding `zero-shot-object-detection` pipeline doctest. (#20274)

* Adding `zero-shot-object-detection` pipeline doctest.

* Remove nested_simplify.

* Add generate kwargs to `AutomaticSpeechRecognitionPipeline` (#20952)

* Add generate kwargs to AutomaticSpeechRecognitionPipeline

* Add test for generation kwargs

* Trigger CI

* Data collator returns np

* Update feature extractor -> image processor

* Bug fixes - updates to reflect changes in API

* Update flags to match PT & run faster

* Update instructions - Maria's comment

* Update examples/tensorflow/image-classification/README.md

* Remove slow decorator

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
Co-authored-by: bofeng huang <bofenghuang7@gmail.com>
Co-authored-by: Sylvain Gugger <Sylvain.gugger@gmail.com>
This commit is contained in:
amyeroberts
2023-02-01 19:09:36 +00:00
committed by GitHub
parent 3fadb4b211
commit e5db7051a8
5 changed files with 763 additions and 0 deletions

View File

@@ -38,6 +38,7 @@ SRC_DIRS = [
"question-answering",
"summarization",
"translation",
"image-classification",
]
]
sys.path.extend(SRC_DIRS)
@@ -45,6 +46,7 @@ sys.path.extend(SRC_DIRS)
if SRC_DIRS is not None:
import run_clm
import run_image_classification
import run_mlm
import run_ner
import run_qa as run_squad
@@ -294,3 +296,28 @@ class ExamplesTests(TestCasePlus):
run_translation.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["bleu"], 30)
def test_run_image_classification(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_image_classification.py
--dataset_name hf-internal-testing/cats_vs_dogs_sample
--model_name_or_path microsoft/resnet-18
--do_train
--do_eval
--learning_rate 1e-4
--per_device_train_batch_size 2
--per_device_eval_batch_size 1
--output_dir {tmp_dir}
--overwrite_output_dir
--dataloader_num_workers 16
--num_train_epochs 2
--train_val_split 0.1
--seed 42
--ignore_mismatched_sizes True
""".split()
with patch.object(sys, "argv", testargs):
run_image_classification.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["accuracy"], 0.7)