Commit Graph

1603 Commits

Author SHA1 Message Date
Patrick von Platen
ddbb485c41 [TF-PT-Tests] Fix PyTorch - TF tests for different GPU devices (#15846) 2022-02-28 15:46:46 -05:00
Nicolas Patry
97f9b8a27b Fixing the timestamps with chunking. (#15843)
* Fixing the timestamps with chunking.

* The changes modified (and fixed) the striding tests.

* Adding a tokenizer test.

* Update src/transformers/pipelines/automatic_speech_recognition.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Defense -> comment.

* Update src/transformers/models/wav2vec2/tokenization_wav2vec2.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2022-02-28 21:00:21 +01:00
Sanchit Gandhi
e3342edc4e Flax Speech-Encoder-Decoder Model (#15613)
* rebase

* Delete shift tokens func

* downsample decoder input seq len for init

* correct attention mask

* add tests

* pt flax cross test

* make fixup

* init file for import

* change pt-flax cross test threshold

* pt-flax test logits only

* move tests

* make repo-consistency

* consistent indentation

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2022-02-28 12:22:36 +01:00
Patrick von Platen
935a76d90d [UniSpeechSat] correct unispeech sat (#15847) 2022-02-28 11:23:13 +01:00
Sayak Paul
84eaa6acf5 Add TFConvNextModel (#15750)
* feat: initial implementation of convnext in tensorflow.

* fix: sample code for the classification model.

* chore: added checked for  from the classification model.

* chore: set bias initializer in the classification head.

* chore: updated license terms.

* chore: removed ununsed imports

* feat: enabled  argument during using drop_path.

* chore: replaced tf.identity with layers.Activation(linear).

* chore: edited default checkpoint.

* fix: minor bugs in the initializations.

* partial-fix: tf model errors for loading pretrained pt weights.

* partial-fix: call method updated

* partial-fix: cross loading of weights (4x3 variables to be matched)

* chore: removed unneeded comment.

* removed playground.py

* rebasing

* rebasing and removing playground.py.

* fix: renaming TFConvNextStage conv and layer norm layers

* chore: added initializers and other minor additions.

* chore: added initializers and other minor additions.

* add: tests for convnext.

* fix: integration tester class.

* fix: issues mentioned in pr feedback (round 1).

* fix: how output_hidden_states arg is propoagated inside the network.

* feat: handling of  arg for pure cnn models.

* chore: added a note on equal contribution in model docs.

* rebasing

* rebasing and removing playground.py.

* feat: encapsulation for the convnext trunk.

* Fix variable naming; Test-related corrections; Run make fixup

* chore: added Joao as a contributor to convnext.

* rebasing

* rebasing and removing playground.py.

* rebasing

* rebasing and removing playground.py.

* chore: corrected copyright year and added comment on NHWC.

* chore: fixed the black version and ran formatting.

* chore: ran make style.

* chore: removed from_pt argument from test, ran make style.

* rebasing

* rebasing and removing playground.py.

* rebasing

* rebasing and removing playground.py.

* fix: tests in the convnext subclass, ran make style.

* rebasing

* rebasing and removing playground.py.

* rebasing

* rebasing and removing playground.py.

* chore: moved convnext test to the correct location

* fix: locations for the test file of convnext.

* fix: convnext tests.

* chore: applied  sgugger's suggestion for dealing w/ output_attentions.

* chore: added comments.

* chore: applied updated quality enviornment style.

* chore: applied formatting with quality enviornment.

* chore: revert to the previous tests/test_modeling_common.py.

* chore: revert to the original test_modeling_common.py

* chore: revert to previous states for test_modeling_tf_common.py and modeling_tf_utils.py

* fix: tests for convnext.

* chore: removed output_attentions argument from convnext config.

* chore: revert to the earlier tf utils.

* fix: output shapes of the hidden states

* chore: removed unnecessary comment

* chore: reverting to the right test_modeling_tf_common.py.

* Styling nits

Co-authored-by: ariG23498 <aritra.born2fly@gmail.com>
Co-authored-by: Joao Gante <joao@huggingface.co>
Co-authored-by: Sylvain Gugger <Sylvain.gugger@gmail.com>
2022-02-25 18:19:16 +01:00
Yih-Dar
8635407bc7 Fix tf.concatenate + test past_key_values for TF models (#15774)
* fix wrong method name tf.concatenate

* add tests related to causal LM / decoder

* make style and quality

* clean-up

* Fix TFBertModel's extended_attention_mask when past_key_values is provided

* Fix tests

* fix copies

* More tf.int8 -> tf.int32 in TF test template

* clean-up

* Update TF test template

* revert the previous commit + update the TF test template

* Fix TF template extended_attention_mask when past_key_values is provided

* Fix some styles manually

* clean-up

* Fix ValueError: too many values to unpack in the test

* Fix more: too many values to unpack in the test

* Add a comment for extended_attention_mask when there is past_key_values

* Fix TFElectra extended_attention_mask when past_key_values is provided

* Add tests to other TF models

* Fix for TF Electra test: add prepare_config_and_inputs_for_decoder

* Fix not passing training arg to lm_head in TFRobertaForCausalLM

* Fix tests (with past) for TF Roberta

* add testing for pask_key_values for TFElectra model

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2022-02-25 17:11:46 +01:00
Nicolas Patry
ad0d7d1745 Adding the option to return_timestamps on pure CTC ASR models. (#15792)
* Adding the option to return_timestamps on pure CTC ASR models.

* Remove `math.prod` which was introduced in Python 3.8

* int are not floats.

* Reworking the PR to support "char" vs "word" output.

* Fixup!

* Update src/transformers/pipelines/automatic_speech_recognition.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines/automatic_speech_recognition.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines/automatic_speech_recognition.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines/automatic_speech_recognition.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines/automatic_speech_recognition.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines/automatic_speech_recognition.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines/automatic_speech_recognition.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines/automatic_speech_recognition.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines/automatic_speech_recognition.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Quality.

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2022-02-25 14:06:45 +01:00
Sylvain Gugger
074645e32a Fix semantic segmentation pipeline test (#15826) 2022-02-25 09:21:29 +01:00
Patrick von Platen
cbf4391177 [TFXLNet] Correct tf xlnet generate (#15822)
* [TFXLNet] Correct tf xlnet

* adapt test comment
2022-02-24 19:23:34 +01:00
Patrick von Platen
ca57b45071 [Unispeech] Fix slow tests (#15818)
* remove soundfile old way of loading audio

* Adapt slow test
2022-02-24 19:08:54 +01:00
Sylvain Gugger
35ecf99cc4 Revert changes in logit size for semantic segmentation models (#15722)
* Revert changes in logit size for semantic segmentation models

* Address review comments
2022-02-24 15:52:52 +01:00
Sylvain Gugger
d1fcc90abf Fix from_pretrained with default base_model_prefix (#15814) 2022-02-24 11:43:51 +01:00
Lysandre Debut
29c10a41d0 [Test refactor 1/5] Per-folder tests reorganization (#15725)
* Per-folder tests reorganization

Co-authored-by: sgugger <sylvain.gugger@gmail.com>
Co-authored-by: Stas Bekman <stas@stason.org>
2022-02-23 15:46:28 -05:00
Nicolas Patry
9e71d46455 Enable image-segmentation on AutoModelForSemanticSegmentation (#15647)
* Enabling Beit SegFormer to `image-segmentation`.

* Fixing the score.

* Fix import ?

* Missing in type hint.

* Multiple test fixes:

- Add `raw_image` support. It should be the default IMHO since in Python
  world it doesn't make any sense to base64 encode the image (Sorry
  @mishig, didn't catch that in my review). I really think we should
  consider breaking BC here.
- Add support for Segformer tiny test (needed
  `SegformerModelTester.get_config` to enable TinyConfig
  @NielsRogge)
- Add the check that `batch_size` works correctly on that pipeline.
  Uncovered that it doesn't for Detr, which IMO is OK since images
  after `feature_extractor` don't have the same size. Comment should
  explain.

* Type hint as a string.

* Make fixup + update black.

* torch+vision protections.

* Don't use torchvision, use F.interpolate instead (no new dep).

* Last fixes for Segformer.

* Update test to reflect new image (which was broken)

* Update tests.

* Major BC modification:

- Removed the string compressed PNG string, that's a job for users
`transformers` stays in python land.
- Removed the `score` for semantic segmentation. It has hardly a meaning
  on its own in this context.
- Don't include the grayscale with logits for now (which could enable
  users to get a sense of confidence). Might be done later.
- Don't include the surface of the mask (could be used for sorting by
  users, to filter out small masks). It's already calculable, and
  it's easier to add later, than to add now and break later if we need.

* `make fixup`.

* Small changes.

* Rebase + doc fixup.
2022-02-23 17:20:26 +01:00
Nicolas Patry
f9582c205a Adding ZeroShotImageClassificationPipeline (#12119)
* [Proposal] Adding ZeroShotImageClassificationPipeline

- Based on CLIP

* WIP, Resurection in progress.

* Resurrection... achieved.

* Reword handling different `padding_value` for `feature_extractor` and
`tokenizer`.

* Thanks doc-builder !

* Adding docs + global namespace `ZeroShotImageClassificationPipeline`.

* Fixing templates.

* Make the test pass and be robust to floating error.

* Adressing suraj's comments on docs mostly.

* Tf support start.

* TF support.

* Update src/transformers/pipelines/zero_shot_image_classification.py

Co-authored-by: Suraj Patil <surajp815@gmail.com>

Co-authored-by: Suraj Patil <surajp815@gmail.com>
2022-02-23 09:41:42 +01:00
Patrick von Platen
c44d3675c2 Time stamps for CTC models (#15687)
* [Wav2Vec2 Time Stamps]

* Add first version

* add word time stamps

* Fix

* save intermediate space

* improve

* [Finish CTC Tokenizer]

* remove @

* remove @

* push

* continue with phonemes

* up

* finish PR

* up

* add example

* rename

* finish

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* correct split

* finalize

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
2022-02-22 19:26:44 +01:00
Funtowicz Morgan
32295b15a1 Gelu10 (#15676)
* Add GeLU10 (clipped version of GeLU) to transformers to improve quantization performances.

* Add unittests.

* Import tensorflow after `is_tf_available` check.

* Fix tensorflow wrong function `tf.tensor` to `tf.constant`

* style.

* use `tf.math.max`

* Fix tf tests.

* style.

* style style style style style style

* style style style style style style

* Address @sgugger comments.

* Fix wrong operator for raising ValueError for ClippedGELUActivation.
2022-02-22 18:21:16 +01:00
SaulLu
0187c6f0ad revert temporary addition to test next version of CLIPTokenizerFast (#15717) 2022-02-21 18:30:11 +01:00
Sanchit Gandhi
60ba48205e fix bug in PT speech-encoder-decoder (#15699)
* fix bug in PT speech-encoder-decoder

* add pt test for `inputs is not None`

* fix test

* new pt test

* Update tests/test_modeling_speech_encoder_decoder.py

* make fixup

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2022-02-18 18:20:24 +01:00
Lysandre Debut
83f45cd656 Fix auto (#15706) 2022-02-18 08:50:23 -05:00
Gunjan Chhablani
ae1f835028 Add PLBart (#13269)
* Init PLBART

* Add missing configuration file

* Add conversion script and configurationf ile

* Fix style

* Update modeling and conversion scripts

* Fix scale embedding in config

* Add comment

* Fix conversion script

* Add classification option to conversion script

* Fix vocab size in config doc

* Add tokenizer files from MBart50

* Allow no lang code in regular tokenizer

* Add PLBart Tokenizer Converters

* Remove mask from multi tokenizer

* Remove mask from multi tokenizer

* Change from MBart-50 to MBart tokenizer

* Fix names and modify src/tgt behavior

* Fix imports for tokenizer

* Remove <mask> from multi tokenizer

* Fix style

* Change tokenizer_class to processor_class

* Add attribute map to config class

* Update modeling file to modified MBart code

* Update configuration file to MBart style configuration

* Fix tokenizer

* Separate tokenizers

* Fix error in tokenization auto

* Copy MBart tests

* Replace with MBart tokenization tests

* Fix style

* Fix language code in multi tokenizer

* Fix configuration docs

* Add entry for plbart_multi in transformers init

* Add dummy objects and fix imports

* Fix modeling tests

* Add TODO in config

* Fix copyright year

* Fix modeling docs and test

* Fix some tokenization tests and style

* Add changes from review

* Fix copies

* Fix docs

* Fix docs

* Fix style

* Fix year

* Add changes from review

* Remove extra changes

* Fix base tokenizer and doc

* Fix style

* Fix modeling and slow tokenizer tests

* Remove Multi-tokenizer Converter and Tests

* Delete QA model and Multi Tokenizer dummy objects

* Fix repo consistency and code quality issues

* Fix example documentation

* Fix style

* Remove PLBartTokenizer from type checking in init

* Fix consistency issue

* Add changes from review

* Fix style

* Remove PLBartTokenizerFast

* Remove FastTokenizer converter

* Fix AutoTokenzier mapping

* Add plbart to toctree and fix consistency issues

* Add language codes tokenizer test

* Fix styling and doc issues

* Add fixes for failing tests

* Fix copies

* Fix failing modeling test

* Change assert to assertTrue in modeling tests
2022-02-18 14:17:09 +01:00
Yih-Dar
2f2fefd6af Fix LongformerModel hidden states (#15537)
* add undo padding

* fix

* fix tuple issue

* make style and quality

* move unpad logic to LongformerEncoder + unpad attentions + update tests

* move unpad logic to TFLongformerEncoder

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2022-02-18 13:56:53 +01:00
Yih-Dar
f8ff3fad87 TF: add initializer_std with a small value in TFFunnelModelTester (#15684) 2022-02-18 11:20:07 +00:00
SaulLu
e93763d420 fix CLIP fast tokenizer and change some properties of the slow version (#15067)
Very big changes concerning the tokenizer fast of CLIP which did not correspond to the tokenizer slow of CLIP

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
2022-02-18 10:21:30 +01:00
NielsRogge
57882177be Add SimMIM (#15586)
* Add first draft

* Make model importable

* Make SwinForMaskedImageModeling importable

* Fix imports

* Add missing inits

* Add support for Swin

* Fix bug

* Fix bug

* Fix another bug

* Fix Swin MIM implementation

* Fix default encoder stride

* Fix Swin

* Add print statements for debugging

* Add image_size data argument

* Fix Swin

* Fix image_size

* Add print statements for debugging

* Fix print statement

* Remove print statements

* Improve reshaping of bool_masked_pos

* Add support for DeiT, fix tests

* Improve docstrings

* Apply new black version

* Improve script

* Fix bug

* Improve README

* Apply suggestions from code review

* Remove DS_Store and add to gitignore

* Apply suggestions from code review + fix BEiT Flax

* Revert BEiT changes

* Improve README

* Fix code quality

* Improve README

Co-authored-by: Niels Rogge <nielsrogge@Nielss-MBP.localdomain>
Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
2022-02-17 19:44:55 +01:00
Tanay Mehta
f84e0dbd2a Add PoolFormer (#15531)
* Added all files, PoolFormerFeatureExtractor still failing tests

* Fixed PoolFormerFeatureExtractor not being able to import

* Completed Poolformer doc

* Applied Suggested fixes

* Fixed errors in modeling_auto.py

* Fix feature extractor, convert docs to Markdown, styling of code

* Remove PoolFormer from check_repo and fix integration test

* Remove Poolformer from check_repo

* Fixed configuration_poolformer.py docs and removed inference.py from poolformer

* Ran with black v22

* Added PoolFormer to _toctree.yml

* Updated poolformer doc

* Applied suggested fixes and added on README.md

* Did make fixup and make fix-copies, tests should pass now

* Changed PoolFormer weights conversion script name and fixed README

* Applied fixes in test_modeling_poolformer.py and modeling_poolformer.py

* Added PoolFormerFeatureExtractor to AutoFeatureExtractor API

Co-authored-by: Niels Rogge <nielsrogge@Nielss-MBP.localdomain>
2022-02-17 13:16:37 +01:00
Eldar Kurtic
f65fe3663a Implementation of activations as pytorch modules (#15616)
* Implement activations as pytorch modules

* Apply fixup

* Add missing tests for activations

* Update docstring

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2022-02-16 14:37:52 -05:00
Patrick von Platen
3a4376d008 [Wav2Vec2ProcessorWithLM] Fix auto processor with lm (#15683) 2022-02-16 17:33:33 +01:00
Sylvain Gugger
cdc51ffd27 Add register method to AutoProcessor (#15669)
* Add push_to_hub method to processors

* Fix test

* The other one too!

* Add register method to AutoProcessor

* Update src/transformers/models/auto/processing_auto.py

Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr>

Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr>
2022-02-16 09:13:33 -05:00
Sylvain Gugger
2d02f7b29b Add push_to_hub method to processors (#15668)
* Add push_to_hub method to processors

* Fix test

* The other one too!
2022-02-15 21:14:04 -05:00
Lysandre Debut
1ddf3c2b74 Fix vit test (#15671) 2022-02-15 18:55:38 -05:00
Lysandre Debut
943e2aa036 Fix model equivalence tests (#15670)
* Fix model equivalence tests

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
2022-02-15 18:55:22 -05:00
Patrick von Platen
2e12b907ae TF generate refactor - Greedy Search (#15562)
* TF generate start refactor

* Add tf tests for sample generate

* re-organize

* boom boom

* Apply suggestions from code review

* re-add

* add all code

* make random greedy pass

* make encoder-decoder random work

* further improvements

* delete bogus file

* make gpt2 and t5 tests work

* finish logits tests

* correct logits processors

* correct past / encoder_outputs drama

* refactor some methods

* another fix

* refactor shape_list

* fix more shape list

* import shape
_list

* finish docs

* fix imports

* make style

* correct tf utils

* Fix TFRag as well

* Apply Lysandre's and Sylvais suggestions

* Update tests/test_generation_tf_logits_process.py

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>

* Update src/transformers/tf_utils.py

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>

* remove cpu according to gante

* correct logit processor

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
2022-02-15 17:54:43 +01:00
Nicolas Patry
a3dbbc3467 Add decoder_kwargs to send to LM on asr pipeline. (#15646)
Co-authored-by: Giuseppe Attanasio <giuseppeattanasio6@gmail.com>

Co-authored-by: Giuseppe Attanasio <giuseppeattanasio6@gmail.com>
2022-02-15 17:53:24 +01:00
arampacha
67047b86ce add scores to Wav2Vec2WithLMOutput (#15413)
* add scores to Wav2Vec2WithLMOutput

* style fixup
2022-02-15 16:40:50 +01:00
Sylvain Gugger
45f56580a7 Allow custom code for Processors (#15649)
* Allow custom code for Processors

* Add more test

* Test all auto_map configs are properly set
2022-02-15 09:44:35 -05:00
Javier de la Rosa
9eb7e9ba1d Fix ASR pipelines from local directories with wav2vec models that have language models attached (#15590)
* Fix loading pipelines with wav2vec models with lm when in local paths

* Adding tests

* Fix test

* Adding tests

* Flake8 fixes

* Removing conflict files :(

* Adding task type to test

* Remove unnecessary test and imports
2022-02-15 13:45:08 +01:00
Patrick von Platen
041fdc4a7e [SpeechEncoderDecoder] Make sure no EOS is generated in test (#15655) 2022-02-15 09:13:55 +01:00
Sylvain Gugger
2e11a04337 Register feature extractor (#15634)
* Rework AutoFeatureExtractor.from_pretrained internal

* Custom feature extractor

* Add more tests

* Add support for custom feature extractor code

* Clean up

* Add register API to AutoFeatureExtractor
2022-02-14 13:35:16 -05:00
Sylvain Gugger
52d2e6f6e9 Add push to hub to feature extractor (#15632)
* Add push to hub to feature extractor

* Quality

* Clean up
2022-02-11 17:14:01 -05:00
Sylvain Gugger
7a32e4722f Custom feature extractor (#15630)
* Rework AutoFeatureExtractor.from_pretrained internal

* Custom feature extractor

* Add more tests

* Add support for custom feature extractor code

* Clean up
2022-02-11 16:43:54 -05:00
Sylvain Gugger
2dce350b33 Fix _configuration_file argument getting passed to model (#15629) 2022-02-11 13:46:08 -05:00
Joao Gante
2f40c728c9 TF MT5 embeddings resize (#15567)
* Fix TF MT5 vocab resize

* more assertive testing
2022-02-11 17:35:10 +00:00
Yih-Dar
724e51c6e6 Compute loss independent from decoder for TF EncDec models (as #14139) (#15175)
* Compute loss independent from decoder (as 14139)

* fix expected seq_len + style

* Apply the same change to TFVisionEncoderDecoderModel

* fix style

* Add case with labels in equivalence test

* uncomment

* Add case with labels in equivalence test

* add decoder_token_labels

* use hf_compute_loss

* Apply suggestions from code review

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Add copied from

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
2022-02-10 15:47:02 +01:00
Alberto Bégué
cb7ed6e083 Add Tensorflow handling of ONNX conversion (#13831)
* Add TensorFlow support for ONNX export

* Change documentation to mention conversion with Tensorflow

* Refactor export into export_pytorch and export_tensorflow

* Check model's type instead of framework installation to choose between TF and Pytorch

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Alberto Bégué <alberto.begue@della.ai>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2022-02-10 11:18:41 +01:00
Lysandre
e923917cd9 Reformat tokenization_fnet 2022-02-09 22:23:32 -05:00
Sylvain Gugger
644ec05233 Make slow tests slow 2022-02-09 19:10:22 -05:00
Sylvain Gugger
315e67404d Fix tests hub failure (#15580)
* Expose hub test problem

* Fix tests
2022-02-09 12:27:59 -05:00
Sylvain Gugger
b1ba03e082 Fix quality 2022-02-09 12:06:59 -05:00
Chan Woo Kim
2b5603f6ac Constrained Beam Search [without disjunctive decoding] (#15416)
* added classes to get started with constrained beam search

* in progress, think i can directly force tokens now but not yet with the round robin

* think now i have total control, now need to code the bank selection

* technically works as desired, need to optimize and fix design choices leading to undersirable outputs

* complete PR #1 without disjunctive decoding

* removed incorrect tests

* Delete k.txt

* Delete test.py

* Delete test.sh

* revert changes to test scripts

* genutils

* full implementation with testing, no disjunctive yet

* shifted docs

* passing all tests realistically ran locally

* removing accidentally included print statements

* fixed source of error in initial PR test

* fixing the get_device() vs device trap

* fixed documentation docstrings about constrained_beam_search

* fixed tests having failing for Speech2TextModel's floating point inputs

* fix cuda long tensor

* added examples and testing for them and founx & fixed a bug in beam_search and constrained_beam_search

* deleted accidentally added test halting code with assert False

* code reformat

* Update tests/test_generation_utils.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update tests/test_generation_utils.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update tests/test_generation_utils.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update tests/test_generation_utils.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update tests/test_generation_utils.py

* fixing based on comments on PR

* took out the testing code that should but work fails without the beam search moditification ; style changes

* fixing comments issues

* docstrings for ConstraintListState

* typo in PhrsalConstraint docstring

* docstrings improvements

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2022-02-09 16:59:26 +01:00