Resnet flax (#21472)

* [WIP] flax resnet

* added pretrained flax models, results reproducible

* Added pretrained flax models, results reproducible

* working on tests

* no real code change, just some comments

* [flax] adding support for batch norm layers

* fixing bugs related to pt+flax integration

* removing loss from modeling flax output class

* fixing classifier tests

* fixing comments, model output

* cleaning comments

* review changes

* review changes

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* renaming Flax to PyTorch

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Shubhamai
2023-03-25 01:15:57 +05:30
committed by GitHub
parent 88dae78f4d
commit a0cbbba31f
15 changed files with 1057 additions and 8 deletions

View File

@@ -71,3 +71,13 @@ If you're interested in submitting a resource to be included here, please feel f
[[autodoc]] TFResNetForImageClassification
- call
## FlaxResNetModel
[[autodoc]] FlaxResNetModel
- __call__
## FlaxResNetForImageClassification
[[autodoc]] FlaxResNetForImageClassification
- __call__