Commit Graph

18 Commits

Author SHA1 Message Date
Patrick von Platen
a1bbcf3f6c Refactoring the generate() function (#6949)
* first draft

* show design proposition for new generate method

* up

* make better readable

* make first version

* gpt2 tests pass

* make beam search for gpt2 work

* add first encoder-decoder code

* delete typo

* make t5 work

* save indermediate

* make bart work with beam search

* finish beam search bart / t5

* add default kwargs

* make more tests pass

* fix no bad words sampler

* some fixes and tests for all distribution processors

* fix test

* fix rag slow tests

* merge to master

* add nograd to generate

* make all slow tests pass

* speed up generate

* fix edge case bug

* small fix

* correct typo

* add type hints and docstrings

* fix typos in tests

* add beam search tests

* add tests for beam scorer

* fix test rag

* finish beam search tests

* move generation tests in seperate file

* fix generation tests

* more tests

* add aggressive generation tests

* fix tests

* add gpt2 sample test

* add more docstring

* add more docs

* finish doc strings

* apply some more of sylvains and sams comments

* fix some typos

* make fix copies

* apply lysandres and sylvains comments

* final corrections on examples

* small fix for reformer
2020-11-03 16:04:22 +01:00
Santiago Castro
969859d5f6 Fix doc errors and typos across the board (#8139)
* Fix doc errors and typos across the board

* Fix a typo

* Fix the CI

* Fix more typos

* Fix CI

* More fixes

* Fix CI

* More fixes

* More fixes
2020-10-29 10:33:33 -04:00
Sylvain Gugger
08f534d2da Doc styling (#8067)
* Important files

* Styling them all

* Revert "Styling them all"

This reverts commit 7d029395fdae8513b8281cbc2a6c239f8093503e.

* Syling them for realsies

* Fix syntax error

* Fix benchmark_utils

* More fixes

* Fix modeling auto and script

* Remove new line

* Fixes

* More fixes

* Fix more files

* Style

* Add FSMT

* More fixes

* More fixes

* More fixes

* More fixes

* Fixes

* More fixes

* More fixes

* Last fixes

* Make sphinx happy
2020-10-26 18:26:02 -04:00
ayushtiku5
776e82d2be Add support to provide initial tokens to decoder of encoder-decoder type models (#7577)
* Add support to provide initial tokens for decoding

* Add docstring

* improve code quality

* code reformat

* code reformat

* minor change

* remove appending decoder start token

Co-authored-by: Ayush Jain <a.jain@sprinklr.com>
2020-10-19 08:56:08 +02:00
Patrick von Platen
7fd1febf38 Add "Leveraging Pretrained Checkpoints for Generation" Seq2Seq models. (#6594)
* add conversion script

* improve conversion script

* make style

* add tryout files

* fix

* update

* add causal bert

* better names

* add tokenizer file as well

* finish causal_bert

* fix small bugs

* improve generate

* change naming

* renaming

* renaming

* renaming

* remove leftover files

* clean files

* add fix tokenizer

* finalize

* correct slow test

* update docs

* small fixes

* fix link

* adapt check repo

* apply sams and sylvains recommendations

* fix import

* implement Lysandres recommendations

* fix logger warn
2020-09-10 16:40:51 +02:00
Stas Bekman
03e363f9ae [generation] consistently add eos tokens (#6982)
Currently beam search returns inconsistent outputs - if hypos have different lengths we get eos, if they are the same - we don't.

This PR makes the output consistent.

Also why not also replace:

```
            if sent_lengths[i] < max_length:
                decoded[i, sent_lengths[i]] = eos_token_id
```
with:
```
            decoded[i, sent_lengths[i]] = eos_token_id
```
Shouldn't eos always be there? If the data gets truncated, the caller needs to user a larger `max_length`.

Please correct me if my logic is flawed.
2020-09-09 04:08:36 -04:00
Stas Bekman
848fbe1e35 [gen utils] missing else case (#6980)
* [gen utils] missing else case

1. `else` is missing - I hit that case while porting a model. Probably needs to assert there?
2. also the comment on top seems to be outdated (just vocab_size is being set there)

* typo
2020-09-07 07:28:06 -04:00
Stas Bekman
c3317e1f80 typo (#6959)
there is no var `decoder_input_ids`, but there is `input_ids` for decoder :)
2020-09-07 05:16:24 -04:00
Patrick von Platen
afc4ece462 [Generate] Facilitate PyTorch generate using ModelOutputs (#6735)
* fix generate for GPT2 Double Head

* fix gpt2 double head model

* fix  bart / t5

* also add for no beam search

* fix no beam search

* fix encoder decoder

* simplify t5

* simplify t5

* fix t5 tests

* fix BART

* fix transfo-xl

* fix conflict

* integrating sylvains and sams comments

* fix tf past_decoder_key_values

* fix enc dec test
2020-09-01 12:38:25 +02:00
Lysandre
a75c64d80c Black 20 release 2020-08-26 17:20:22 +02:00
Lysandre Debut
77abd1e79f Centralize logging (#6434)
* Logging

* Style

* hf_logging > utils.logging

* Address @thomwolf's comments

* Update test

* Update src/transformers/benchmark/benchmark_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Revert bad change

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
2020-08-26 11:10:36 -04:00
Oren Amsalem
93c5c9a528 [cleanup] remove confusing newline (#6603) 2020-08-20 00:33:36 -04:00
Sylvain Gugger
895ed8f451 Generation doc (#6470)
* Generation doc

* MBartForConditionalGeneration (#6441)

* add MBartForConditionalGeneration

* style

* rebase and fixes

* add mbart test in TEST_FILES_WITH_NO_COMMON_TESTS

* fix docs

* don't ignore mbart

* doc

* fix mbart fairseq link

* put mbart before bart

* apply doc suggestions

* Use hash to clean the test dirs (#6475)

* Use hash to clean the test dirs

* Use hash to clean the test dirs

* Use hash to clean the test dirs

* fix

* [EncoderDecoder] Add Cross Attention for GPT2 (#6415)

* add cross attention layers for gpt2

* make gpt2 cross attention work

* finish bert2gpt2

* add explicit comments

* remove attention mask since not yet supported

* revert attn mask in pipeline

* Update src/transformers/modeling_gpt2.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/modeling_encoder_decoder.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Sort unique_no_split_tokens to make it deterministic (#6461)

* change unique_no_split_tokens's type to set

* use sorted list instead of set

* style

* Import accuracy_score (#6480)

* Apply suggestions from code review

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Address comments

* Styling

* Generation doc

* Apply suggestions from code review

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Address comments

* Styling

Co-authored-by: Suraj Patil <surajp815@gmail.com>
Co-authored-by: Kevin Canwen Xu <canwenxu@126.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com>
Co-authored-by: gijswijnholds <gijswijnholds@gmail.com>
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
2020-08-14 09:46:39 -04:00
Patrick von Platen
1d6e71e116 [EncoderDecoder] Add Cross Attention for GPT2 (#6415)
* add cross attention layers for gpt2

* make gpt2 cross attention work

* finish bert2gpt2

* add explicit comments

* remove attention mask since not yet supported

* revert attn mask in pipeline

* Update src/transformers/modeling_gpt2.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/modeling_encoder_decoder.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
2020-08-14 09:43:29 +02:00
Zhu Baohe
9d94aecd51 Fix docs and bad word tokens generation_utils.py (#6387)
* fix

* fix2

* fix3
2020-08-13 13:12:16 +02:00
guillaume-be
404782912a [Performance improvement] "Bad tokens ids" optimization (#6064)
* Optimized banned token masking

* Avoid duplicate EOS masking if in bad_words_id

* Updated mask generation to handle empty banned token list

* Addition of unit tests for the updated bad_words_ids masking

* Updated timeout handling in `test_postprocess_next_token_scores_large_bad_words_list` unit test

* Updated timeout handling in `test_postprocess_next_token_scores_large_bad_words_list` unit test (timeout does not work on Windows)

* Moving Marian import to the test context to allow TF only environments to run

* Moving imports to torch_available test

* Updated operations device and test

* Updated operations device and test

* Added docstring and comment for in-place scores modification

* Moving test to own test_generation_utils, use of lighter models for testing

* removed unneded imports in test_modeling_common

* revert formatting change for ModelTesterMixin

* Updated caching, simplified eos token id test, removed unnecessary @require_torch

* formatting compliance
2020-08-11 05:56:40 -04:00
Patrick von Platen
991172922f better error message (#5497) 2020-07-03 19:25:25 +02:00
Yacine Jernite
c4d4e8bdbd Move GenerationMixin to separate file (#5254)
* separate_generation_code

* isort

* renamed

* rename_files

* move_shapelit
2020-06-30 10:42:08 -04:00