Add TF DeiT implementation (#17806)

* Initial TF DeiT implementation

* Fix copies naming issues

* Fix up + docs

* Properly same main layer

* Name layers properly

* Initial TF DeiT implementation

* Fix copies naming issues

* Fix up + docs

* Properly same main layer

* Name layers properly

* Fixup

* Fix import

* Fix import

* Fix import

* Fix weight loading for tests whilst not on hub

* Add doc tests and remove to_2tuple

* Add back to_2tuple
Removing to_2tuple results in many downstream changes needed because of the copies checks

* Incorporate updates in Improve vision models #17731 PR

* Don't hard code num_channels

* Copy PyTorch DeiT embeddings and remove pytorch operations with mask

* Fix patch embeddings & tidy up

* Update PixelShuffle to move logic into class layer

* Update doc strings - remove PT references

* Use NHWC format in internal layers

* Fix up

* Use linear activation layer

* Remove unused import

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Move dataclass to top of file

* Remove from_pt now weights on hub

* Fixup

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Amy Roberts <amyeroberts@users.noreply.github.com>
This commit is contained in:
amyeroberts
2022-07-13 18:04:08 +01:00
committed by GitHub
parent 7ea6ccc2b3
commit 8581a798c0
10 changed files with 1471 additions and 3 deletions

View File

@@ -26,6 +26,7 @@ src/transformers/models/cvt/modeling_cvt.py
src/transformers/models/data2vec/modeling_data2vec_audio.py
src/transformers/models/data2vec/modeling_data2vec_vision.py
src/transformers/models/deit/modeling_deit.py
src/transformers/models/deit/modeling_tf_deit.py
src/transformers/models/detr/modeling_detr.py
src/transformers/models/dpt/modeling_dpt.py
src/transformers/models/electra/modeling_electra.py