Add TF image classification example script (#19956)
* TF image classification script * Update requirements * Fix up * Add tests * Update test fetcher Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Fix directory path * Adding `zero-shot-object-detection` pipeline doctest. (#20274) * Adding `zero-shot-object-detection` pipeline doctest. * Remove nested_simplify. * Add generate kwargs to `AutomaticSpeechRecognitionPipeline` (#20952) * Add generate kwargs to AutomaticSpeechRecognitionPipeline * Add test for generation kwargs * Trigger CI * Data collator returns np * Update feature extractor -> image processor * Bug fixes - updates to reflect changes in API * Update flags to match PT & run faster * Update instructions - Maria's comment * Update examples/tensorflow/image-classification/README.md * Remove slow decorator --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> Co-authored-by: bofeng huang <bofenghuang7@gmail.com> Co-authored-by: Sylvain Gugger <Sylvain.gugger@gmail.com>
This commit is contained in:
@@ -38,6 +38,7 @@ SRC_DIRS = [
|
||||
"question-answering",
|
||||
"summarization",
|
||||
"translation",
|
||||
"image-classification",
|
||||
]
|
||||
]
|
||||
sys.path.extend(SRC_DIRS)
|
||||
@@ -45,6 +46,7 @@ sys.path.extend(SRC_DIRS)
|
||||
|
||||
if SRC_DIRS is not None:
|
||||
import run_clm
|
||||
import run_image_classification
|
||||
import run_mlm
|
||||
import run_ner
|
||||
import run_qa as run_squad
|
||||
@@ -294,3 +296,28 @@ class ExamplesTests(TestCasePlus):
|
||||
run_translation.main()
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["bleu"], 30)
|
||||
|
||||
def test_run_image_classification(self):
|
||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||
testargs = f"""
|
||||
run_image_classification.py
|
||||
--dataset_name hf-internal-testing/cats_vs_dogs_sample
|
||||
--model_name_or_path microsoft/resnet-18
|
||||
--do_train
|
||||
--do_eval
|
||||
--learning_rate 1e-4
|
||||
--per_device_train_batch_size 2
|
||||
--per_device_eval_batch_size 1
|
||||
--output_dir {tmp_dir}
|
||||
--overwrite_output_dir
|
||||
--dataloader_num_workers 16
|
||||
--num_train_epochs 2
|
||||
--train_val_split 0.1
|
||||
--seed 42
|
||||
--ignore_mismatched_sizes True
|
||||
""".split()
|
||||
|
||||
with patch.object(sys, "argv", testargs):
|
||||
run_image_classification.main()
|
||||
result = get_results(tmp_dir)
|
||||
self.assertGreaterEqual(result["accuracy"], 0.7)
|
||||
|
||||
Reference in New Issue
Block a user