Compare commits
167 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7d9a9d0c72 | ||
|
|
c949516695 | ||
|
|
04dc65e5c6 | ||
|
|
b2dfcc567b | ||
|
|
eabad8fd9c | ||
|
|
0c9f01a8e5 | ||
|
|
27d0e01d75 | ||
|
|
245cdb469d | ||
|
|
247a7b2029 | ||
|
|
69ed36063a | ||
|
|
2df34f4aba | ||
|
|
5f6721032a | ||
|
|
063d8d27f4 | ||
|
|
e6ecef711e | ||
|
|
250f27f207 | ||
|
|
dfbf0f5598 | ||
|
|
a1100fac67 | ||
|
|
e45eba3b1c | ||
|
|
ccd1923f46 | ||
|
|
2aa9c2f204 | ||
|
|
406cbf58b2 | ||
|
|
3b67c5abb0 | ||
|
|
a051d8928a | ||
|
|
7f28613213 | ||
|
|
0ecbb69806 | ||
|
|
e6f211cade | ||
|
|
01a1684078 | ||
|
|
6009668c63 | ||
|
|
ba702966ba | ||
|
|
33b7422839 | ||
|
|
6f63501383 | ||
|
|
d20e9c7299 | ||
|
|
8d25df2c7a | ||
|
|
5a442a8db1 | ||
|
|
6c8ec2a931 | ||
|
|
1e3c362235 | ||
|
|
d415882b41 | ||
|
|
1243ee7d0c | ||
|
|
cf416764f4 | ||
|
|
09926c8e86 | ||
|
|
4f7022d68d | ||
|
|
96f1f74aaf | ||
|
|
1c19b423bf | ||
|
|
02e05fb0a5 | ||
|
|
4fbcf8ea49 | ||
|
|
e34e45536f | ||
|
|
1bdf42409c | ||
|
|
79bbcc5260 | ||
|
|
9e1ea846bc | ||
|
|
bf9056442a | ||
|
|
f33a6f3446 | ||
|
|
758ed3332b | ||
|
|
a400fe8931 | ||
|
|
ae5a32bb0d | ||
|
|
812045adcc | ||
|
|
390cf16bc8 | ||
|
|
28d74872cc | ||
|
|
3ec40299c1 | ||
|
|
b8462b5b2a | ||
|
|
0c96262f7d | ||
|
|
c89f1bc92e | ||
|
|
7a9f1b5c99 | ||
|
|
ecfcac223c | ||
|
|
be898998bb | ||
|
|
b972c1bfb0 | ||
|
|
bcb55d33ce | ||
|
|
b7e548976f | ||
|
|
9f675b05d4 | ||
|
|
453a70d4cb | ||
|
|
7988edc031 | ||
|
|
c9553c0352 | ||
|
|
090d28e32d | ||
|
|
d64372fdfc | ||
|
|
eef66035a2 | ||
|
|
4eec5d0cf6 | ||
|
|
d9e848c1d6 | ||
|
|
29acabd886 | ||
|
|
57a6626929 | ||
|
|
189387e9b2 | ||
|
|
314cca2842 | ||
|
|
52d62e686c | ||
|
|
748006c0b3 | ||
|
|
4225740a7b | ||
|
|
4aa8f6ad99 | ||
|
|
83eec97ec6 | ||
|
|
30fa0b780f | ||
|
|
143289dcf7 | ||
|
|
086718ac6e | ||
|
|
47ca0eaaac | ||
|
|
75ff530551 | ||
|
|
ec54d70e16 | ||
|
|
de29ff9bd2 | ||
|
|
d018afced0 | ||
|
|
d735b074d7 | ||
|
|
5dd389d1c7 | ||
|
|
23a71449c0 | ||
|
|
6c03d4ac70 | ||
|
|
c581d8af7a | ||
|
|
8eb7f26d5d | ||
|
|
d944966b19 | ||
|
|
c4fd609afb | ||
|
|
b01f451ca3 | ||
|
|
5f7a07c0c8 | ||
|
|
ae333d04b2 | ||
|
|
8217d4e37f | ||
|
|
912f6881d2 | ||
|
|
785e52cd30 | ||
|
|
64103fb6be | ||
|
|
d97d06d05f | ||
|
|
83fdd252f6 | ||
|
|
8e74eca7f2 | ||
|
|
61443cd7d9 | ||
|
|
21fc676645 | ||
|
|
52b3a05e83 | ||
|
|
7777db159f | ||
|
|
71963a6633 | ||
|
|
f3a3b91d6f | ||
|
|
2a18b70998 | ||
|
|
6189ae9960 | ||
|
|
222dbdb203 | ||
|
|
6c091abef2 | ||
|
|
88ef8893cd | ||
|
|
a1cb6e9866 | ||
|
|
bcc87c639f | ||
|
|
d5db6c37d4 | ||
|
|
4bafc43b0e | ||
|
|
58e8a7611f | ||
|
|
cbe63949d7 | ||
|
|
e6c1f1cad8 | ||
|
|
ab17758874 | ||
|
|
5b5f7dd09c | ||
|
|
1558d191e6 | ||
|
|
37d6fb5d04 | ||
|
|
189c1b91a6 | ||
|
|
490b39e614 | ||
|
|
1fc7119181 | ||
|
|
e9d77ccd5a | ||
|
|
ec07da65e2 | ||
|
|
4eef5889ac | ||
|
|
9a12b9696f | ||
|
|
f4432b7e01 | ||
|
|
08abdabda1 | ||
|
|
161a6461db | ||
|
|
5a8a4eb187 | ||
|
|
6b034309ca | ||
|
|
a4b21cdd20 | ||
|
|
f38c4ad302 | ||
|
|
e0e255be1f | ||
|
|
6b850b671d | ||
|
|
3ff5e8955a | ||
|
|
291974c65c | ||
|
|
1198ba8fba | ||
|
|
9a25c5bd3a | ||
|
|
3e56e2ce04 | ||
|
|
077a5dce32 | ||
|
|
84d5879eaf | ||
|
|
fd7b6a5274 | ||
|
|
66a14a2f6f | ||
|
|
f06d0fadc9 | ||
|
|
467e9158b4 | ||
|
|
63841c559b | ||
|
|
bf713cdec7 | ||
|
|
bd40345d3e | ||
|
|
bfa4ccf77d | ||
|
|
e0790cca78 | ||
|
|
6d2e864db7 | ||
|
|
f83d9c8da7 |
@@ -53,4 +53,5 @@ deploy_doc "3ebb1b3" v3.2.0
|
||||
deploy_doc "0613f05" v3.3.1
|
||||
deploy_doc "eb0e0ce" v3.4.0
|
||||
deploy_doc "818878d" v3.5.1
|
||||
deploy_doc "c781171" # v4.0.0 Latest stable release
|
||||
deploy_doc "c781171" v4.0.0
|
||||
deploy_doc "bfa4ccf" # v4.1.1 Latest stable release
|
||||
|
||||
9
.github/ISSUE_TEMPLATE/bug-report.md
vendored
9
.github/ISSUE_TEMPLATE/bug-report.md
vendored
@@ -11,7 +11,7 @@ assignees: ''
|
||||
## Environment info
|
||||
<!-- You can run the command `transformers-cli env` and copy-and-paste its output below.
|
||||
Don't forget to fill out the missing fields in that output! -->
|
||||
|
||||
|
||||
- `transformers` version:
|
||||
- Platform:
|
||||
- Python version:
|
||||
@@ -24,13 +24,13 @@ assignees: ''
|
||||
<!-- Your issue will be replied to more quickly if you can figure out the right person to tag with @
|
||||
If you know how to use git blame, that is the easiest way, otherwise, here is a rough guide of **who to tag**.
|
||||
Please tag fewer than 3 people.
|
||||
|
||||
albert, bert, GPT2, XLM: @LysandreJik
|
||||
|
||||
albert, bert, GPT2, XLM: @LysandreJik
|
||||
tokenizers: @mfuntowicz
|
||||
Trainer: @sgugger
|
||||
Speed and Memory Benchmarks: @patrickvonplaten
|
||||
Model Cards: @julien-c
|
||||
TextGeneration: @TevenLeScao
|
||||
TextGeneration: @TevenLeScao
|
||||
examples/distillation: @VictorSanh
|
||||
nlp datasets: [different repo](https://github.com/huggingface/nlp)
|
||||
rust tokenizers: [different repo](https://github.com/huggingface/tokenizers)
|
||||
@@ -47,6 +47,7 @@ assignees: ''
|
||||
FSMT: @stas00
|
||||
examples/seq2seq: @patil-suraj
|
||||
examples/bert-loses-patience: @JetRunner
|
||||
ray/raytune: @richardliaw @amogkam
|
||||
tensorflow: @jplu
|
||||
examples/token-classification: @stefan-it
|
||||
documentation: @sgugger
|
||||
|
||||
1
.github/stale.yml
vendored
1
.github/stale.yml
vendored
@@ -6,6 +6,7 @@ daysUntilClose: 7
|
||||
exemptLabels:
|
||||
- pinned
|
||||
- security
|
||||
- Feature request
|
||||
# Label to use when marking an issue as stale
|
||||
staleLabel: wontfix
|
||||
# Comment to post when marking an issue as stale. Set to `false` to disable
|
||||
|
||||
8
.github/workflows/github-torch-hub.yml
vendored
8
.github/workflows/github-torch-hub.yml
vendored
@@ -1,6 +1,6 @@
|
||||
name: Torch hub integration
|
||||
|
||||
on:
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- "*"
|
||||
@@ -32,8 +32,10 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install torch
|
||||
pip install numpy filelock protobuf requests tqdm regex sentencepiece sacremoses tokenizers packaging
|
||||
# install torch-hub specific dependencies
|
||||
pip install -e git+https://github.com/huggingface/transformers.git#egg=transformers[torchhub]
|
||||
# no longer needed
|
||||
pip uninstall -y transformers
|
||||
|
||||
- name: Torch hub list
|
||||
run: |
|
||||
|
||||
2
.github/workflows/model-templates.yml
vendored
2
.github/workflows/model-templates.yml
vendored
@@ -40,6 +40,8 @@ jobs:
|
||||
transformers-cli add-new-model --testing --testing_file=templates/adding_a_new_model/tests/pt-encoder-bert-tokenizer.json --path=templates/adding_a_new_model
|
||||
transformers-cli add-new-model --testing --testing_file=templates/adding_a_new_model/tests/standalone.json --path=templates/adding_a_new_model
|
||||
transformers-cli add-new-model --testing --testing_file=templates/adding_a_new_model/tests/tf-encoder-bert-tokenizer.json --path=templates/adding_a_new_model
|
||||
transformers-cli add-new-model --testing --testing_file=templates/adding_a_new_model/tests/tf-seq-2-seq-bart-tokenizer.json --path=templates/adding_a_new_model
|
||||
transformers-cli add-new-model --testing --testing_file=templates/adding_a_new_model/tests/pt-seq-2-seq-bart-tokenizer.json --path=templates/adding_a_new_model
|
||||
make style
|
||||
python utils/check_table.py --fix_and_overwrite
|
||||
python utils/check_dummies.py --fix_and_overwrite
|
||||
|
||||
2
.github/workflows/self-scheduled.yml
vendored
2
.github/workflows/self-scheduled.yml
vendored
@@ -75,7 +75,7 @@ jobs:
|
||||
RUN_SLOW: yes
|
||||
run: |
|
||||
source .env/bin/activate
|
||||
pip install -r examples/requirements.txt
|
||||
pip install -r examples/_tests_requirements.txt
|
||||
python -m pytest -n 1 --dist=loadfile -s --make-reports=examples_torch_gpu examples
|
||||
|
||||
- name: Failure short reports
|
||||
|
||||
@@ -328,11 +328,18 @@ for more information.
|
||||
|
||||
### Develop on Windows
|
||||
|
||||
On windows, you need to configure git to transform Windows `CRLF` line endings to Linux `LF` line endings:
|
||||
|
||||
`git config core.autocrlf input`
|
||||
|
||||
One way one can run the make command on Window is to pass by MSYS2:
|
||||
|
||||
1. [Download MSYS2](https://www.msys2.org/), we assume to have it installed in C:\msys64
|
||||
2. Open the command line C:\msys64\msys2.exe (it should be available from the start menu)
|
||||
3. Run in the shell: `pacman -Syu` and install make with `pacman -S make`
|
||||
4. Add `C:\msys64\usr\bin` to your PATH environment variable.
|
||||
|
||||
You can now use `make` from any terminal (Powershell, cmd.exe, etc) 🎉
|
||||
|
||||
### Syncing forked master with upstream (HuggingFace) master
|
||||
|
||||
|
||||
275
ISSUES.md
Normal file
275
ISSUES.md
Normal file
@@ -0,0 +1,275 @@
|
||||
<!---
|
||||
Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
-->
|
||||
|
||||
# How To Request Support
|
||||
|
||||
This is an Open Source Project so please be mindful that like in any other project of this kind there is no obligation to answer all requests for help.
|
||||
|
||||
However, we want to encourage you to ask for help whenever you think it's needed! We are happy about every question we get because it allows us to better understand your needs, possible misunderstandings, and most importantly a way for you to help us make this library better. That being said, this document's main purpose is to provide guidelines at how you can formulate your requests to increase your chances to be understood and to get support.
|
||||
|
||||
There are two main venues to receive support: [the forums](https://discuss.huggingface.co/) and [the GitHub issues](https://github.com/huggingface/transformers/issues).
|
||||
|
||||
## The Forums
|
||||
|
||||
[The user forums](https://discuss.huggingface.co/) are supported by the wide community of the library users and backed up by developers when needed.
|
||||
|
||||
If you have a difficulty with deploying this library or some questions, or you'd like to discuss a new feature, please first consider discussing those things at the forums. Only when you feel your subject matter has been crystalized and you still need support from the library developers do proceed to file an [issue](https://github.com/huggingface/transformers/issues).
|
||||
|
||||
In particular all "Please explain" questions or objectively very user-specific feature requests belong to the forums. Here are some example of such questions:
|
||||
|
||||
* "I would like to use a BertModel within a RL-Agent for a customer support service. How can I use a BertForMaskedLM in my ChatBotModel?"
|
||||
|
||||
* "Could you please explain why T5 has no positional embedding matrix under T5Model?"
|
||||
|
||||
* "How should I set my generation parameters for translation?"
|
||||
|
||||
* "How to train T5 on De->En translation?"
|
||||
|
||||
|
||||
## The GitHub Issues
|
||||
|
||||
Everything which hints at a bug should be opened as an [issue](https://github.com/huggingface/transformers/issues).
|
||||
|
||||
You are not required to read the following guidelines before opening an issue. However, if you notice that your issue doesn't get any replies, chances are that the developers have one or several difficulties with its quality. In this case, reading the following points and adjusting your issue accordingly could help.
|
||||
|
||||
1. Before posting an issue, first search for already posted issues, since chances are someone has already asked a similar question before you.
|
||||
|
||||
If you use Google your search query should be:
|
||||
|
||||
```
|
||||
"huggingface" "transformers" your query
|
||||
```
|
||||
|
||||
The first two quoted words tell Google to limit the search to the context of the Huggingface Transformers. The remainder is your query - most commonly this would be the error message the software fails with. We will go deeper into details shortly.
|
||||
|
||||
The results of such a query will typically match GitHub issues, Hugging Face forums, StackExchange, and blogs.
|
||||
|
||||
If you find relevant hints, you may choose to continue the discussion there if you have follow up questions.
|
||||
|
||||
If what you found is similar but doesn't quite answer your problem, please, post a new issue and do include links to similar issues or forum discussions you may have found.
|
||||
|
||||
Let's look at some examples:
|
||||
|
||||
The error message, often referred to as an assertion, tells us what went wrong. Here is an example of an assertion:
|
||||
|
||||
```python
|
||||
Traceback (most recent call last):
|
||||
File "<string>", line 1, in <module>
|
||||
File "/transformers/src/transformers/__init__.py", line 34, in <module>
|
||||
from . import dependency_versions_check
|
||||
File "/transformers/src/transformers/dependency_versions_check.py", line 34, in <module>
|
||||
from .file_utils import is_tokenizers_available
|
||||
File "/transformers/src/transformers/file_utils.py", line 40, in <module>
|
||||
from tqdm.auto import tqdm
|
||||
ModuleNotFoundError: No module named 'tqdm.auto'
|
||||
```
|
||||
|
||||
and it typically includes a traceback, so that we can see the full stack of calls the program made before it fails. This gives us the context to know why the program failed.
|
||||
|
||||
Going back to the above example. If you received this error search, look at the very last line of the error which is:
|
||||
|
||||
```python
|
||||
ModuleNotFoundError: No module named 'tqdm.auto'
|
||||
```
|
||||
|
||||
And now we can use it to do the searching on your favorite search engine:
|
||||
|
||||
1. first for `"huggingface" "transformers" "ModuleNotFoundError: No module named 'tqdm.auto'"`
|
||||
2. if you don't find relevant results, then search for just `"ModuleNotFoundError: No module named 'tqdm.auto'"`
|
||||
3. and finally if nothing still comes up, then remove the outside quotes: `ModuleNotFoundError: No module named 'tqdm.auto'`
|
||||
|
||||
If the error includes any messages that include bits unique to your filesystem, always remove those in the search query since other users will not have the same filesystem as yours. For example:
|
||||
|
||||
```bash
|
||||
python -c 'open("/tmp/wrong_path.txt", "r")'
|
||||
Traceback (most recent call last):
|
||||
File "<string>", line 1, in <module>
|
||||
FileNotFoundError: [Errno 2] No such file or directory: '/tmp/wrong_path.txt'
|
||||
```
|
||||
Here you'd search for just: `"FileNotFoundError: [Errno 2] No such file or directory"`
|
||||
|
||||
If the local information that you removed were inside the error message and you removed them you may need to remove double quotes since your query is no longer exact. So if the error message was something like:
|
||||
|
||||
```bash
|
||||
ValueError: '/tmp/wrong_path.txt' cannot be found
|
||||
```
|
||||
|
||||
then you'd search for `"ValueError" "cannot be found"`
|
||||
|
||||
As you search you will notice that when you don't use quotes often the search engines will return a variety of unrelated hits, which may or may not be what you want.
|
||||
|
||||
Experiment with different ways and find which approach gives the most satisfactory results.
|
||||
|
||||
2. Keep the issue short, providing the information that you think will aid the developers to understand your situation. Put yourself in the shoes of the person who has never seen your code or knows anything about your custom setup. This mental exercise will help to develop an intuition to what/what not to share"
|
||||
|
||||
3. If there is a software failure, always provide the full traceback, for example:
|
||||
|
||||
```python
|
||||
$ python -c 'import transformers'
|
||||
Traceback (most recent call last):
|
||||
File "<string>", line 1, in <module>
|
||||
File "/transformers/src/transformers/__init__.py", line 34, in <module>
|
||||
from . import dependency_versions_check
|
||||
File "/transformers/src/transformers/dependency_versions_check.py", line 34, in <module>
|
||||
from .file_utils import is_tokenizers_available
|
||||
File "/transformers/src/transformers/file_utils.py", line 40, in <module>
|
||||
from tqdm.auto import tqdm
|
||||
ModuleNotFoundError: No module named 'tqdm.auto'
|
||||
```
|
||||
|
||||
As compared to providing just the last line of the error message, e.g.:
|
||||
```python
|
||||
ModuleNotFoundError: No module named 'tqdm.auto'
|
||||
```
|
||||
which is not sufficient.
|
||||
|
||||
If your application is running on more than one GPU (e.g. under `DistributedDataParallel`) and typically getting every log and traceback printed multiple times, please make sure that you paste only one copy of it. At times the traceback from parallel processes may get interleaved - so either disentangle these or change the loggers to log only for `local_rank==0` so that only one process logs things.
|
||||
|
||||
4. When quoting a traceback, command line instructions and any type of code always enclose it in triple backticks inside the editor window, that is:
|
||||
|
||||
````
|
||||
```
|
||||
git clone https://github.com/huggingface/transformers
|
||||
cd transformers
|
||||
pip install .
|
||||
```
|
||||
````
|
||||
|
||||
If it's a command line with a long argument list, please consider breaking it down using backslashes and new lines. Here is an example of a good command line quote:
|
||||
|
||||
```bash
|
||||
cd examples/seq2seq
|
||||
python -m torch.distributed.launch --nproc_per_node=2 ./finetune_trainer.py \
|
||||
--model_name_or_path sshleifer/distill-mbart-en-ro-12-4 --data_dir wmt_en_ro \
|
||||
--output_dir output_dir --overwrite_output_dir \
|
||||
--do_train --n_train 500 --num_train_epochs 1 \
|
||||
--per_device_train_batch_size 1 --freeze_embeds \
|
||||
--src_lang en_XX --tgt_lang ro_RO --task translation \
|
||||
--fp16 --sharded_ddp
|
||||
```
|
||||
|
||||
If you don't break it up, one has to scroll horizontally which often makes it quite difficult to quickly see what's happening.
|
||||
|
||||
The backslashes allow us to copy the command directly into the console to run it, without needing to edit it.
|
||||
|
||||
5. Include only the important information that you think will help the developer to quickly identify the problem.
|
||||
|
||||
For example applications often create huge amounts of logs. Ask yourself whether providing all or parts of the log is useful.
|
||||
|
||||
Pasting a 100-1000 lines of log into the issue is an immediate turn off, since it will take a lot of time to figure out where the pertinent parts of the log are.
|
||||
|
||||
Attaching a full log can be helpful if it's done as an attachment, if it's enclosed in the following html code in the comment editor window:
|
||||
|
||||
```
|
||||
<details>
|
||||
<summary>Full log</summary>
|
||||
<pre>
|
||||
|
||||
many
|
||||
lines
|
||||
go
|
||||
here
|
||||
|
||||
</pre>
|
||||
</details>
|
||||
```
|
||||
|
||||
which would result in the following entry, which can be opened if desired, but otherwise takes little space.
|
||||
|
||||
<details>
|
||||
<summary>Full log</summary>
|
||||
<pre>
|
||||
many
|
||||
lines
|
||||
go
|
||||
here
|
||||
</pre>
|
||||
</details>
|
||||
|
||||
You could also provide a link to a pastebin service, but this is less beneficial since those links tend to expire quickly and future readers of your issue might not be able to access that log file anymore and may lack some context.
|
||||
|
||||
6. If this is an issue in your code, do try to reduce that code to a minimal example that still demonstrates the problem. Please ask at the forums if you have a hard time figuring how to do that. Please realize that we don't have the luxury of having time to try and understand all of your custom code.
|
||||
|
||||
If you really tried to make a short reproducible code but couldn't figure it out, it might be that having a traceback will give the developer enough information to know what's going on. But if it is not enough and we can't reproduce the problem, we can't really solve it.
|
||||
|
||||
Do not dispair if you can't figure it out from the begining, just share what you can and perhaps someone else will be able to help you at the forums.
|
||||
|
||||
7. If you forked off some of this project's code or example applications, please, do not ask us to go into your code repository and figure out what you may have done. The code is already very complex and unless there is an easy way to do a diff and it's a small diff, it won't be possible to find someone with time on their hands to make a lengthy investigation. Albeit, you might find someone at the forums who will be generous to do this for you.
|
||||
|
||||
8. Before reporting an issue, first, always try to update your environment to the latest official version of this library. We have no resources to go and debug older revisions, which could easily have bugs that have been fixed in the latest released version.
|
||||
|
||||
We understand that this is not always possible, especially when APIs change, in which case file an issue against the highest library version your environment can support.
|
||||
|
||||
Of course, if you upgrade the library, always retest that the problem is still there.
|
||||
|
||||
9. Please do not ask us to reproduce an issue with your custom data, since we don't have it. So, either you should use some existing dataset supported by HF datasets or you need to supply a code that generates a small sample on the fly, or some another quick and simple way to get it.
|
||||
|
||||
Please do not send us any non-public domain data that may require a license or a permission to be used.
|
||||
|
||||
10. Do not tag multiple developers on the issue unless you know this is expected, either because you asked them and they gave you an explicit permission to tag them or the issue template instructs you to do so.
|
||||
|
||||
The "who to tag for what domain" part of the issue template is there to help users direct their questions to the right developers who are designated maintainers of project's specific domains. They can then decide at their own discretion to tag other developers if they feel it'd help move the issue forward.
|
||||
|
||||
We currently don't have a triage service and we trust your capacity to identify the right domain and thus the persons to tag in your issue. If you are not sure, please use the forums to ask for guidance.
|
||||
|
||||
When in doubt, err on the side of not tagging a given person. If you tag multiple people out of context or permission don't be surprised if you get no response at all. Please remember that every time you tag someone, they get a notification and you're taking their time without their permission. Please be sensitive to that.
|
||||
|
||||
If you got helped by one of the developers in the past please don't tag them in future issues, unless they are listed in the issue template for the domain you are asking about or that developer gave you an explicit permission to tag them in future issues.
|
||||
|
||||
If you see a certain developer doing multiple and/or recent commits into a specific area of the project that you feel is relevant to your issue, it is not a good reason to tag them. Various developers may be fixing things that prevent them from moving forward, but often their work is focused on a totally different domain. And while they may or may not know how to help you with the problem at hand, it would benefit the whole community much more if they focus on the domain of their unique expertise.
|
||||
|
||||
11. Use the Edit button. Take your time, and re-read and improve the wording and formatting to make your posts and comments as easy to understand as possible.
|
||||
|
||||
Avoid posting multiple comments in a row, as each comment generates a notification for the developers tagged in that issue. If you happened to post multiple comments in a row, and nobody followed up yet - consider merging those into one or a few comments while editing the combined content to be coherent.
|
||||
|
||||
If you choose to edit your older comments after others posted follow up comments you need to be aware that your modifications might not be noticed, so if it's not a typo fixing, try to write a new comment flagging that something has been changed in the previous comments.
|
||||
|
||||
For example, the very first comment is the most important one. If while the thread unfolds you realize that things aren't as they seemed to you originally you may want to edit the first post to reflect the up-to-date understanding of the issue at hand so that it helps those who read your issue in the future quickly understand what's going on and not need to sift through dozens of comments. It also helps to indicate that the post was edited. So, those reading the thread later can understand why there might be certain discontinuity in the information flow.
|
||||
|
||||
Use bullets and items if you have lists of items and the outcome improves overall readability.
|
||||
|
||||
Use backticks to refer to class and function names, e.g. `BartModel` and `generate` as these stand out and improve the speed of a reader's comprehension.
|
||||
|
||||
Try not use italics and bold text too much as these often make the text more difficult to read.
|
||||
|
||||
|
||||
12. If you are cross-referencing a specific comment in a given thread or another issue, always link to that specific comment, rather than using the issue link. If you do the latter it could be quite impossible to find which specific comment you're referring to.
|
||||
|
||||
To get the link to the specific comment do not copy the url from the location bar of your browser, but instead, click the `...` icon in the upper right corner of the comment and then select "Copy Link".
|
||||
|
||||
For example the first link is a link to an issue, and the second to a specific comment in the same issue:
|
||||
|
||||
1. https://github.com/huggingface/transformers/issues/9257
|
||||
2. https://github.com/huggingface/transformers/issues/9257#issuecomment-749945162
|
||||
|
||||
|
||||
13. If you are replying to a last comment, it's totally fine to make your reply with just your comment in it. The readers can follow the information flow here.
|
||||
|
||||
But if you're replying to a comment that happened some comments back it's always a good practice to quote just the relevant lines you're replying it. The `>` is used for quoting, or you can always use the menu to do so. For example your editor box will look like:
|
||||
|
||||
```
|
||||
> How big is your gpu cluster?
|
||||
|
||||
Our cluster is made of 256 gpus.
|
||||
```
|
||||
|
||||
If you are addressing multiple comments, quote the relevant parts of each before your answer. Some people use the same comment to do multiple replies, others separate them into separate comments. Either way works. The latter approach helps for linking to a specific comment.
|
||||
|
||||
In general the best way to figure out what works the best is learn from issues posted by other people - see which issues get great responses and which get little to no response - observe what the posters who received great responses did differently from those who did not.
|
||||
|
||||
Thank you for reading this somewhat lengthy document. We would like to conclude that these are not absolute rules, but a friendly advice that will help maximize the chances for us to understand what you are trying to communicate, reproduce the problem then resolve it to your satisfaction and the benefit of the whole community.
|
||||
|
||||
If after reading this document there are remaining questions on how and why or there is a need for further elucidation, please, don't hesitate to ask your question in [this thread](https://discuss.huggingface.co/t/how-to-request-support/3128).
|
||||
2
Makefile
2
Makefile
@@ -67,4 +67,4 @@ test-examples:
|
||||
# Check that docs can build
|
||||
|
||||
docs:
|
||||
cd docs && make html SPHINXOPTS="-W"
|
||||
cd docs && make html SPHINXOPTS="-W -j 4"
|
||||
|
||||
@@ -49,7 +49,7 @@ limitations under the License.
|
||||
|
||||
## Online demos
|
||||
|
||||
You can test most of our models directly on their pages from the [model hub](https://huggingface.co/models). We also offer an [inference API](https://huggingface.co/pricing) to use those models.
|
||||
You can test most of our models directly on their pages from the [model hub](https://huggingface.co/models). We also offer [private model hosting, versioning, & an inference API](https://huggingface.co/pricing) to use those models.
|
||||
|
||||
Here are a few examples:
|
||||
- [Masked word completion with BERT](https://huggingface.co/bert-base-uncased?text=Paris+is+the+%5BMASK%5D+of+France)
|
||||
@@ -195,6 +195,7 @@ Current number of checkpoints: ** (from Google) released with the paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova.
|
||||
1. **[BERT For Sequence Generation](https://huggingface.co/transformers/model_doc/bertgeneration.html)** (from Google) released with the paper [Leveraging Pre-trained Checkpoints for Sequence Generation Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn.
|
||||
1. **[Blenderbot](https://huggingface.co/transformers/model_doc/blenderbot.html)** (from Facebook) released with the paper [Recipes for building an open-domain chatbot](https://arxiv.org/abs/2004.13637) by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston.
|
||||
1. **[BlenderbotSmall](https://huggingface.co/transformers/model_doc/blenderbot_small.html)** (from Facebook) released with the paper [Recipes for building an open-domain chatbot](https://arxiv.org/abs/2004.13637) by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston.
|
||||
1. **[CamemBERT](https://huggingface.co/transformers/model_doc/camembert.html)** (from Inria/Facebook/Sorbonne) released with the paper [CamemBERT: a Tasty French Language Model](https://arxiv.org/abs/1911.03894) by Louis Martin*, Benjamin Muller*, Pedro Javier Ortiz Suárez*, Yoann Dupont, Laurent Romary, Éric Villemonte de la Clergerie, Djamé Seddah and Benoît Sagot.
|
||||
1. **[CTRL](https://huggingface.co/transformers/model_doc/ctrl.html)** (from Salesforce) released with the paper [CTRL: A Conditional Transformer Language Model for Controllable Generation](https://arxiv.org/abs/1909.05858) by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher.
|
||||
1. **[DeBERTa](https://huggingface.co/transformers/model_doc/deberta.html)** (from Microsoft Research) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen.
|
||||
@@ -209,6 +210,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
|
||||
1. **[GPT](https://huggingface.co/transformers/model_doc/gpt.html)** (from OpenAI) released with the paper [Improving Language Understanding by Generative Pre-Training](https://blog.openai.com/language-unsupervised/) by Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever.
|
||||
1. **[GPT-2](https://huggingface.co/transformers/model_doc/gpt2.html)** (from OpenAI) released with the paper [Language Models are Unsupervised Multitask Learners](https://blog.openai.com/better-language-models/) by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**.
|
||||
1. **[LayoutLM](https://huggingface.co/transformers/model_doc/layoutlm.html)** (from Microsoft Research Asia) released with the paper [LayoutLM: Pre-training of Text and Layout for Document Image Understanding](https://arxiv.org/abs/1912.13318) by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou.
|
||||
1. **[LED](https://huggingface.co/transformers/model_doc/led.html)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
|
||||
1. **[Longformer](https://huggingface.co/transformers/model_doc/longformer.html)** (from AllenAI) released with the paper [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.
|
||||
1. **[LXMERT](https://huggingface.co/transformers/model_doc/lxmert.html)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal.
|
||||
1. **[MarianMT](https://huggingface.co/transformers/model_doc/marian.html)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team.
|
||||
@@ -222,7 +224,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
|
||||
ultilingual BERT into [DistilmBERT](https://github.com/huggingface/transformers/tree/master/examples/distillation) and a German version of DistilBERT.
|
||||
1. **[SqueezeBert](https://huggingface.co/transformers/model_doc/squeezebert.html)** released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer.
|
||||
1. **[T5](https://huggingface.co/transformers/model_doc/t5.html)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
|
||||
1. **[TAPAS](https://huggingface.co/transformers/master/model_doc/tapas.html)** released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos.
|
||||
1. **[TAPAS](https://huggingface.co/transformers/model_doc/tapas.html)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos.
|
||||
1. **[Transformer-XL](https://huggingface.co/transformers/model_doc/transformerxl.html)** (from Google/CMU) released with the paper [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860) by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
|
||||
1. **[XLM](https://huggingface.co/transformers/model_doc/xlm.html)** (from Facebook) released together with the paper [Cross-lingual Language Model Pretraining](https://arxiv.org/abs/1901.07291) by Guillaume Lample and Alexis Conneau.
|
||||
1. **[XLM-ProphetNet](https://huggingface.co/transformers/model_doc/xlmprophetnet.html)** (from Microsoft Research) released with the paper [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
// These two things need to be updated at each release for the version selector.
|
||||
// Last stable version
|
||||
const stableVersion = "v4.0.0"
|
||||
const stableVersion = "v4.1.1"
|
||||
// Dictionary doc folder to label. The last stable version should have an empty key.
|
||||
const versionMapping = {
|
||||
"master": "master",
|
||||
"": "v4.0.0 (stable)",
|
||||
"": "v4.1.1 (stable)",
|
||||
"v4.0.1": "v4.0.0/v4.0.1",
|
||||
"v3.5.1": "v3.5.0/v3.5.1",
|
||||
"v3.4.0": "v3.4.0",
|
||||
"v3.3.1": "v3.3.0/v3.3.1",
|
||||
|
||||
@@ -15,8 +15,8 @@ Benchmarks
|
||||
|
||||
Let's take a look at how 🤗 Transformer models can be benchmarked, best practices, and already available benchmarks.
|
||||
|
||||
A notebook explaining in more detail how to benchmark 🤗 Transformer models can be found `here
|
||||
<https://github.com/huggingface/transformers/blob/master/notebooks/05-benchmark.ipynb>`__.
|
||||
A notebook explaining in more detail how to benchmark 🤗 Transformer models can be found :prefix_link:`here
|
||||
<notebooks/05-benchmark.ipynb>`.
|
||||
|
||||
How to benchmark 🤗 Transformer models
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
@@ -99,6 +99,7 @@ An instantiated benchmark object can then simply be run by calling ``benchmark.r
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
==================== ENVIRONMENT INFORMATION ====================
|
||||
|
||||
- transformers_version: 2.11.0
|
||||
- framework: PyTorch
|
||||
- use_torchscript: False
|
||||
@@ -145,6 +146,7 @@ An instantiated benchmark object can then simply be run by calling ``benchmark.r
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
==================== ENVIRONMENT INFORMATION ====================
|
||||
|
||||
- transformers_version: 2.11.0
|
||||
- framework: Tensorflow
|
||||
- use_xla: False
|
||||
@@ -228,6 +230,7 @@ configurations must be inserted with the benchmark args as follows.
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
==================== ENVIRONMENT INFORMATION ====================
|
||||
|
||||
- transformers_version: 2.11.0
|
||||
- framework: PyTorch
|
||||
- use_torchscript: False
|
||||
@@ -297,6 +300,7 @@ configurations must be inserted with the benchmark args as follows.
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
==================== ENVIRONMENT INFORMATION ====================
|
||||
|
||||
- transformers_version: 2.11.0
|
||||
- framework: Tensorflow
|
||||
- use_xla: False
|
||||
@@ -353,5 +357,5 @@ The approach is detailed in the `following blogpost
|
||||
available `here
|
||||
<https://docs.google.com/spreadsheets/d/1sryqufw2D0XlUH4sq3e9Wnxu5EAQkaohzrJbd5HdQ_w/edit?usp=sharing>`__.
|
||||
|
||||
With the new `benchmark` tools, it is easier than ever to share your benchmark results with the community `here
|
||||
<https://github.com/huggingface/transformers/blob/master/examples/benchmarking/README.md>`__.
|
||||
With the new `benchmark` tools, it is easier than ever to share your benchmark results with the community
|
||||
:prefix_link:`here <examples/benchmarking/README.md>`.
|
||||
|
||||
@@ -33,6 +33,6 @@ help people access the inner representations, mainly adapted from the great work
|
||||
* retrieving heads output values and gradients to be able to compute head importance score and prune head as explained
|
||||
in https://arxiv.org/abs/1905.10650.
|
||||
|
||||
To help you understand and use these features, we have added a specific example script: `bertology.py
|
||||
<https://github.com/huggingface/transformers/blob/master/examples/research_projects/bertology/run_bertology.py>`_ while
|
||||
extract information and prune a model pre-trained on GLUE.
|
||||
To help you understand and use these features, we have added a specific example script: :prefix_link:`bertology.py
|
||||
<examples/research_projects/bertology/run_bertology.py>` while extract information and prune a model pre-trained on
|
||||
GLUE.
|
||||
|
||||
@@ -26,8 +26,11 @@ author = u'huggingface'
|
||||
# The short X.Y version
|
||||
version = u''
|
||||
# The full version, including alpha/beta/rc tags
|
||||
release = u'4.1.0'
|
||||
|
||||
release = u'4.2.0'
|
||||
# Prefix link to point to master, comment this during version release and uncomment below line
|
||||
extlinks = {'prefix_link': ('https://github.com/huggingface/transformers/blob/master/%s', '')}
|
||||
# Prefix link to always point to corresponding version, uncomment this during version release
|
||||
# extlinks = {'prefix_link': ('https://github.com/huggingface/transformers/blob/v'+ release + '/%s', '')}
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
|
||||
@@ -40,6 +43,7 @@ release = u'4.1.0'
|
||||
# ones.
|
||||
extensions = [
|
||||
'sphinx.ext.autodoc',
|
||||
'sphinx.ext.extlinks',
|
||||
'sphinx.ext.coverage',
|
||||
'sphinx.ext.napoleon',
|
||||
'recommonmark',
|
||||
|
||||
@@ -27,9 +27,8 @@ BERT
|
||||
|
||||
You can convert any TensorFlow checkpoint for BERT (in particular `the pre-trained models released by Google
|
||||
<https://github.com/google-research/bert#pre-trained-models>`_\ ) in a PyTorch save file by using the
|
||||
`convert_bert_original_tf_checkpoint_to_pytorch.py
|
||||
<https://github.com/huggingface/transformers/blob/master/src/transformers/convert_bert_original_tf_checkpoint_to_pytorch.py>`_
|
||||
script.
|
||||
:prefix_link:`convert_bert_original_tf_checkpoint_to_pytorch.py
|
||||
<src/transformers/convert_bert_original_tf_checkpoint_to_pytorch.py>` script.
|
||||
|
||||
This CLI takes as input a TensorFlow checkpoint (three files starting with ``bert_model.ckpt``\ ) and the associated
|
||||
configuration file (\ ``bert_config.json``\ ), and creates a PyTorch model for this configuration, loads the weights
|
||||
@@ -66,9 +65,8 @@ ALBERT
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Convert TensorFlow model checkpoints of ALBERT to PyTorch using the
|
||||
`convert_albert_original_tf_checkpoint_to_pytorch.py
|
||||
<https://github.com/huggingface/transformers/blob/master/src/transformers/convert_bert_original_tf_checkpoint_to_pytorch.py>`_
|
||||
script.
|
||||
:prefix_link:`convert_albert_original_tf_checkpoint_to_pytorch.py
|
||||
<src/transformers/convert_bert_original_tf_checkpoint_to_pytorch.py>` script.
|
||||
|
||||
The CLI takes as input a TensorFlow checkpoint (three files starting with ``model.ckpt-best``\ ) and the accompanying
|
||||
configuration file (\ ``albert_config.json``\ ), then creates and saves a PyTorch model. To run this conversion you
|
||||
|
||||
@@ -558,12 +558,15 @@ we can use the built in :func:`~transformers.BatchEncoding.char_to_token` method
|
||||
end_positions = []
|
||||
for i in range(len(answers)):
|
||||
start_positions.append(encodings.char_to_token(i, answers[i]['answer_start']))
|
||||
end_positions.append(encodings.char_to_token(i, answers[i]['answer_end'] - 1))
|
||||
# if None, the answer passage has been truncated
|
||||
end_positions.append(encodings.char_to_token(i, answers[i]['answer_end']))
|
||||
|
||||
# if start position is None, the answer passage has been truncated
|
||||
if start_positions[-1] is None:
|
||||
start_positions[-1] = tokenizer.model_max_length
|
||||
|
||||
# if end position is None, the 'char_to_token' function points to the space before the correct token - > add + 1
|
||||
if end_positions[-1] is None:
|
||||
end_positions[-1] = tokenizer.model_max_length
|
||||
end_positions[-1] = encodings.char_to_token(i, answers[i]['answer_end'] + 1)
|
||||
encodings.update({'start_positions': start_positions, 'end_positions': end_positions})
|
||||
|
||||
add_token_positions(train_encodings, train_answers)
|
||||
|
||||
@@ -226,7 +226,7 @@ Contrary to RNNs that have the position of each token embedded within them, tran
|
||||
each token. Therefore, the position IDs (``position_ids``) are used by the model to identify each token's position in
|
||||
the list of tokens.
|
||||
|
||||
They are an optional parameter. If no ``position_ids`` is passed to the model, the IDs are automatically created as
|
||||
They are an optional parameter. If no ``position_ids`` are passed to the model, the IDs are automatically created as
|
||||
absolute positional embeddings.
|
||||
|
||||
Absolute positional embeddings are selected in the range ``[0, config.max_position_embeddings - 1]``. Some models use
|
||||
|
||||
@@ -100,98 +100,103 @@ and conversion utilities for the following models:
|
||||
6. :doc:`Blenderbot <model_doc/blenderbot>` (from Facebook) released with the paper `Recipes for building an
|
||||
open-domain chatbot <https://arxiv.org/abs/2004.13637>`__ by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary
|
||||
Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston.
|
||||
7. :doc:`CamemBERT <model_doc/camembert>` (from Inria/Facebook/Sorbonne) released with the paper `CamemBERT: a Tasty
|
||||
7. :doc:`BlenderbotSmall <model_doc/blenderbot_small>` (from Facebook) released with the paper `Recipes for building an
|
||||
open-domain chatbot <https://arxiv.org/abs/2004.13637>`__ by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary
|
||||
Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston.
|
||||
8. :doc:`CamemBERT <model_doc/camembert>` (from Inria/Facebook/Sorbonne) released with the paper `CamemBERT: a Tasty
|
||||
French Language Model <https://arxiv.org/abs/1911.03894>`__ by Louis Martin*, Benjamin Muller*, Pedro Javier Ortiz
|
||||
Suárez*, Yoann Dupont, Laurent Romary, Éric Villemonte de la Clergerie, Djamé Seddah and Benoît Sagot.
|
||||
8. :doc:`CTRL <model_doc/ctrl>` (from Salesforce) released with the paper `CTRL: A Conditional Transformer Language
|
||||
9. :doc:`CTRL <model_doc/ctrl>` (from Salesforce) released with the paper `CTRL: A Conditional Transformer Language
|
||||
Model for Controllable Generation <https://arxiv.org/abs/1909.05858>`__ by Nitish Shirish Keskar*, Bryan McCann*,
|
||||
Lav R. Varshney, Caiming Xiong and Richard Socher.
|
||||
9. :doc:`DeBERTa <model_doc/deberta>` (from Microsoft Research) released with the paper `DeBERTa: Decoding-enhanced
|
||||
BERT with Disentangled Attention <https://arxiv.org/abs/2006.03654>`__ by Pengcheng He, Xiaodong Liu, Jianfeng Gao,
|
||||
Weizhu Chen.
|
||||
10. :doc:`DialoGPT <model_doc/dialogpt>` (from Microsoft Research) released with the paper `DialoGPT: Large-Scale
|
||||
10. :doc:`DeBERTa <model_doc/deberta>` (from Microsoft Research) released with the paper `DeBERTa: Decoding-enhanced
|
||||
BERT with Disentangled Attention <https://arxiv.org/abs/2006.03654>`__ by Pengcheng He, Xiaodong Liu, Jianfeng Gao,
|
||||
Weizhu Chen.
|
||||
11. :doc:`DialoGPT <model_doc/dialogpt>` (from Microsoft Research) released with the paper `DialoGPT: Large-Scale
|
||||
Generative Pre-training for Conversational Response Generation <https://arxiv.org/abs/1911.00536>`__ by Yizhe
|
||||
Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan.
|
||||
11. :doc:`DistilBERT <model_doc/distilbert>` (from HuggingFace), released together with the paper `DistilBERT, a
|
||||
12. :doc:`DistilBERT <model_doc/distilbert>` (from HuggingFace), released together with the paper `DistilBERT, a
|
||||
distilled version of BERT: smaller, faster, cheaper and lighter <https://arxiv.org/abs/1910.01108>`__ by Victor
|
||||
Sanh, Lysandre Debut and Thomas Wolf. The same method has been applied to compress GPT2 into `DistilGPT2
|
||||
<https://github.com/huggingface/transformers/tree/master/examples/distillation>`__, RoBERTa into `DistilRoBERTa
|
||||
<https://github.com/huggingface/transformers/tree/master/examples/distillation>`__, Multilingual BERT into
|
||||
`DistilmBERT <https://github.com/huggingface/transformers/tree/master/examples/distillation>`__ and a German
|
||||
version of DistilBERT.
|
||||
12. :doc:`DPR <model_doc/dpr>` (from Facebook) released with the paper `Dense Passage Retrieval for Open-Domain
|
||||
13. :doc:`DPR <model_doc/dpr>` (from Facebook) released with the paper `Dense Passage Retrieval for Open-Domain
|
||||
Question Answering <https://arxiv.org/abs/2004.04906>`__ by Vladimir Karpukhin, Barlas Oğuz, Sewon Min, Patrick
|
||||
Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
|
||||
13. :doc:`ELECTRA <model_doc/electra>` (from Google Research/Stanford University) released with the paper `ELECTRA:
|
||||
14. :doc:`ELECTRA <model_doc/electra>` (from Google Research/Stanford University) released with the paper `ELECTRA:
|
||||
Pre-training text encoders as discriminators rather than generators <https://arxiv.org/abs/2003.10555>`__ by Kevin
|
||||
Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning.
|
||||
14. :doc:`FlauBERT <model_doc/flaubert>` (from CNRS) released with the paper `FlauBERT: Unsupervised Language Model
|
||||
15. :doc:`FlauBERT <model_doc/flaubert>` (from CNRS) released with the paper `FlauBERT: Unsupervised Language Model
|
||||
Pre-training for French <https://arxiv.org/abs/1912.05372>`__ by Hang Le, Loïc Vial, Jibril Frej, Vincent Segonne,
|
||||
Maximin Coavoux, Benjamin Lecouteux, Alexandre Allauzen, Benoît Crabbé, Laurent Besacier, Didier Schwab.
|
||||
15. :doc:`Funnel Transformer <model_doc/funnel>` (from CMU/Google Brain) released with the paper `Funnel-Transformer:
|
||||
16. :doc:`Funnel Transformer <model_doc/funnel>` (from CMU/Google Brain) released with the paper `Funnel-Transformer:
|
||||
Filtering out Sequential Redundancy for Efficient Language Processing <https://arxiv.org/abs/2006.03236>`__ by
|
||||
Zihang Dai, Guokun Lai, Yiming Yang, Quoc V. Le.
|
||||
16. :doc:`GPT <model_doc/gpt>` (from OpenAI) released with the paper `Improving Language Understanding by Generative
|
||||
17. :doc:`GPT <model_doc/gpt>` (from OpenAI) released with the paper `Improving Language Understanding by Generative
|
||||
Pre-Training <https://blog.openai.com/language-unsupervised/>`__ by Alec Radford, Karthik Narasimhan, Tim Salimans
|
||||
and Ilya Sutskever.
|
||||
17. :doc:`GPT-2 <model_doc/gpt2>` (from OpenAI) released with the paper `Language Models are Unsupervised Multitask
|
||||
18. :doc:`GPT-2 <model_doc/gpt2>` (from OpenAI) released with the paper `Language Models are Unsupervised Multitask
|
||||
Learners <https://blog.openai.com/better-language-models/>`__ by Alec Radford*, Jeffrey Wu*, Rewon Child, David
|
||||
Luan, Dario Amodei** and Ilya Sutskever**.
|
||||
18. :doc:`LayoutLM <model_doc/layoutlm>` (from Microsoft Research Asia) released with the paper `LayoutLM: Pre-training
|
||||
19. :doc:`LayoutLM <model_doc/layoutlm>` (from Microsoft Research Asia) released with the paper `LayoutLM: Pre-training
|
||||
of Text and Layout for Document Image Understanding <https://arxiv.org/abs/1912.13318>`__ by Yiheng Xu, Minghao Li,
|
||||
Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou.
|
||||
19. :doc:`Longformer <model_doc/longformer>` (from AllenAI) released with the paper `Longformer: The Long-Document
|
||||
20. :doc:`LED <model_doc/led>` (from AllenAI) released with the paper `Longformer: The Long-Document Transformer
|
||||
<https://arxiv.org/abs/2004.05150>`__ by Iz Beltagy, Matthew E. Peters, Arman Cohan.
|
||||
21. :doc:`Longformer <model_doc/longformer>` (from AllenAI) released with the paper `Longformer: The Long-Document
|
||||
Transformer <https://arxiv.org/abs/2004.05150>`__ by Iz Beltagy, Matthew E. Peters, Arman Cohan.
|
||||
20. :doc:`LXMERT <model_doc/lxmert>` (from UNC Chapel Hill) released with the paper `LXMERT: Learning Cross-Modality
|
||||
22. :doc:`LXMERT <model_doc/lxmert>` (from UNC Chapel Hill) released with the paper `LXMERT: Learning Cross-Modality
|
||||
Encoder Representations from Transformers for Open-Domain Question Answering <https://arxiv.org/abs/1908.07490>`__
|
||||
by Hao Tan and Mohit Bansal.
|
||||
21. :doc:`MarianMT <model_doc/marian>` Machine translation models trained using `OPUS <http://opus.nlpl.eu/>`__ data by
|
||||
23. :doc:`MarianMT <model_doc/marian>` Machine translation models trained using `OPUS <http://opus.nlpl.eu/>`__ data by
|
||||
Jörg Tiedemann. The `Marian Framework <https://marian-nmt.github.io/>`__ is being developed by the Microsoft
|
||||
Translator Team.
|
||||
22. :doc:`MBart <model_doc/mbart>` (from Facebook) released with the paper `Multilingual Denoising Pre-training for
|
||||
24. :doc:`MBart <model_doc/mbart>` (from Facebook) released with the paper `Multilingual Denoising Pre-training for
|
||||
Neural Machine Translation <https://arxiv.org/abs/2001.08210>`__ by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li,
|
||||
Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer.
|
||||
23. :doc:`MPNet <model_doc/mpnet>` (from Microsoft Research) released with the paper `MPNet: Masked and Permuted
|
||||
25. :doc:`MPNet <model_doc/mpnet>` (from Microsoft Research) released with the paper `MPNet: Masked and Permuted
|
||||
Pre-training for Language Understanding <https://arxiv.org/abs/2004.09297>`__ by Kaitao Song, Xu Tan, Tao Qin,
|
||||
Jianfeng Lu, Tie-Yan Liu.
|
||||
24. :doc:`MT5 <model_doc/mt5>` (from Google AI) released with the paper `mT5: A massively multilingual pre-trained
|
||||
26. :doc:`MT5 <model_doc/mt5>` (from Google AI) released with the paper `mT5: A massively multilingual pre-trained
|
||||
text-to-text transformer <https://arxiv.org/abs/2010.11934>`__ by Linting Xue, Noah Constant, Adam Roberts, Mihir
|
||||
Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.
|
||||
25. :doc:`Pegasus <model_doc/pegasus>` (from Google) released with the paper `PEGASUS: Pre-training with Extracted
|
||||
27. :doc:`Pegasus <model_doc/pegasus>` (from Google) released with the paper `PEGASUS: Pre-training with Extracted
|
||||
Gap-sentences for Abstractive Summarization <https://arxiv.org/abs/1912.08777>`__> by Jingqing Zhang, Yao Zhao,
|
||||
Mohammad Saleh and Peter J. Liu.
|
||||
26. :doc:`ProphetNet <model_doc/prophetnet>` (from Microsoft Research) released with the paper `ProphetNet: Predicting
|
||||
28. :doc:`ProphetNet <model_doc/prophetnet>` (from Microsoft Research) released with the paper `ProphetNet: Predicting
|
||||
Future N-gram for Sequence-to-Sequence Pre-training <https://arxiv.org/abs/2001.04063>`__ by Yu Yan, Weizhen Qi,
|
||||
Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
|
||||
27. :doc:`Reformer <model_doc/reformer>` (from Google Research) released with the paper `Reformer: The Efficient
|
||||
29. :doc:`Reformer <model_doc/reformer>` (from Google Research) released with the paper `Reformer: The Efficient
|
||||
Transformer <https://arxiv.org/abs/2001.04451>`__ by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
|
||||
28. :doc:`RoBERTa <model_doc/roberta>` (from Facebook), released together with the paper a `Robustly Optimized BERT
|
||||
30. :doc:`RoBERTa <model_doc/roberta>` (from Facebook), released together with the paper a `Robustly Optimized BERT
|
||||
Pretraining Approach <https://arxiv.org/abs/1907.11692>`__ by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar
|
||||
Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. ultilingual BERT into `DistilmBERT
|
||||
<https://github.com/huggingface/transformers/tree/master/examples/distillation>`__ and a German version of
|
||||
DistilBERT.
|
||||
29. :doc:`SqueezeBert <model_doc/squeezebert>` released with the paper `SqueezeBERT: What can computer vision teach NLP
|
||||
31. :doc:`SqueezeBert <model_doc/squeezebert>` released with the paper `SqueezeBERT: What can computer vision teach NLP
|
||||
about efficient neural networks? <https://arxiv.org/abs/2006.11316>`__ by Forrest N. Iandola, Albert E. Shaw, Ravi
|
||||
Krishna, and Kurt W. Keutzer.
|
||||
30. :doc:`T5 <model_doc/t5>` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a
|
||||
32. :doc:`T5 <model_doc/t5>` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a
|
||||
Unified Text-to-Text Transformer <https://arxiv.org/abs/1910.10683>`__ by Colin Raffel and Noam Shazeer and Adam
|
||||
Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
|
||||
31. `TAPAS <https://huggingface.co/transformers/master/model_doc/tapas.html>`__ released with the paper `TAPAS: Weakly
|
||||
Supervised Table Parsing via Pre-training <https://arxiv.org/abs/2004.02349>`__ by Jonathan Herzig, Paweł Krzysztof
|
||||
Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos.
|
||||
32. :doc:`Transformer-XL <model_doc/transformerxl>` (from Google/CMU) released with the paper `Transformer-XL:
|
||||
33. :doc:`TAPAS <model_doc/tapas>` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via
|
||||
Pre-training <https://arxiv.org/abs/2004.02349>`__ by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller,
|
||||
Francesco Piccinno and Julian Martin Eisenschlos.
|
||||
34. :doc:`Transformer-XL <model_doc/transformerxl>` (from Google/CMU) released with the paper `Transformer-XL:
|
||||
Attentive Language Models Beyond a Fixed-Length Context <https://arxiv.org/abs/1901.02860>`__ by Zihang Dai*,
|
||||
Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
|
||||
33. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model
|
||||
35. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model
|
||||
Pretraining <https://arxiv.org/abs/1901.07291>`__ by Guillaume Lample and Alexis Conneau.
|
||||
34. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet:
|
||||
36. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet:
|
||||
Predicting Future N-gram for Sequence-to-Sequence Pre-training <https://arxiv.org/abs/2001.04063>`__ by Yu Yan,
|
||||
Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
|
||||
35. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised
|
||||
37. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised
|
||||
Cross-lingual Representation Learning at Scale <https://arxiv.org/abs/1911.02116>`__ by Alexis Conneau*, Kartikay
|
||||
Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke
|
||||
Zettlemoyer and Veselin Stoyanov.
|
||||
36. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `XLNet: Generalized Autoregressive
|
||||
38. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `XLNet: Generalized Autoregressive
|
||||
Pretraining for Language Understanding <https://arxiv.org/abs/1906.08237>`__ by Zhilin Yang*, Zihang Dai*, Yiming
|
||||
Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
|
||||
|
||||
@@ -220,6 +225,8 @@ TensorFlow and/or Flax.
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| Blenderbot | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| BlenderbotSmall | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| CTRL | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| CamemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
@@ -240,6 +247,8 @@ TensorFlow and/or Flax.
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| Funnel Transformer | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| LED | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| LXMERT | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| LayoutLM | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
@@ -356,8 +365,10 @@ TensorFlow and/or Flax.
|
||||
model_doc/bart
|
||||
model_doc/barthez
|
||||
model_doc/bert
|
||||
model_doc/bertweet
|
||||
model_doc/bertgeneration
|
||||
model_doc/blenderbot
|
||||
model_doc/blenderbot_small
|
||||
model_doc/camembert
|
||||
model_doc/ctrl
|
||||
model_doc/deberta
|
||||
@@ -369,7 +380,9 @@ TensorFlow and/or Flax.
|
||||
model_doc/flaubert
|
||||
model_doc/fsmt
|
||||
model_doc/funnel
|
||||
model_doc/herbert
|
||||
model_doc/layoutlm
|
||||
model_doc/led
|
||||
model_doc/longformer
|
||||
model_doc/lxmert
|
||||
model_doc/marian
|
||||
@@ -380,6 +393,7 @@ TensorFlow and/or Flax.
|
||||
model_doc/gpt
|
||||
model_doc/gpt2
|
||||
model_doc/pegasus
|
||||
model_doc/phobert
|
||||
model_doc/prophetnet
|
||||
model_doc/rag
|
||||
model_doc/reformer
|
||||
|
||||
@@ -13,13 +13,102 @@
|
||||
Utilities for Generation
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
This page lists all the utility functions used by :meth:`~transformers.PretrainedModel.generate`,
|
||||
:meth:`~transformers.PretrainedModel.greedy_search`, :meth:`~transformers.PretrainedModel.sample`,
|
||||
:meth:`~transformers.PretrainedModel.beam_search`, :meth:`~transformers.PretrainedModel.beam_sample`, and
|
||||
:meth:`~transformers.PretrainedModel.group_beam_search`.
|
||||
This page lists all the utility functions used by :meth:`~transformers.PreTrainedModel.generate`,
|
||||
:meth:`~transformers.PreTrainedModel.greedy_search`, :meth:`~transformers.PreTrainedModel.sample`,
|
||||
:meth:`~transformers.PreTrainedModel.beam_search`, :meth:`~transformers.PreTrainedModel.beam_sample`, and
|
||||
:meth:`~transformers.PreTrainedModel.group_beam_search`.
|
||||
|
||||
Most of those are only useful if you are studying the code of the generate methods in the library.
|
||||
|
||||
Generate Outputs
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The output of :meth:`~transformers.PreTrainedModel.generate` is an instance of a subclass of
|
||||
:class:`~transformers.file_utils.ModelOutput`. This output is a data structure containing all the information returned
|
||||
by :meth:`~transformers.PreTrainedModel.generate`, but that can also be used as tuple or dictionary.
|
||||
|
||||
Here's an example:
|
||||
|
||||
.. code-block::
|
||||
|
||||
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
||||
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
model = GPT2LMHeadModel.from_pretrained('gpt2')
|
||||
|
||||
inputs = tokenizer("Hello, my dog is cute and ", return_tensors="pt")
|
||||
generation_output = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
|
||||
|
||||
The ``generation_output`` object is a :class:`~transformers.generation_utils.GreedySearchDecoderOnlyOutput`, as we can
|
||||
see in the documentation of that class below, it means it has the following attributes:
|
||||
|
||||
- ``sequences``: the generated sequences of tokens
|
||||
- ``scores`` (optional): the prediction scores of the language modelling head, for each generation step
|
||||
- ``hidden_states`` (optional): the hidden states of the model, for each generation step
|
||||
- ``attentions`` (optional): the attention weights of the model, for each generation step
|
||||
|
||||
Here we have the ``scores`` since we passed along ``output_scores=True``, but we don't have ``hidden_states`` and
|
||||
``attentions`` because we didn't pass ``output_hidden_states=True`` or ``output_attentions=True``.
|
||||
|
||||
You can access each attribute as you would usually do, and if that attribute has not been returned by the model, you
|
||||
will get ``None``. Here for instance ``generation_output.scores`` are all the generated prediction scores of the
|
||||
language modeling head, and ``generation_output.attentions`` is ``None``.
|
||||
|
||||
When using our ``generation_output`` object as a tuple, it only keeps the attributes that don't have ``None`` values.
|
||||
Here, for instance, it has two elements, ``loss`` then ``logits``, so
|
||||
|
||||
.. code-block::
|
||||
|
||||
generation_output[:2]
|
||||
|
||||
will return the tuple ``(generation_output.sequences, generation_output.scores)`` for instance.
|
||||
|
||||
When using our ``generation_output`` object as a dictionary, it only keeps the attributes that don't have ``None``
|
||||
values. Here, for instance, it has two keys that are ``sequences`` and ``scores``.
|
||||
|
||||
We document here all output types.
|
||||
|
||||
|
||||
GreedySearchOutput
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: transformers.generation_utils.GreedySearchDecoderOnlyOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.generation_utils.GreedySearchEncoderDecoderOutput
|
||||
:members:
|
||||
|
||||
|
||||
SampleOutput
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: transformers.generation_utils.SampleDecoderOnlyOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.generation_utils.SampleEncoderDecoderOutput
|
||||
:members:
|
||||
|
||||
|
||||
BeamSearchOutput
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: transformers.generation_utils.BeamSearchDecoderOnlyOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.generation_utils.BeamSearchEncoderDecoderOutput
|
||||
:members:
|
||||
|
||||
|
||||
BeamSampleOutput
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: transformers.generation_utils.BeamSampleDecoderOnlyOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.generation_utils.BeamSampleEncoderDecoderOutput
|
||||
:members:
|
||||
|
||||
|
||||
LogitsProcessor
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
||||
@@ -43,6 +43,10 @@ Schedules
|
||||
Learning Rate Schedules (Pytorch)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: transformers.SchedulerType
|
||||
|
||||
.. autofunction:: transformers.get_scheduler
|
||||
|
||||
.. autofunction:: transformers.get_constant_schedule
|
||||
|
||||
|
||||
|
||||
@@ -126,13 +126,6 @@ CausalLMOutputWithCrossAttentions
|
||||
:members:
|
||||
|
||||
|
||||
CausalLMOutputWithPastAndCrossAttentions
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.modeling_outputs.CausalLMOutputWithPastAndCrossAttentions
|
||||
:members:
|
||||
|
||||
|
||||
CausalLMOutputWithPast
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
||||
@@ -168,5 +168,5 @@ Using `tensorflow_datasets` is as easy as using a data file:
|
||||
)
|
||||
|
||||
|
||||
Another example using these processors is given in the `run_squad.py
|
||||
<https://github.com/huggingface/transformers/blob/master/examples/question-answering/run_squad.py>`__ script.
|
||||
Another example using these processors is given in the :prefix_link:`run_squad.py
|
||||
<examples/question-answering/run_squad.py>` script.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
..
|
||||
..
|
||||
Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
@@ -63,6 +63,13 @@ Trainer
|
||||
:members:
|
||||
|
||||
|
||||
Seq2SeqTrainer
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.Seq2SeqTrainer
|
||||
:members: evaluate, predict
|
||||
|
||||
|
||||
TFTrainer
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@@ -77,8 +84,450 @@ TrainingArguments
|
||||
:members:
|
||||
|
||||
|
||||
Seq2SeqTrainingArguments
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.Seq2SeqTrainingArguments
|
||||
:members:
|
||||
|
||||
|
||||
TFTrainingArguments
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFTrainingArguments
|
||||
:members:
|
||||
|
||||
|
||||
Trainer Integrations
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
||||
|
||||
The :class:`~transformers.Trainer` has been extended to support libraries that may dramatically improve your training
|
||||
time and fit much bigger models.
|
||||
|
||||
Currently it supports third party solutions, `DeepSpeed <https://github.com/microsoft/DeepSpeed>`__ and `FairScale
|
||||
<https://github.com/facebookresearch/fairscale/>`__, which implement parts of the paper `ZeRO: Memory Optimizations
|
||||
Toward Training Trillion Parameter Models, by Samyam Rajbhandari, Jeff Rasley, Olatunji Ruwase, Yuxiong He
|
||||
<https://arxiv.org/abs/1910.02054>`__.
|
||||
|
||||
This provided support is new and experimental as of this writing.
|
||||
|
||||
You will need at least 2 GPUs to benefit from these features.
|
||||
|
||||
FairScale
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
By integrating `FairScale <https://github.com/facebookresearch/fairscale/>`__ the :class:`~transformers.Trainer`
|
||||
provides support for the following features from `the ZeRO paper <https://arxiv.org/abs/1910.02054>`__:
|
||||
|
||||
1. Optimizer State Sharding
|
||||
2. Gradient Sharding
|
||||
|
||||
To deploy this feature:
|
||||
|
||||
1. Install the library via pypi:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install fairscale
|
||||
|
||||
or find more details on `the FairScale's github page
|
||||
<https://github.com/facebookresearch/fairscale/#installation>`__.
|
||||
|
||||
2. Add ``--sharded_ddp`` to the command line arguments, and make sure you have added the distributed launcher ``-m
|
||||
torch.distributed.launch --nproc_per_node=NUMBER_OF_GPUS_YOU_HAVE`` if you haven't been using it already.
|
||||
|
||||
For example here is how you could use it for ``finetune_trainer.py`` with 2 GPUs:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
cd examples/seq2seq
|
||||
python -m torch.distributed.launch --nproc_per_node=2 ./finetune_trainer.py \
|
||||
--model_name_or_path sshleifer/distill-mbart-en-ro-12-4 --data_dir wmt_en_ro \
|
||||
--output_dir output_dir --overwrite_output_dir \
|
||||
--do_train --n_train 500 --num_train_epochs 1 \
|
||||
--per_device_train_batch_size 1 --freeze_embeds \
|
||||
--src_lang en_XX --tgt_lang ro_RO --task translation \
|
||||
--fp16 --sharded_ddp
|
||||
|
||||
Notes:
|
||||
|
||||
- This feature requires distributed training (so multiple GPUs).
|
||||
- It is not implemented for TPUs.
|
||||
- It works with ``--fp16`` too, to make things even faster.
|
||||
- One of the main benefits of enabling ``--sharded_ddp`` is that it uses a lot less GPU memory, so you should be able
|
||||
to use significantly larger batch sizes using the same hardware (e.g. 3x and even bigger) which should lead to
|
||||
significantly shorter training time.
|
||||
|
||||
|
||||
DeepSpeed
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
||||
`DeepSpeed <https://github.com/microsoft/DeepSpeed>`__ implements everything described in the `ZeRO paper
|
||||
<https://arxiv.org/abs/1910.02054>`__, except ZeRO's stage 3. "Parameter Partitioning (Pos+g+p)". Currently it provides
|
||||
full support for:
|
||||
|
||||
1. Optimizer State Partitioning (ZeRO stage 1)
|
||||
2. Add Gradient Partitioning (ZeRO stage 2)
|
||||
|
||||
To deploy this feature:
|
||||
|
||||
1. Install the library via pypi:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install deepspeed
|
||||
|
||||
or find more details on `the DeepSpeed's github page <https://github.com/microsoft/deepspeed#installation>`__.
|
||||
|
||||
2. Adjust the :class:`~transformers.Trainer` command line arguments as following:
|
||||
|
||||
1. replace ``python -m torch.distributed.launch`` with ``deepspeed``.
|
||||
2. add a new argument ``--deepspeed ds_config.json``, where ``ds_config.json`` is the DeepSpeed configuration file
|
||||
as documented `here <https://www.deepspeed.ai/docs/config-json/>`__. The file naming is up to you.
|
||||
|
||||
Therefore, if your original command line looked as following:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python -m torch.distributed.launch --nproc_per_node=2 your_program.py <normal cl args>
|
||||
|
||||
Now it should be:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
deepspeed --num_gpus=2 your_program.py <normal cl args> --deepspeed ds_config.json
|
||||
|
||||
Unlike, ``torch.distributed.launch`` where you have to specify how many GPUs to use with ``--nproc_per_node``, with
|
||||
the ``deepspeed`` launcher you don't have to use the corresponding ``--num_gpus`` if you want all of your GPUs used.
|
||||
The full details on how to configure various nodes and GPUs can be found `here
|
||||
<https://www.deepspeed.ai/getting-started/#resource-configuration-multi-node>`__.
|
||||
|
||||
Here is an example of running ``finetune_trainer.py`` under DeepSpeed deploying all available GPUs:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
cd examples/seq2seq
|
||||
deepspeed ./finetune_trainer.py --deepspeed ds_config.json \
|
||||
--model_name_or_path sshleifer/distill-mbart-en-ro-12-4 --data_dir wmt_en_ro \
|
||||
--output_dir output_dir --overwrite_output_dir \
|
||||
--do_train --n_train 500 --num_train_epochs 1 \
|
||||
--per_device_train_batch_size 1 --freeze_embeds \
|
||||
--src_lang en_XX --tgt_lang ro_RO --task translation
|
||||
|
||||
Note that in the DeepSpeed documentation you are likely to see ``--deepspeed --deepspeed_config ds_config.json`` -
|
||||
i.e. two DeepSpeed-related arguments, but for the sake of simplicity, and since there are already so many arguments
|
||||
to deal with, we combined the two into a single argument.
|
||||
|
||||
Before you can deploy DeepSpeed, let's discuss its configuration.
|
||||
|
||||
**Configuration:**
|
||||
|
||||
For the complete guide to the DeepSpeed configuration options that can be used in its configuration file please refer
|
||||
to the `following documentation <https://www.deepspeed.ai/docs/config-json/>`__.
|
||||
|
||||
While you always have to supply the DeepSpeed configuration file, you can configure the DeepSpeed integration in
|
||||
several ways:
|
||||
|
||||
1. Supply most of the configuration inside the file, and just use a few required command line arguments. This is the
|
||||
recommended way as it puts most of the configuration params in one place.
|
||||
2. Supply just the ZeRO configuration params inside the file, and configure the rest using the normal
|
||||
:class:`~transformers.Trainer` command line arguments.
|
||||
3. Any variation of the first two ways.
|
||||
|
||||
To get an idea of what DeepSpeed configuration file looks like, here is one that activates ZeRO stage 2 features,
|
||||
enables FP16, uses AdamW optimizer and WarmupLR scheduler:
|
||||
|
||||
.. code-block:: json
|
||||
|
||||
{
|
||||
"fp16": {
|
||||
"enabled": true,
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 1000,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"allgather_partitions": true,
|
||||
"allgather_bucket_size": 5e8,
|
||||
"overlap_comm": true,
|
||||
"reduce_scatter": true,
|
||||
"reduce_bucket_size": 5e8,
|
||||
"contiguous_gradients": true,
|
||||
"cpu_offload": true
|
||||
},
|
||||
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": 3e-5,
|
||||
"betas": [ 0.8, 0.999 ],
|
||||
"eps": 1e-8,
|
||||
"weight_decay": 3e-7
|
||||
}
|
||||
},
|
||||
"zero_allow_untested_optimizer": true,
|
||||
|
||||
"scheduler": {
|
||||
"type": "WarmupLR",
|
||||
"params": {
|
||||
"warmup_min_lr": 0,
|
||||
"warmup_max_lr": 3e-5,
|
||||
"warmup_num_steps": 500
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
If you already have a command line that you have been using with :class:`transformers.Trainer` args, you can continue
|
||||
using those and the :class:`~transformers.Trainer` will automatically convert them into the corresponding DeepSpeed
|
||||
configuration at run time. For example, you could use the following configuration file:
|
||||
|
||||
.. code-block:: json
|
||||
|
||||
{
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"allgather_partitions": true,
|
||||
"allgather_bucket_size": 5e8,
|
||||
"overlap_comm": true,
|
||||
"reduce_scatter": true,
|
||||
"reduce_bucket_size": 5e8,
|
||||
"contiguous_gradients": true,
|
||||
"cpu_offload": true
|
||||
}
|
||||
}
|
||||
|
||||
and the following command line arguments:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
--learning_rate 3e-5 --warmup_steps 500 --adam_beta1 0.8 --adam_beta2 0.999 --adam_epsilon 1e-8 \
|
||||
--weight_decay 3e-7 --lr_scheduler_type constant_with_warmup --fp16 --fp16_backend amp
|
||||
|
||||
to achieve the same configuration as provided by the longer json file in the first example.
|
||||
|
||||
When you execute the program, DeepSpeed will log the configuration it received from the :class:`~transformers.Trainer`
|
||||
to the console, so you can see exactly what the final configuration was passed to it.
|
||||
|
||||
**Shared Configuration:**
|
||||
|
||||
Some configuration information is required by both the :class:`~transformers.Trainer` and DeepSpeed to function
|
||||
correctly, therefore, to prevent conflicting definitions, which could lead to hard to detect errors, we chose to
|
||||
configure those via the :class:`~transformers.Trainer` command line arguments.
|
||||
|
||||
Therefore, the following DeepSpeed configuration params shouldn't be used with the :class:`~transformers.Trainer`:
|
||||
|
||||
* ``train_batch_size``
|
||||
* ``train_micro_batch_size_per_gpu``
|
||||
* ``gradient_accumulation_steps``
|
||||
|
||||
as these will be automatically derived from the run time environment and the following 2 command line arguments:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
--per_device_train_batch_size 8 --gradient_accumulation_steps 2
|
||||
|
||||
which are always required to be supplied.
|
||||
|
||||
Of course, you will need to adjust the values in this example to your situation.
|
||||
|
||||
|
||||
|
||||
**ZeRO:**
|
||||
|
||||
The ``zero_optimization`` section of the configuration file is the most important part (`docs
|
||||
<https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training>`__), since that is where you define
|
||||
which ZeRO stages you want to enable and how to configure them.
|
||||
|
||||
.. code-block:: json
|
||||
|
||||
{
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"allgather_partitions": true,
|
||||
"allgather_bucket_size": 5e8,
|
||||
"overlap_comm": true,
|
||||
"reduce_scatter": true,
|
||||
"reduce_bucket_size": 5e8,
|
||||
"contiguous_gradients": true,
|
||||
"cpu_offload": true
|
||||
}
|
||||
}
|
||||
|
||||
Notes:
|
||||
|
||||
- enabling ``cpu_offload`` should reduce GPU RAM usage (it requires ``"stage": 2``)
|
||||
- ``"overlap_comm": true`` trades off increased GPU RAM usage to lower all-reduce latency. ``overlap_comm`` uses 4.5x
|
||||
the ``allgather_bucket_size`` and ``reduce_bucket_size`` values. So if they are set to 5e8, this requires a 9GB
|
||||
footprint (``5e8 x 2Bytes x 2 x 4.5``). Therefore, if you have a GPU with 8GB or less RAM, to avoid getting
|
||||
OOM-errors you will need to reduce those parameters to about ``2e8``, which would require 3.6GB.
|
||||
|
||||
This section has to be configured exclusively via DeepSpeed configuration - the :class:`~transformers.Trainer` provides
|
||||
no equivalent command line arguments.
|
||||
|
||||
|
||||
|
||||
**Optimizer:**
|
||||
|
||||
|
||||
DeepSpeed's main optimizers are Adam, OneBitAdam, and Lamb. These have been thoroughly tested with ZeRO and are thus
|
||||
recommended to be used. It, however, can import other optimizers from ``torch``. The full documentation is `here
|
||||
<https://www.deepspeed.ai/docs/config-json/#optimizer-parameters>`__.
|
||||
|
||||
If you don't configure the ``optimizer`` entry in the configuration file, the :class:`~transformers.Trainer` will
|
||||
automatically set it to ``AdamW`` and will use the supplied values or the defaults for the following command line
|
||||
arguments: ``--learning_rate``, ``--adam_beta1``, ``--adam_beta2``, ``--adam_epsilon`` and ``--weight_decay``.
|
||||
|
||||
Here is an example of the pre-configured ``optimizer`` entry for AdamW:
|
||||
|
||||
.. code-block:: json
|
||||
|
||||
{
|
||||
"zero_allow_untested_optimizer": true,
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": 0.001,
|
||||
"betas": [0.8, 0.999],
|
||||
"eps": 1e-8,
|
||||
"weight_decay": 3e-7
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Since AdamW isn't on the list of tested with DeepSpeed/ZeRO optimizers, we have to add
|
||||
``zero_allow_untested_optimizer`` flag.
|
||||
|
||||
If you want to use one of the officially supported optimizers, configure them explicitly in the configuration file, and
|
||||
make sure to adjust the values. e.g. if use Adam you will want ``weight_decay`` around ``0.01``.
|
||||
|
||||
|
||||
**Scheduler:**
|
||||
|
||||
DeepSpeed supports LRRangeTest, OneCycle, WarmupLR and WarmupDecayLR LR schedulers. The full documentation is `here
|
||||
<https://www.deepspeed.ai/docs/config-json/#scheduler-parameters>`__.
|
||||
|
||||
If you don't configure the ``scheduler`` entry in the configuration file, the :class:`~transformers.Trainer` will use
|
||||
the value of ``--lr_scheduler_type`` to configure it. Currently the :class:`~transformers.Trainer` supports only 2 LR
|
||||
schedulers that are also supported by DeepSpeed:
|
||||
|
||||
* ``WarmupLR`` via ``--lr_scheduler_type constant_with_warmup``
|
||||
* ``WarmupDecayLR`` via ``--lr_scheduler_type linear``. This is also the default value for ``--lr_scheduler_type``,
|
||||
therefore, if you don't configure the scheduler this is scheduler that will get configured by default.
|
||||
|
||||
In either case, the values of ``--learning_rate`` and ``--warmup_steps`` will be used for the configuration.
|
||||
|
||||
In other words, if you don't use the configuration file to set the ``scheduler`` entry, provide either:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
--lr_scheduler_type constant_with_warmup --learning_rate 3e-5 --warmup_steps 500
|
||||
|
||||
or
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
--lr_scheduler_type linear --learning_rate 3e-5 --warmup_steps 500
|
||||
|
||||
with the desired values. If you don't pass these arguments, reasonable default values will be used instead.
|
||||
|
||||
In the case of WarmupDecayLR ``total_num_steps`` gets set either via the ``--max_steps`` command line argument, or if
|
||||
it is not provided, derived automatically at run time based on the environment and the size of the dataset and other
|
||||
command line arguments.
|
||||
|
||||
Here is an example of the pre-configured ``scheduler`` entry for WarmupLR (``constant_with_warmup`` in the
|
||||
:class:`~transformers.Trainer` API):
|
||||
|
||||
.. code-block:: json
|
||||
|
||||
{
|
||||
"scheduler": {
|
||||
"type": "WarmupLR",
|
||||
"params": {
|
||||
"warmup_min_lr": 0,
|
||||
"warmup_max_lr": 0.001,
|
||||
"warmup_num_steps": 1000
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
**Automatic Mixed Precision:**
|
||||
|
||||
You can work with FP16 in one of the following ways:
|
||||
|
||||
1. Pytorch native amp, as documented `here <https://www.deepspeed.ai/docs/config-json/#fp16-training-options>`__.
|
||||
2. NVIDIA's apex, as documented `here
|
||||
<https://www.deepspeed.ai/docs/config-json/#automatic-mixed-precision-amp-training-options>`__.
|
||||
|
||||
If you want to use an equivalent of the pytorch native amp, you can either configure the ``fp16`` entry in the
|
||||
configuration file, or use the following command line arguments: ``--fp16 --fp16_backend amp``.
|
||||
|
||||
Here is an example of the ``fp16`` configuration:
|
||||
|
||||
.. code-block:: json
|
||||
|
||||
{
|
||||
"fp16": {
|
||||
"enabled": true,
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 1000,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
}
|
||||
|
||||
If you want to use NVIDIA's apex instead, you can can either configure the ``amp`` entry in the configuration file, or
|
||||
use the following command line arguments: ``--fp16 --fp16_backend apex --fp16_opt_level 01``.
|
||||
|
||||
Here is an example of the ``amp`` configuration:
|
||||
|
||||
.. code-block:: json
|
||||
|
||||
{
|
||||
"amp": {
|
||||
"enabled": true,
|
||||
"opt_level": "O1"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
**Gradient Clipping:**
|
||||
|
||||
If you don't configure the ``gradient_clipping`` entry in the configuration file, the :class:`~transformers.Trainer`
|
||||
will use the value of the ``--max_grad_norm`` command line argument to set it.
|
||||
|
||||
Here is an example of the ``gradient_clipping`` configuration:
|
||||
|
||||
.. code-block:: json
|
||||
|
||||
{
|
||||
"gradient_clipping": 1.0,
|
||||
}
|
||||
|
||||
|
||||
|
||||
**Notes:**
|
||||
|
||||
* DeepSpeed works with the PyTorch :class:`~transformers.Trainer` but not TF :class:`~transformers.TFTrainer`.
|
||||
* While DeepSpeed has a pip installable PyPI package, it is highly recommended that it gets installed from `source
|
||||
<https://github.com/microsoft/deepspeed#installation>`__ to best match your hardware and also if you need to enable
|
||||
certain features, like 1-bit Adam, which aren't available in the pypi distribution.
|
||||
* You don't have to use the :class:`~transformers.Trainer` to use DeepSpeed with HuggingFace ``transformers`` - you can
|
||||
use any model with your own trainer, and you will have to adapt the latter according to `the DeepSpeed integration
|
||||
instructions <https://www.deepspeed.ai/getting-started/#writing-deepspeed-models>`__.
|
||||
|
||||
**Main DeepSpeed Resources:**
|
||||
|
||||
- `github <https://github.com/microsoft/deepspeed>`__
|
||||
- `Usage docs <https://www.deepspeed.ai/getting-started/>`__
|
||||
- `API docs <https://deepspeed.readthedocs.io/en/latest/index.html>`__
|
||||
|
||||
Finally, please, remember that, HuggingFace :class:`~transformers.Trainer` only integrates DeepSpeed, therefore if you
|
||||
have any problems or questions with regards to DeepSpeed usage, please, file an issue with `DeepSpeed github
|
||||
<https://github.com/microsoft/DeepSpeed/issues>`__.
|
||||
|
||||
@@ -42,7 +42,7 @@ Examples
|
||||
_______________________________________________________________________________________________________________________
|
||||
|
||||
- Examples and scripts for fine-tuning BART and other models for sequence to sequence tasks can be found in
|
||||
`examples/seq2seq/ <https://github.com/huggingface/transformers/blob/master/examples/seq2seq/README.md>`__.
|
||||
:prefix_link:`examples/seq2seq/ <examples/seq2seq/README.md>`.
|
||||
- An example of how to train :class:`~transformers.BartForConditionalGeneration` with a Hugging Face :obj:`datasets`
|
||||
object can be found in this `forum discussion
|
||||
<https://discuss.huggingface.co/t/train-bart-for-conditional-generation-e-g-summarization/1904>`__.
|
||||
@@ -55,9 +55,8 @@ Implementation Notes
|
||||
|
||||
- Bart doesn't use :obj:`token_type_ids` for sequence classification. Use :class:`~transformers.BartTokenizer` or
|
||||
:meth:`~transformers.BartTokenizer.encode` to get the proper splitting.
|
||||
- The forward pass of :class:`~transformers.BartModel` will create decoder inputs (using the helper function
|
||||
:func:`transformers.models.bart.modeling_bart._prepare_bart_decoder_inputs`) if they are not passed. This is
|
||||
different than some other modeling APIs.
|
||||
- The forward pass of :class:`~transformers.BartModel` will create the ``decoder_input_ids`` if they are not passed.
|
||||
This is different than some other modeling APIs. A typical use case of this feature is mask filling.
|
||||
- Model predictions are intended to be identical to the original implementation when
|
||||
:obj:`force_bos_token_to_be_generated=True`. This only works, however, if the string you pass to
|
||||
:func:`fairseq.encode` starts with a space.
|
||||
@@ -65,7 +64,6 @@ Implementation Notes
|
||||
summarization, see the example in that docstrings.
|
||||
- Models that load the `facebook/bart-large-cnn` weights will not have a :obj:`mask_token_id`, or be able to perform
|
||||
mask-filling tasks.
|
||||
- For training/forward passes that don't involve beam search, pass :obj:`use_cache=False`.
|
||||
|
||||
Mask Filling
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@@ -41,8 +41,8 @@ The Authors' code can be found `here <https://github.com/moussaKam/BARThez>`__.
|
||||
Examples
|
||||
_______________________________________________________________________________________________________________________
|
||||
|
||||
- BARThez can be fine-tuned on sequence-to-sequence tasks in a similar way as BART, check: `examples/seq2seq/
|
||||
<https://github.com/huggingface/transformers/blob/master/examples/seq2seq/README.md>`__.
|
||||
- BARThez can be fine-tuned on sequence-to-sequence tasks in a similar way as BART, check:
|
||||
:prefix_link:`examples/seq2seq/ <examples/seq2seq/README.md>`.
|
||||
|
||||
|
||||
BarthezTokenizer
|
||||
|
||||
64
docs/source/model_doc/bertweet.rst
Normal file
64
docs/source/model_doc/bertweet.rst
Normal file
@@ -0,0 +1,64 @@
|
||||
..
|
||||
Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
Bertweet
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
Overview
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The BERTweet model was proposed in `BERTweet: A pre-trained language model for English Tweets
|
||||
<https://www.aclweb.org/anthology/2020.emnlp-demos.2.pdf>`__ by Dat Quoc Nguyen, Thanh Vu, Anh Tuan Nguyen.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*We present BERTweet, the first public large-scale pre-trained language model for English Tweets. Our BERTweet, having
|
||||
the same architecture as BERT-base (Devlin et al., 2019), is trained using the RoBERTa pre-training procedure (Liu et
|
||||
al., 2019). Experiments show that BERTweet outperforms strong baselines RoBERTa-base and XLM-R-base (Conneau et al.,
|
||||
2020), producing better performance results than the previous state-of-the-art models on three Tweet NLP tasks:
|
||||
Part-of-speech tagging, Named-entity recognition and text classification.*
|
||||
|
||||
Example of use:
|
||||
|
||||
.. code-block::
|
||||
|
||||
import torch
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
bertweet = AutoModel.from_pretrained("vinai/bertweet-base")
|
||||
|
||||
# For transformers v4.x+:
|
||||
tokenizer = AutoTokenizer.from_pretrained("vinai/bertweet-base", use_fast=False)
|
||||
|
||||
# For transformers v3.x:
|
||||
# tokenizer = AutoTokenizer.from_pretrained("vinai/bertweet-base")
|
||||
|
||||
# INPUT TWEET IS ALREADY NORMALIZED!
|
||||
line = "SC has first two presumptive cases of coronavirus , DHEC confirms HTTPURL via @USER :cry:"
|
||||
|
||||
input_ids = torch.tensor([tokenizer.encode(line)])
|
||||
|
||||
with torch.no_grad():
|
||||
features = bertweet(input_ids) # Models outputs are now tuples
|
||||
|
||||
## With TensorFlow 2.0+:
|
||||
# from transformers import TFAutoModel
|
||||
# bertweet = TFAutoModel.from_pretrained("vinai/bertweet-base")
|
||||
|
||||
|
||||
The original code can be found `here <https://github.com/VinAIResearch/BERTweet>`__.
|
||||
|
||||
BertweetTokenizer
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.BertweetTokenizer
|
||||
:members:
|
||||
@@ -43,13 +43,10 @@ Implementation Notes
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
- Blenderbot uses a standard `seq2seq model transformer <https://arxiv.org/pdf/1706.03762.pdf>`__ based architecture.
|
||||
- It inherits completely from :class:`~transformers.BartForConditionalGeneration`
|
||||
- Even though blenderbot is one model, it uses two tokenizers :class:`~transformers.BlenderbotSmallTokenizer` for 90M
|
||||
checkpoint and :class:`~transformers.BlenderbotTokenizer` for all other checkpoints.
|
||||
- :class:`~transformers.BlenderbotSmallTokenizer` will always return :class:`~transformers.BlenderbotSmallTokenizer`,
|
||||
regardless of checkpoint. To use the 3B parameter checkpoint, you must call
|
||||
:class:`~transformers.BlenderbotTokenizer` directly.
|
||||
- Available checkpoints can be found in the `model hub <https://huggingface.co/models?search=blenderbot>`__.
|
||||
- This is the `default` Blenderbot model class. However, some smaller checkpoints, such as
|
||||
``facebook/blenderbot_small_90M``, have a different architecture and consequently should be used with
|
||||
`BlenderbotSmall <https://huggingface.co/transformers/master/model_doc/blenderbot_small.html>`__.
|
||||
|
||||
|
||||
Usage
|
||||
@@ -59,26 +56,15 @@ Here is an example of model usage:
|
||||
|
||||
.. code-block::
|
||||
|
||||
>>> from transformers import BlenderbotSmallTokenizer, BlenderbotForConditionalGeneration
|
||||
>>> mname = 'facebook/blenderbot-90M'
|
||||
>>> from transformers import BlenderbotTokenizer, BlenderbotForConditionalGeneration
|
||||
>>> mname = 'facebook/blenderbot-400M-distill'
|
||||
>>> model = BlenderbotForConditionalGeneration.from_pretrained(mname)
|
||||
>>> tokenizer = BlenderbotSmallTokenizer.from_pretrained(mname)
|
||||
>>> tokenizer = BlenderbotTokenizer.from_pretrained(mname)
|
||||
>>> UTTERANCE = "My friends are cool but they eat too many carbs."
|
||||
>>> inputs = tokenizer([UTTERANCE], return_tensors='pt')
|
||||
>>> reply_ids = model.generate(**inputs)
|
||||
>>> print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in reply_ids])
|
||||
|
||||
|
||||
Here is how you can check out config values:
|
||||
|
||||
.. code-block::
|
||||
|
||||
|
||||
>>> from transformers import BlenderbotConfig
|
||||
>>> config_90 = BlenderbotConfig.from_pretrained("facebook/blenderbot-90M")
|
||||
>>> config_90.to_diff_dict() # show interesting Values.
|
||||
>>> configuration_3B = BlenderbotConfig("facebook/blenderbot-3B")
|
||||
>>> configuration_3B.to_diff_dict()
|
||||
>>> print(tokenizer.batch_decode(reply_ids))
|
||||
["<s> That's unfortunate. Are they trying to lose weight or are they just trying to be healthier?</s>"]
|
||||
|
||||
|
||||
BlenderbotConfig
|
||||
@@ -93,11 +79,14 @@ BlenderbotTokenizer
|
||||
.. autoclass:: transformers.BlenderbotTokenizer
|
||||
:members: build_inputs_with_special_tokens
|
||||
|
||||
BlenderbotSmallTokenizer
|
||||
|
||||
BlenderbotModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.BlenderbotSmallTokenizer
|
||||
:members:
|
||||
See :obj:`transformers.BartModel` for arguments to `forward` and `generate`
|
||||
|
||||
.. autoclass:: transformers.BlenderbotModel
|
||||
:members: forward
|
||||
|
||||
|
||||
BlenderbotForConditionalGeneration
|
||||
@@ -106,13 +95,18 @@ BlenderbotForConditionalGeneration
|
||||
See :obj:`transformers.BartForConditionalGeneration` for arguments to `forward` and `generate`
|
||||
|
||||
.. autoclass:: transformers.BlenderbotForConditionalGeneration
|
||||
:members:
|
||||
:members: forward
|
||||
|
||||
|
||||
TFBlenderbotModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFBlenderbotModel
|
||||
:members: call
|
||||
|
||||
|
||||
TFBlenderbotForConditionalGeneration
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
See :obj:`transformers.TFBartForConditionalGeneration` for arguments to `forward` and `generate`
|
||||
|
||||
.. autoclass:: transformers.TFBlenderbotForConditionalGeneration
|
||||
:members:
|
||||
:members: call
|
||||
|
||||
84
docs/source/model_doc/blenderbot_small.rst
Normal file
84
docs/source/model_doc/blenderbot_small.rst
Normal file
@@ -0,0 +1,84 @@
|
||||
..
|
||||
Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
Blenderbot Small
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
Note that :class:`~transformers.BlenderbotSmallModel` and
|
||||
:class:`~transformers.BlenderbotSmallForConditionalGeneration` are only used in combination with the checkpoint
|
||||
`facebook/blenderbot-90M <https://huggingface.co/facebook/blenderbot-90M>`__. Larger Blenderbot checkpoints should
|
||||
instead be used with :class:`~transformers.BlenderbotModel` and
|
||||
:class:`~transformers.BlenderbotForConditionalGeneration`
|
||||
|
||||
Overview
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The Blender chatbot model was proposed in `Recipes for building an open-domain chatbot
|
||||
<https://arxiv.org/pdf/2004.13637.pdf>`__ Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu,
|
||||
Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston on 30 Apr 2020.
|
||||
|
||||
The abstract of the paper is the following:
|
||||
|
||||
*Building open-domain chatbots is a challenging area for machine learning research. While prior work has shown that
|
||||
scaling neural models in the number of parameters and the size of the data they are trained on gives improved results,
|
||||
we show that other ingredients are important for a high-performing chatbot. Good conversation requires a number of
|
||||
skills that an expert conversationalist blends in a seamless way: providing engaging talking points and listening to
|
||||
their partners, and displaying knowledge, empathy and personality appropriately, while maintaining a consistent
|
||||
persona. We show that large scale models can learn these skills when given appropriate training data and choice of
|
||||
generation strategy. We build variants of these recipes with 90M, 2.7B and 9.4B parameter models, and make our models
|
||||
and code publicly available. Human evaluations show our best models are superior to existing approaches in multi-turn
|
||||
dialogue in terms of engagingness and humanness measurements. We then discuss the limitations of this work by analyzing
|
||||
failure cases of our models.*
|
||||
|
||||
The authors' code can be found `here <https://github.com/facebookresearch/ParlAI>`__ .
|
||||
|
||||
BlenderbotSmallConfig
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.BlenderbotSmallConfig
|
||||
:members:
|
||||
|
||||
|
||||
BlenderbotSmallTokenizer
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.BlenderbotSmallTokenizer
|
||||
:members: build_inputs_with_special_tokens, get_special_tokens_mask,
|
||||
create_token_type_ids_from_sequences, save_vocabulary
|
||||
|
||||
|
||||
BlenderbotSmallModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.BlenderbotSmallModel
|
||||
:members: forward
|
||||
|
||||
|
||||
BlenderbotSmallForConditionalGeneration
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.BlenderbotSmallForConditionalGeneration
|
||||
:members: forward
|
||||
|
||||
|
||||
TFBlenderbotSmallModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFBlenderbotSmallModel
|
||||
:members: call
|
||||
|
||||
|
||||
TFBlenderbotSmallForConditionalGeneration
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFBlenderbotSmallForConditionalGeneration
|
||||
:members: call
|
||||
@@ -97,3 +97,8 @@ TFCTRLLMHeadModel
|
||||
.. autoclass:: transformers.TFCTRLLMHeadModel
|
||||
:members: call
|
||||
|
||||
TFCTRLForSequenceClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFCTRLForSequenceClassification
|
||||
:members: call
|
||||
|
||||
71
docs/source/model_doc/herbert.rst
Normal file
71
docs/source/model_doc/herbert.rst
Normal file
@@ -0,0 +1,71 @@
|
||||
..
|
||||
Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
herBERT
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
Overview
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The herBERT model was proposed in `KLEJ: Comprehensive Benchmark for Polish Language Understanding
|
||||
<https://www.aclweb.org/anthology/2020.acl-main.111.pdf>`__ by Piotr Rybak, Robert Mroczkowski, Janusz Tracz, and
|
||||
Ireneusz Gawlik. It is a BERT-based Language Model trained on Polish Corpora using only MLM objective with dynamic
|
||||
masking of whole words.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*In recent years, a series of Transformer-based models unlocked major improvements in general natural language
|
||||
understanding (NLU) tasks. Such a fast pace of research would not be possible without general NLU benchmarks, which
|
||||
allow for a fair comparison of the proposed methods. However, such benchmarks are available only for a handful of
|
||||
languages. To alleviate this issue, we introduce a comprehensive multi-task benchmark for the Polish language
|
||||
understanding, accompanied by an online leaderboard. It consists of a diverse set of tasks, adopted from existing
|
||||
datasets for named entity recognition, question-answering, textual entailment, and others. We also introduce a new
|
||||
sentiment analysis task for the e-commerce domain, named Allegro Reviews (AR). To ensure a common evaluation scheme and
|
||||
promote models that generalize to different NLU tasks, the benchmark includes datasets from varying domains and
|
||||
applications. Additionally, we release HerBERT, a Transformer-based model trained specifically for the Polish language,
|
||||
which has the best average performance and obtains the best results for three out of nine tasks. Finally, we provide an
|
||||
extensive evaluation, including several standard baselines and recently proposed, multilingual Transformer-based
|
||||
models.*
|
||||
|
||||
Examples of use:
|
||||
|
||||
.. code-block::
|
||||
|
||||
from transformers import HerbertTokenizer, RobertaModel
|
||||
|
||||
tokenizer = HerbertTokenizer.from_pretrained("allegro/herbert-klej-cased-tokenizer-v1")
|
||||
model = RobertaModel.from_pretrained("allegro/herbert-klej-cased-v1")
|
||||
|
||||
encoded_input = tokenizer.encode("Kto ma lepszą sztukę, ma lepszy rząd – to jasne.", return_tensors='pt')
|
||||
outputs = model(encoded_input)
|
||||
|
||||
# HerBERT can also be loaded using AutoTokenizer and AutoModel:
|
||||
import torch
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("allegro/herbert-klej-cased-tokenizer-v1")
|
||||
model = AutoModel.from_pretrained("allegro/herbert-klej-cased-v1")
|
||||
|
||||
|
||||
The original code can be found `here <https://github.com/allegro/HerBERT>`__.
|
||||
|
||||
HerbertTokenizer
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.HerbertTokenizer
|
||||
:members:
|
||||
|
||||
HerbertTokenizerFast
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.HerbertTokenizerFast
|
||||
:members:
|
||||
@@ -13,32 +13,72 @@
|
||||
LayoutLM
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
.. _Overview:
|
||||
|
||||
Overview
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The LayoutLM model was proposed in the paper `LayoutLM: Pre-training of Text and Layout for Document Image
|
||||
Understanding <https://arxiv.org/abs/1912.13318>`__ by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, and
|
||||
Ming Zhou. It's a simple but effective pretraining method of text and layout for document image understanding and
|
||||
information extraction tasks, such as form understanding and receipt understanding.
|
||||
information extraction tasks, such as form understanding and receipt understanding. It obtains state-of-the-art results
|
||||
on several downstream tasks:
|
||||
|
||||
- form understanding: the `FUNSD <https://guillaumejaume.github.io/FUNSD/>`__ dataset (a collection of 199 annotated
|
||||
forms comprising more than 30,000 words).
|
||||
- receipt understanding: the `SROIE <https://rrc.cvc.uab.es/?ch=13>`__ dataset (a collection of 626 receipts for
|
||||
training and 347 receipts for testing).
|
||||
- document image classification: the `RVL-CDIP <https://www.cs.cmu.edu/~aharley/rvl-cdip/>`__ dataset (a collection of
|
||||
400,000 images belonging to one of 16 classes).
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Pre-training techniques have been verified successfully in a variety of NLP tasks in recent years. Despite the
|
||||
widespread use of pretraining models for NLP applications, they almost exclusively focus on text-level manipulation,
|
||||
while neglecting layout and style information that is vital for document image understanding. In this paper, we propose
|
||||
the \textbf{LayoutLM} to jointly model interactions between text and layout information across scanned document images,
|
||||
which is beneficial for a great number of real-world document image understanding tasks such as information extraction
|
||||
from scanned documents. Furthermore, we also leverage image features to incorporate words' visual information into
|
||||
LayoutLM. To the best of our knowledge, this is the first time that text and layout are jointly learned in a single
|
||||
framework for document-level pretraining. It achieves new state-of-the-art results in several downstream tasks,
|
||||
including form understanding (from 70.72 to 79.27), receipt understanding (from 94.02 to 95.24) and document image
|
||||
classification (from 93.07 to 94.42).*
|
||||
the LayoutLM to jointly model interactions between text and layout information across scanned document images, which is
|
||||
beneficial for a great number of real-world document image understanding tasks such as information extraction from
|
||||
scanned documents. Furthermore, we also leverage image features to incorporate words' visual information into LayoutLM.
|
||||
To the best of our knowledge, this is the first time that text and layout are jointly learned in a single framework for
|
||||
document-level pretraining. It achieves new state-of-the-art results in several downstream tasks, including form
|
||||
understanding (from 70.72 to 79.27), receipt understanding (from 94.02 to 95.24) and document image classification
|
||||
(from 93.07 to 94.42).*
|
||||
|
||||
Tips:
|
||||
|
||||
- LayoutLM has an extra input called :obj:`bbox`, which is the bounding boxes of the input tokens.
|
||||
- The :obj:`bbox` requires the data that on 0-1000 scale, which means you should normalize the bounding box before
|
||||
passing them into model.
|
||||
- In addition to `input_ids`, :meth:`~transformer.LayoutLMModel.forward` also expects the input :obj:`bbox`, which are
|
||||
the bounding boxes (i.e. 2D-positions) of the input tokens. These can be obtained using an external OCR engine such
|
||||
as Google's `Tesseract <https://github.com/tesseract-ocr/tesseract>`__ (there's a `Python wrapper
|
||||
<https://pypi.org/project/pytesseract/>`__ available). Each bounding box should be in (x0, y0, x1, y1) format, where
|
||||
(x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1, y1) represents the
|
||||
position of the lower right corner. Note that one first needs to normalize the bounding boxes to be on a 0-1000
|
||||
scale. To normalize, you can use the following function:
|
||||
|
||||
.. code-block::
|
||||
|
||||
def normalize_bbox(bbox, width, height):
|
||||
return [
|
||||
int(1000 * (bbox[0] / width)),
|
||||
int(1000 * (bbox[1] / height)),
|
||||
int(1000 * (bbox[2] / width)),
|
||||
int(1000 * (bbox[3] / height)),
|
||||
]
|
||||
|
||||
Here, :obj:`width` and :obj:`height` correspond to the width and height of the original document in which the token
|
||||
occurs. Those can be obtained using the Python Image Library (PIL) library for example, as follows:
|
||||
|
||||
.. code-block::
|
||||
|
||||
from PIL import Image
|
||||
|
||||
image = Image.open("name_of_your_document - can be a png file, pdf, etc.")
|
||||
|
||||
width, height = image.size
|
||||
|
||||
- For a demo which shows how to fine-tune :class:`LayoutLMForTokenClassification` on the `FUNSD dataset
|
||||
<https://guillaumejaume.github.io/FUNSD/>`__ (a collection of annotated forms), see `this notebook
|
||||
<https://github.com/NielsRogge/Transformers-Tutorials/blob/master/LayoutLM/Fine_tuning_LayoutLMForTokenClassification_on_FUNSD.ipynb>`__.
|
||||
It includes an inference part, which shows how to use Google's Tesseract on a new document.
|
||||
|
||||
The original code can be found `here <https://github.com/microsoft/unilm/tree/master/layoutlm>`_.
|
||||
|
||||
@@ -78,6 +118,13 @@ LayoutLMForMaskedLM
|
||||
:members:
|
||||
|
||||
|
||||
LayoutLMForSequenceClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.LayoutLMForSequenceClassification
|
||||
:members:
|
||||
|
||||
|
||||
LayoutLMForTokenClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
||||
149
docs/source/model_doc/led.rst
Normal file
149
docs/source/model_doc/led.rst
Normal file
@@ -0,0 +1,149 @@
|
||||
..
|
||||
Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
LED
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
Overview
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The LED model was proposed in `Longformer: The Long-Document Transformer <https://arxiv.org/abs/2004.05150>`__ by Iz
|
||||
Beltagy, Matthew E. Peters, Arman Cohan.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Transformer-based models are unable to process long sequences due to their self-attention operation, which scales
|
||||
quadratically with the sequence length. To address this limitation, we introduce the Longformer with an attention
|
||||
mechanism that scales linearly with sequence length, making it easy to process documents of thousands of tokens or
|
||||
longer. Longformer's attention mechanism is a drop-in replacement for the standard self-attention and combines a local
|
||||
windowed attention with a task motivated global attention. Following prior work on long-sequence transformers, we
|
||||
evaluate Longformer on character-level language modeling and achieve state-of-the-art results on text8 and enwik8. In
|
||||
contrast to most prior work, we also pretrain Longformer and finetune it on a variety of downstream tasks. Our
|
||||
pretrained Longformer consistently outperforms RoBERTa on long document tasks and sets new state-of-the-art results on
|
||||
WikiHop and TriviaQA. We finally introduce the Longformer-Encoder-Decoder (LED), a Longformer variant for supporting
|
||||
long document generative sequence-to-sequence tasks, and demonstrate its effectiveness on the arXiv summarization
|
||||
dataset.*
|
||||
|
||||
Tips:
|
||||
|
||||
- :class:`~transformers.LEDForConditionalGeneration` is an extension of
|
||||
:class:`~transformers.BartForConditionalGeneration` exchanging the traditional *self-attention* layer with
|
||||
*Longformer*'s *chunked self-attention* layer. :class:`~transformers.LEDTokenizer` is an alias of
|
||||
:class:`~transformers.BartTokenizer`.
|
||||
- LED works very well on long-range *sequence-to-sequence* tasks where the ``input_ids`` largely exceed a length of
|
||||
1024 tokens.
|
||||
- LED pads the ``input_ids`` to be a multiple of ``config.attention_window`` if required. Therefore a small speed-up is
|
||||
gained, when :class:`~transformers.LEDTokenizer` is used with the ``pad_to_multiple_of`` argument.
|
||||
- LED makes use of *global attention* by means of the ``global_attention_mask`` (see
|
||||
:class:`~transformers.LongformerModel`). For summarization, it is advised to put *global attention* only on the first
|
||||
``<s>`` token. For question answering, it is advised to put *global attention* on all tokens of the question.
|
||||
- To fine-tune LED on all 16384, it is necessary to enable *gradient checkpointing* by setting
|
||||
``config.gradient_checkpointing = True``.
|
||||
- A notebook showing how to evaluate LED, can be accessed `here
|
||||
<https://colab.research.google.com/drive/12INTTR6n64TzS4RrXZxMSXfrOd9Xzamo?usp=sharing>`__.
|
||||
- A notebook showing how to fine-tune LED, can be accessed `here
|
||||
<https://colab.research.google.com/drive/12LjJazBl7Gam0XBPy_y0CTOJZeZ34c2v?usp=sharing>`__.
|
||||
|
||||
|
||||
LEDConfig
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.LEDConfig
|
||||
:members:
|
||||
|
||||
|
||||
LEDTokenizer
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.LEDTokenizer
|
||||
:members: build_inputs_with_special_tokens, get_special_tokens_mask,
|
||||
create_token_type_ids_from_sequences, save_vocabulary
|
||||
|
||||
|
||||
LEDTokenizerFast
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.LEDTokenizerFast
|
||||
:members: build_inputs_with_special_tokens, get_special_tokens_mask,
|
||||
create_token_type_ids_from_sequences, save_vocabulary
|
||||
|
||||
|
||||
LED specific outputs
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.models.led.modeling_led.LEDEncoderBaseModelOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.models.led.modeling_led.LEDSeq2SeqModelOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.models.led.modeling_led.LEDSeq2SeqLMOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.models.led.modeling_led.LEDSeq2SeqSequenceClassifierOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.models.led.modeling_led.LEDSeq2SeqQuestionAnsweringModelOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.models.led.modeling_tf_led.TFLEDEncoderBaseModelOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.models.led.modeling_tf_led.TFLEDSeq2SeqModelOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.models.led.modeling_tf_led.TFLEDSeq2SeqLMOutput
|
||||
:members:
|
||||
|
||||
|
||||
|
||||
|
||||
LEDModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.LEDModel
|
||||
:members: forward
|
||||
|
||||
|
||||
LEDForConditionalGeneration
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.LEDForConditionalGeneration
|
||||
:members: forward
|
||||
|
||||
|
||||
LEDForSequenceClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.LEDForSequenceClassification
|
||||
:members: forward
|
||||
|
||||
|
||||
LEDForQuestionAnswering
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.LEDForQuestionAnswering
|
||||
:members: forward
|
||||
|
||||
|
||||
TFLEDModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFLEDModel
|
||||
:members: call
|
||||
|
||||
|
||||
TFLEDForConditionalGeneration
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFLEDForConditionalGeneration
|
||||
:members: call
|
||||
@@ -33,7 +33,6 @@ Implementation Notes
|
||||
- The modeling code is the same as :class:`~transformers.BartForConditionalGeneration` with a few minor modifications:
|
||||
|
||||
- static (sinusoid) positional embeddings (:obj:`MarianConfig.static_position_embeddings=True`)
|
||||
- a new final_logits_bias (:obj:`MarianConfig.add_bias_logits=True`)
|
||||
- no layernorm_embedding (:obj:`MarianConfig.normalize_embedding=False`)
|
||||
- the model starts generating with :obj:`pad_token_id` (which has 0 as a token_embedding) as the prefix (Bart uses
|
||||
:obj:`<s/>`),
|
||||
@@ -56,12 +55,10 @@ Examples
|
||||
|
||||
- Since Marian models are smaller than many other translation models available in the library, they can be useful for
|
||||
fine-tuning experiments and integration tests.
|
||||
- `Fine-tune on TPU
|
||||
<https://github.com/huggingface/transformers/blob/master/examples/seq2seq/builtin_trainer/train_distil_marian_enro_tpu.sh>`__
|
||||
- `Fine-tune on GPU
|
||||
<https://github.com/huggingface/transformers/blob/master/examples/seq2seq/builtin_trainer/train_distil_marian_enro.sh>`__
|
||||
<https://github.com/huggingface/transformers/blob/master/examples/research_projects/seq2seq-distillation/train_distil_marian_enro_teacher.sh>`__
|
||||
- `Fine-tune on GPU with pytorch-lightning
|
||||
<https://github.com/huggingface/transformers/blob/master/examples/seq2seq/distil_marian_no_teacher.sh>`__
|
||||
<https://github.com/huggingface/transformers/blob/master/examples/research_projects/seq2seq-distillation/train_distil_marian_no_teacher.sh>`__
|
||||
|
||||
Multilingual Models
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
@@ -182,13 +179,29 @@ MarianTokenizer
|
||||
:members: prepare_seq2seq_batch
|
||||
|
||||
|
||||
MarianModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.MarianModel
|
||||
:members: forward
|
||||
|
||||
|
||||
MarianMTModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.MarianMTModel
|
||||
:members: forward
|
||||
|
||||
|
||||
TFMarianModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFMarianModel
|
||||
:members: call
|
||||
|
||||
|
||||
TFMarianMTModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFMarianMTModel
|
||||
:members: call
|
||||
|
||||
@@ -35,7 +35,7 @@ Examples
|
||||
_______________________________________________________________________________________________________________________
|
||||
|
||||
- Examples and scripts for fine-tuning mBART and other models for sequence to sequence tasks can be found in
|
||||
`examples/seq2seq/ <https://github.com/huggingface/transformers/blob/master/examples/seq2seq/README.md>`__.
|
||||
:prefix_link:`examples/seq2seq/ <examples/seq2seq/README.md>`.
|
||||
- Given the large embeddings table, mBART consumes a large amount of GPU RAM, especially for fine-tuning.
|
||||
:class:`MarianMTModel` is usually a better choice for bilingual machine translation.
|
||||
|
||||
@@ -97,6 +97,13 @@ MBartTokenizerFast
|
||||
:members:
|
||||
|
||||
|
||||
MBartModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.MBartModel
|
||||
:members:
|
||||
|
||||
|
||||
MBartForConditionalGeneration
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@@ -104,8 +111,28 @@ MBartForConditionalGeneration
|
||||
:members:
|
||||
|
||||
|
||||
MBartForQuestionAnswering
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.MBartForQuestionAnswering
|
||||
:members:
|
||||
|
||||
|
||||
MBartForSequenceClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.MBartForSequenceClassification
|
||||
|
||||
|
||||
TFMBartModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFMBartModel
|
||||
:members: call
|
||||
|
||||
|
||||
TFMBartForConditionalGeneration
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFMBartForConditionalGeneration
|
||||
:members:
|
||||
:members: call
|
||||
|
||||
@@ -51,9 +51,8 @@ All the `checkpoints <https://huggingface.co/models?search=pegasus>`__ are fine-
|
||||
Examples
|
||||
_______________________________________________________________________________________________________________________
|
||||
|
||||
- `Script <https://github.com/huggingface/transformers/blob/master/examples/seq2seq/finetune_pegasus_xsum.sh>`__ to
|
||||
fine-tune pegasus on the XSUM dataset. Data download instructions at `examples/seq2seq/
|
||||
<https://github.com/huggingface/transformers/blob/master/examples/seq2seq/README.md>`__.
|
||||
- :prefix_link:`Script <examples/seq2seq/finetune_pegasus_xsum.sh>` to fine-tune pegasus on the XSUM dataset. Data
|
||||
download instructions at :prefix_link:`examples/seq2seq/ <examples/seq2seq/README.md>`.
|
||||
- FP16 is not supported (help/ideas on this appreciated!).
|
||||
- The adafactor optimizer is recommended for pegasus fine-tuning.
|
||||
|
||||
@@ -66,7 +65,6 @@ Implementation Notes
|
||||
- Some key configuration differences:
|
||||
|
||||
- static, sinusoidal position embeddings
|
||||
- no :obj:`layernorm_embedding` (:obj:`PegasusConfig.normalize_embedding=False`)
|
||||
- the model starts generating with pad_token_id (which has 0 token_embedding) as the prefix.
|
||||
- more beams are used (:obj:`num_beams=8`)
|
||||
- All pretrained pegasus checkpoints are the same besides three attributes: :obj:`tokenizer.model_max_length` (maximum
|
||||
@@ -119,13 +117,29 @@ PegasusTokenizerFast
|
||||
:members:
|
||||
|
||||
|
||||
PegasusModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.PegasusModel
|
||||
:members: forward
|
||||
|
||||
|
||||
PegasusForConditionalGeneration
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.PegasusForConditionalGeneration
|
||||
:members: forward
|
||||
|
||||
|
||||
TFPegasusModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFPegasusModel
|
||||
:members: call
|
||||
|
||||
|
||||
TFPegasusForConditionalGeneration
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFPegasusForConditionalGeneration
|
||||
:members: call
|
||||
|
||||
59
docs/source/model_doc/phobert.rst
Normal file
59
docs/source/model_doc/phobert.rst
Normal file
@@ -0,0 +1,59 @@
|
||||
..
|
||||
Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
PhoBERT
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
Overview
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The PhoBERT model was proposed in `PhoBERT: Pre-trained language models for Vietnamese
|
||||
<https://www.aclweb.org/anthology/2020.findings-emnlp.92.pdf>`__ by Dat Quoc Nguyen, Anh Tuan Nguyen.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*We present PhoBERT with two versions, PhoBERT-base and PhoBERT-large, the first public large-scale monolingual
|
||||
language models pre-trained for Vietnamese. Experimental results show that PhoBERT consistently outperforms the recent
|
||||
best pre-trained multilingual model XLM-R (Conneau et al., 2020) and improves the state-of-the-art in multiple
|
||||
Vietnamese-specific NLP tasks including Part-of-speech tagging, Dependency parsing, Named-entity recognition and
|
||||
Natural language inference.*
|
||||
|
||||
Example of use:
|
||||
|
||||
.. code-block::
|
||||
|
||||
import torch
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
phobert = AutoModel.from_pretrained("vinai/phobert-base")
|
||||
tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base")
|
||||
|
||||
# INPUT TEXT MUST BE ALREADY WORD-SEGMENTED!
|
||||
line = "Tôi là sinh_viên trường đại_học Công_nghệ ."
|
||||
|
||||
input_ids = torch.tensor([tokenizer.encode(line)])
|
||||
|
||||
with torch.no_grad():
|
||||
features = phobert(input_ids) # Models outputs are now tuples
|
||||
|
||||
## With TensorFlow 2.0+:
|
||||
# from transformers import TFAutoModel
|
||||
# phobert = TFAutoModel.from_pretrained("vinai/phobert-base")
|
||||
|
||||
|
||||
The original code can be found `here <https://github.com/VinAIResearch/PhoBERT>`__.
|
||||
|
||||
PhobertTokenizer
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.PhobertTokenizer
|
||||
:members:
|
||||
@@ -44,9 +44,9 @@ Tips:
|
||||
|
||||
For more information about which prefix to use, it is easiest to look into Appendix D of the `paper
|
||||
<https://arxiv.org/pdf/1910.10683.pdf>`__. - For sequence-to-sequence generation, it is recommended to use
|
||||
:obj:`T5ForConditionalGeneration.generate()``. This method takes care of feeding the encoded input via
|
||||
cross-attention layers to the decoder and auto-regressively generates the decoder output. - T5 uses relative scalar
|
||||
embeddings. Encoder input padding can be done on the left and on the right.
|
||||
:obj:`T5ForConditionalGeneration.generate()`. This method takes care of feeding the encoded input via cross-attention
|
||||
layers to the decoder and auto-regressively generates the decoder output. - T5 uses relative scalar embeddings.
|
||||
Encoder input padding can be done on the left and on the right.
|
||||
|
||||
The original code can be found `here <https://github.com/google-research/text-to-text-transfer-transformer>`__.
|
||||
|
||||
@@ -55,7 +55,7 @@ Training
|
||||
|
||||
T5 is an encoder-decoder model and converts all NLP problems into a text-to-text format. It is trained using teacher
|
||||
forcing. This means that for training we always need an input sequence and a target sequence. The input sequence is fed
|
||||
to the model using :obj:`input_ids``. The target sequence is shifted to the right, i.e., prepended by a start-sequence
|
||||
to the model using :obj:`input_ids`. The target sequence is shifted to the right, i.e., prepended by a start-sequence
|
||||
token and fed to the decoder using the :obj:`decoder_input_ids`. In teacher-forcing style, the target sequence is then
|
||||
appended by the EOS token and corresponds to the :obj:`labels`. The PAD token is hereby used as the start-sequence
|
||||
token. T5 can be trained / fine-tuned both in a supervised and unsupervised fashion.
|
||||
|
||||
@@ -265,7 +265,7 @@ conversational**. In case your dataset involves conversational questions (such a
|
||||
together the ``queries``, ``answer_coordinates`` and ``answer_text`` per table (in the order of their ``position``
|
||||
index) and batch encode each table with its questions. This will make sure that the ``prev_labels`` token types (see
|
||||
docs of :class:`~transformers.TapasTokenizer`) are set correctly. See `this notebook
|
||||
<https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Fine_tuning_TapasForQuestionAnswering_on_SQA.ipynb>`__
|
||||
<https://github.com/NielsRogge/Transformers-Tutorials/blob/master/TAPAS/Fine_tuning_TapasForQuestionAnswering_on_SQA.ipynb>`__
|
||||
for more info.
|
||||
|
||||
**STEP 4: Train (fine-tune) TapasForQuestionAnswering**
|
||||
@@ -346,7 +346,7 @@ of that:
|
||||
... inputs,
|
||||
... outputs.logits.detach(),
|
||||
... outputs.logits_aggregation.detach()
|
||||
...)
|
||||
... )
|
||||
|
||||
>>> # let's print out the results:
|
||||
>>> id2aggregation = {0: "NONE", 1: "SUM", 2: "AVERAGE", 3:"COUNT"}
|
||||
@@ -382,7 +382,7 @@ of that:
|
||||
In case of a conversational set-up, then each table-question pair must be provided **sequentially** to the model, such
|
||||
that the ``prev_labels`` token types can be overwritten by the predicted ``labels`` of the previous table-question
|
||||
pair. Again, more info can be found in `this notebook
|
||||
<https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Fine_tuning_TapasForQuestionAnswering_on_SQA.ipynb>`__.
|
||||
<https://github.com/NielsRogge/Transformers-Tutorials/blob/master/TAPAS/Fine_tuning_TapasForQuestionAnswering_on_SQA.ipynb>`__.
|
||||
|
||||
|
||||
Tapas specific outputs
|
||||
|
||||
@@ -87,12 +87,14 @@ TransfoXLLMHeadModel
|
||||
.. autoclass:: transformers.TransfoXLLMHeadModel
|
||||
:members: forward
|
||||
|
||||
|
||||
TransfoXLForSequenceClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TransfoXLForSequenceClassification
|
||||
:members: forward
|
||||
|
||||
|
||||
TFTransfoXLModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@@ -107,6 +109,13 @@ TFTransfoXLLMHeadModel
|
||||
:members: call
|
||||
|
||||
|
||||
TFTransfoXLForSequenceClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.TFTransfoXLForSequenceClassification
|
||||
:members: call
|
||||
|
||||
|
||||
Internal Layers
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
||||
@@ -78,6 +78,12 @@ Once you are logged in with your model hub credentials, you can start building y
|
||||
|
||||
transformers-cli repo create your-model-name
|
||||
|
||||
If you want to create a repo under a specific organization, you should add a `--organization` flag:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
transformers-cli repo create your-model-name --organization your-org-name
|
||||
|
||||
This creates a repo on the model hub, which can be cloned.
|
||||
|
||||
.. code-block:: bash
|
||||
@@ -105,6 +111,9 @@ The only learning curve you might have compared to regular git is the one for gi
|
||||
`git-lfs.github.com <https://git-lfs.github.com/>`__ is decent, but we'll work on a tutorial with some tips and tricks
|
||||
in the coming weeks!
|
||||
|
||||
Additionally, if you want to change multiple repos at once, the `change_config.py script
|
||||
<https://github.com/huggingface/efficient_scripts/blob/main/change_config.py>`__ can probably save you some time.
|
||||
|
||||
Make your model work on all frameworks
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ Summary of the models
|
||||
This is a summary of the models available in 🤗 Transformers. It assumes you’re familiar with the original `transformer
|
||||
model <https://arxiv.org/abs/1706.03762>`_. For a gentle introduction check the `annotated transformer
|
||||
<http://nlp.seas.harvard.edu/2018/04/03/attention.html>`_. Here we focus on the high-level differences between the
|
||||
models. You can check them more in detail in their respective documentation. Also checkout the :doc:`pretrained model
|
||||
models. You can check them more in detail in their respective documentation. Also check out the :doc:`pretrained model
|
||||
page </pretrained_models>` to see the checkpoints available for each type of model and all `the community models
|
||||
<https://huggingface.co/models>`_.
|
||||
|
||||
@@ -30,7 +30,7 @@ Each one of the models in the library falls into one of the following categories
|
||||
|
||||
Autoregressive models are pretrained on the classic language modeling task: guess the next token having read all the
|
||||
previous ones. They correspond to the decoder of the original transformer model, and a mask is used on top of the full
|
||||
sentence so that the attention heads can only see what was before in the next, and not what’s after. Although those
|
||||
sentence so that the attention heads can only see what was before in the text, and not what’s after. Although those
|
||||
models can be fine-tuned and achieve great results on many tasks, the most natural application is text generation. A
|
||||
typical example of such models is GPT.
|
||||
|
||||
@@ -512,8 +512,8 @@ BART
|
||||
<https://arxiv.org/abs/1910.13461>`_, Mike Lewis et al.
|
||||
|
||||
Sequence-to-sequence model with an encoder and a decoder. Encoder is fed a corrupted version of the tokens, decoder is
|
||||
fed the original tokens (but has a mask to hide the future words like a regular transformers decoder). For the encoder
|
||||
, on the pretraining tasks, a composition of the following transformations are applied:
|
||||
fed the original tokens (but has a mask to hide the future words like a regular transformers decoder). A composition of
|
||||
the following transformations are applied on the pretraining tasks for the encoder:
|
||||
|
||||
* mask random tokens (like in BERT)
|
||||
* delete random tokens
|
||||
|
||||
@@ -90,9 +90,8 @@ You can then feed it all as input to your model:
|
||||
>>> outputs = model(input_ids, langs=langs)
|
||||
|
||||
|
||||
The example `run_generation.py
|
||||
<https://github.com/huggingface/transformers/blob/master/examples/text-generation/run_generation.py>`__ can generate
|
||||
text using the CLM checkpoints from XLM, using the language embeddings.
|
||||
The example :prefix_link:`run_generation.py <examples/text-generation/run_generation.py>` can generate text using the
|
||||
CLM checkpoints from XLM, using the language embeddings.
|
||||
|
||||
XLM without Language Embeddings
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
@@ -78,7 +78,7 @@ The library is built around three types of classes for each model:
|
||||
All these classes can be instantiated from pretrained instances and saved locally using two methods:
|
||||
|
||||
- :obj:`from_pretrained()` lets you instantiate a model/configuration/tokenizer from a pretrained version either
|
||||
provided by the library itself (the supported models are provided in the list :doc:`here <pretrained_models>` or
|
||||
provided by the library itself (the supported models are provided in the list :doc:`here <pretrained_models>`) or
|
||||
stored locally (or on a server) by the user,
|
||||
- :obj:`save_pretrained()` lets you save a model/configuration/tokenizer locally so that it can be reloaded using
|
||||
:obj:`from_pretrained()`.
|
||||
|
||||
@@ -10,17 +10,17 @@
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
reprocessing data
|
||||
Preprocessing data
|
||||
=======================================================================================================================
|
||||
|
||||
In this tutorial, we'll explore how to preprocess your data using 🤗 Transformers. The main tool for this is what we
|
||||
call a :doc:`tokenizer <main_classes/tokenizer>`. You can build one using the tokenizer class associated to the model
|
||||
you would like to use, or directly with the :class:`~transformers.AutoTokenizer` class.
|
||||
|
||||
As we saw in the :doc:`quicktour </quicktour>`, the tokenizer will first split a given text in words (or part of words,
|
||||
punctuation symbols, etc.) usually called `tokens`. Then it will convert those `tokens` into numbers, to be able to
|
||||
build a tensor out of them and feed them to the model. It will also add any additional inputs the model might expect to
|
||||
work properly.
|
||||
As we saw in the :doc:`quick tour </quicktour>`, the tokenizer will first split a given text in words (or part of
|
||||
words, punctuation symbols, etc.) usually called `tokens`. Then it will convert those `tokens` into numbers, to be able
|
||||
to build a tensor out of them and feed them to the model. It will also add any additional inputs the model might expect
|
||||
to work properly.
|
||||
|
||||
.. note::
|
||||
|
||||
@@ -131,7 +131,7 @@ ones it should not (because they represent padding in this case).
|
||||
|
||||
|
||||
Note that if your model does not have a maximum length associated to it, the command above will throw a warning. You
|
||||
can safely ignore it. You can also pass ``verbose=False`` to stop the tokenizer to throw those kinds of warnings.
|
||||
can safely ignore it. You can also pass ``verbose=False`` to stop the tokenizer from throwing those kinds of warnings.
|
||||
|
||||
.. _sentence-pairs:
|
||||
|
||||
@@ -216,7 +216,6 @@ Everything you always wanted to know about padding and truncation
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
We have seen the commands that will work for most cases (pad your batch to the length of the maximum sentence and
|
||||
|
||||
truncate to the maximum length the mode can accept). However, the API supports more strategies if you need them. The
|
||||
three arguments you need to know for this are :obj:`padding`, :obj:`truncation` and :obj:`max_length`.
|
||||
|
||||
|
||||
@@ -13,10 +13,9 @@
|
||||
Pretrained models
|
||||
=======================================================================================================================
|
||||
|
||||
Here is the full list of the currently provided pretrained models together with a short presentation of each model.
|
||||
Here is a partial list of some of the available pretrained models together with a short presentation of each model.
|
||||
|
||||
For a list that includes all community-uploaded models, refer to `https://huggingface.co/models
|
||||
<https://huggingface.co/models>`__.
|
||||
For the full list, refer to `https://huggingface.co/models <https://huggingface.co/models>`__.
|
||||
|
||||
+--------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| Architecture | Model id | Details of the model |
|
||||
|
||||
@@ -158,7 +158,7 @@ Using the tokenizer
|
||||
|
||||
We mentioned the tokenizer is responsible for the preprocessing of your texts. First, it will split a given text in
|
||||
words (or part of words, punctuation symbols, etc.) usually called `tokens`. There are multiple rules that can govern
|
||||
that process (you can learn more about them in the :doc:`tokenizer summary <tokenizer_summary>`, which is why we need
|
||||
that process (you can learn more about them in the :doc:`tokenizer summary <tokenizer_summary>`), which is why we need
|
||||
to instantiate the tokenizer using the name of the model, to make sure we use the same rules as when the model was
|
||||
pretrained.
|
||||
|
||||
|
||||
@@ -327,7 +327,7 @@ Masked Language Modeling
|
||||
Masked language modeling is the task of masking tokens in a sequence with a masking token, and prompting the model to
|
||||
fill that mask with an appropriate token. This allows the model to attend to both the right context (tokens on the
|
||||
right of the mask) and the left context (tokens on the left of the mask). Such a training creates a strong basis for
|
||||
downstream tasks, requiring bi-directional context such as SQuAD (question answering, see `Lewis, Lui, Goyal et al.
|
||||
downstream tasks requiring bi-directional context, such as SQuAD (question answering, see `Lewis, Lui, Goyal et al.
|
||||
<https://arxiv.org/abs/1910.13461>`__, part 4.2).
|
||||
|
||||
Here is an example of using pipelines to replace a mask from a sequence:
|
||||
@@ -657,7 +657,7 @@ Here are the expected results:
|
||||
{'word': 'Bridge', 'score': 0.990249514579773, 'entity': 'I-LOC'}
|
||||
]
|
||||
|
||||
Note, how the tokens of the sequence "Hugging Face" have been identified as an organisation, and "New York City",
|
||||
Note how the tokens of the sequence "Hugging Face" have been identified as an organisation, and "New York City",
|
||||
"DUMBO" and "Manhattan Bridge" have been identified as locations.
|
||||
|
||||
Here is an example of doing named entity recognition, using a model and a tokenizer. The process is the following:
|
||||
@@ -750,8 +750,7 @@ Summarization is the task of summarizing a document or an article into a shorter
|
||||
|
||||
An example of a summarization dataset is the CNN / Daily Mail dataset, which consists of long news articles and was
|
||||
created for the task of summarization. If you would like to fine-tune a model on a summarization task, various
|
||||
approaches are described in this `document
|
||||
<https://github.com/huggingface/transformers/blob/master/examples/seq2seq/README.md>`__.
|
||||
approaches are described in this :prefix_link:`document <examples/seq2seq/README.md>`.
|
||||
|
||||
Here is an example of using the pipelines to do summarization. It leverages a Bart model that was fine-tuned on the CNN
|
||||
/ Daily Mail data set.
|
||||
@@ -829,8 +828,7 @@ Translation is the task of translating a text from one language to another.
|
||||
|
||||
An example of a translation dataset is the WMT English to German dataset, which has sentences in English as the input
|
||||
data and the corresponding sentences in German as the target data. If you would like to fine-tune a model on a
|
||||
translation task, various approaches are described in this `document
|
||||
<https://github.com/huggingface/transformers/blob/master/examples/seq2seq/README.md>`__.
|
||||
translation task, various approaches are described in this :prefix_link:`document <examples/seq2seq/README.md>`.
|
||||
|
||||
Here is an example of using the pipelines to do translation. It leverages a T5 model that was only pre-trained on a
|
||||
multi-task mixture dataset (including WMT), yet, yielding impressive translation results.
|
||||
|
||||
@@ -25,25 +25,22 @@ How transformers are tested
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
1. Once a PR is submitted it gets tested with 9 CircleCi jobs. Every new commit to that PR gets retested. These jobs
|
||||
are defined in this `config file <https://github.com/huggingface/transformers/blob/master/.circleci/config.yml>`__,
|
||||
so that if needed you can reproduce the same environment on your machine.
|
||||
are defined in this :prefix_link:`config file <.circleci/config.yml>`, so that if needed you can reproduce the same
|
||||
environment on your machine.
|
||||
|
||||
These CI jobs don't run ``@slow`` tests.
|
||||
|
||||
2. There are 3 jobs run by `github actions <https://github.com/huggingface/transformers/actions>`__:
|
||||
|
||||
* `torch hub integration
|
||||
<https://github.com/huggingface/transformers/blob/master/.github/workflows/github-torch-hub.yml>`__: checks
|
||||
whether torch hub integration works.
|
||||
* :prefix_link:`torch hub integration <.github/workflows/github-torch-hub.yml>`: checks whether torch hub
|
||||
integration works.
|
||||
|
||||
* `self-hosted (push) <https://github.com/huggingface/transformers/blob/master/.github/workflows/self-push.yml>`__:
|
||||
runs fast tests on GPU only on commits on ``master``. It only runs if a commit on ``master`` has updated the code
|
||||
in one of the following folders: ``src``, ``tests``, ``.github`` (to prevent running on added model cards,
|
||||
notebooks, etc.)
|
||||
* :prefix_link:`self-hosted (push) <.github/workflows/self-push.yml>`: runs fast tests on GPU only on commits on
|
||||
``master``. It only runs if a commit on ``master`` has updated the code in one of the following folders: ``src``,
|
||||
``tests``, ``.github`` (to prevent running on added model cards, notebooks, etc.)
|
||||
|
||||
* `self-hosted runner
|
||||
<https://github.com/huggingface/transformers/blob/master/.github/workflows/self-scheduled.yml>`__: runs normal and
|
||||
slow tests on GPU in ``tests`` and ``examples``:
|
||||
* :prefix_link:`self-hosted runner <.github/workflows/self-scheduled.yml>`: runs normal and slow tests on GPU in
|
||||
``tests`` and ``examples``:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
@@ -492,12 +489,9 @@ spawns a normal process that then spawns off multiple workers and manages the IO
|
||||
|
||||
This is still under development but you can study 2 different tests that perform this successfully:
|
||||
|
||||
* `test_seq2seq_examples_multi_gpu.py
|
||||
<https://github.com/huggingface/transformers/blob/master/examples/seq2seq/test_seq2seq_examples_multi_gpu.py>`__ - a
|
||||
* :prefix_link:`test_seq2seq_examples_multi_gpu.py <examples/seq2seq/test_seq2seq_examples_multi_gpu.py>` - a
|
||||
``pytorch-lightning``-running test (had to use PL's ``ddp`` spawning method which is the default)
|
||||
* `test_finetune_trainer.py
|
||||
<https://github.com/huggingface/transformers/blob/master/examples/seq2seq/test_finetune_trainer.py>`__ - a normal
|
||||
(non-PL) test
|
||||
* :prefix_link:`test_finetune_trainer.py <examples/seq2seq/test_finetune_trainer.py>` - a normal (non-PL) test
|
||||
|
||||
To jump right into the execution point, search for the ``execute_subprocess_async`` function in those tests.
|
||||
|
||||
@@ -940,10 +934,9 @@ slow models to do qualitative testing. To see the use of these simply look for *
|
||||
|
||||
grep tiny tests examples
|
||||
|
||||
Here is a an example of a `script
|
||||
<https://github.com/huggingface/transformers/blob/master/scripts/fsmt/fsmt-make-tiny-model.py>`__ that created the tiny
|
||||
model `stas/tiny-wmt19-en-de <https://huggingface.co/stas/tiny-wmt19-en-de>`__. You can easily adjust it to your
|
||||
specific model's architecture.
|
||||
Here is a an example of a :prefix_link:`script <scripts/fsmt/fsmt-make-tiny-model.py>` that created the tiny model
|
||||
`stas/tiny-wmt19-en-de <https://huggingface.co/stas/tiny-wmt19-en-de>`__. You can easily adjust it to your specific
|
||||
model's architecture.
|
||||
|
||||
It's easy to measure the run-time incorrectly if for example there is an overheard of downloading a huge model, but if
|
||||
you test it locally the downloaded files would be cached and thus the download time not measured. Hence check the
|
||||
|
||||
@@ -18,7 +18,7 @@ On this page, we will have a closer look at tokenization. As we saw in :doc:`the
|
||||
look-up table. Converting words or subwords to ids is straightforward, so in this summary, we will focus on splitting a
|
||||
text into words or subwords (i.e. tokenizing a text). More specifically, we will look at the three main types of
|
||||
tokenizers used in 🤗 Transformers: :ref:`Byte-Pair Encoding (BPE) <byte-pair-encoding>`, :ref:`WordPiece <wordpiece>`,
|
||||
and :ref:`SentencePiece <sentencepiece>`, and show exemplary which tokenizer type is used by which model.
|
||||
and :ref:`SentencePiece <sentencepiece>`, and show examples of which tokenizer type is used by which model.
|
||||
|
||||
Note that on each model page, you can look at the documentation of the associated tokenizer to know which tokenizer
|
||||
type was used by the pretrained model. For instance, if we look at :class:`~transformers.BertTokenizer`, we can see
|
||||
@@ -72,7 +72,7 @@ greater than 50,000, especially if they are pretrained only on a single language
|
||||
So if simple space and punctuation tokenization is unsatisfactory, why not simply tokenize on characters? While
|
||||
character tokenization is very simple and would greatly reduce memory and time complexity it makes it much harder for
|
||||
the model to learn meaningful input representations. *E.g.* learning a meaningful context-independent representation
|
||||
for the letter ``"t"`` is much harder as learning a context-independent representation for the word ``"today"``.
|
||||
for the letter ``"t"`` is much harder than learning a context-independent representation for the word ``"today"``.
|
||||
Therefore, character tokenization is often accompanied by a loss of performance. So to get the best of both worlds,
|
||||
transformers models use a hybrid between word-level and character-level tokenization called **subword** tokenization.
|
||||
|
||||
@@ -202,10 +202,10 @@ WordPiece
|
||||
|
||||
WordPiece is the subword tokenization algorithm used for :doc:`BERT <model_doc/bert>`, :doc:`DistilBERT
|
||||
<model_doc/distilbert>`, and :doc:`Electra <model_doc/electra>`. The algorithm was outlined in `Japanese and Korean
|
||||
Voice Seach (Schuster et al., 2012)
|
||||
Voice Search (Schuster et al., 2012)
|
||||
<https://static.googleusercontent.com/media/research.google.com/ja//pubs/archive/37842.pdf>`__ and is very similar to
|
||||
BPE. WordPiece first initializes the vocabulary to include every character present in the training data and
|
||||
progressively learn a given number of merge rules. In contrast to BPE, WordPiece does not choose the most frequent
|
||||
progressively learns a given number of merge rules. In contrast to BPE, WordPiece does not choose the most frequent
|
||||
symbol pair, but the one that maximizes the likelihood of the training data once added to the vocabulary.
|
||||
|
||||
So what does this mean exactly? Referring to the previous example, maximizing the likelihood of the training data is
|
||||
|
||||
@@ -14,7 +14,7 @@ Training and fine-tuning
|
||||
=======================================================================================================================
|
||||
|
||||
Model classes in 🤗 Transformers are designed to be compatible with native PyTorch and TensorFlow 2 and can be used
|
||||
seemlessly with either. In this quickstart, we will show how to fine-tune (or train from scratch) a model using the
|
||||
seamlessly with either. In this quickstart, we will show how to fine-tune (or train from scratch) a model using the
|
||||
standard training tools available in either framework. We will also show how to use our included
|
||||
:func:`~transformers.Trainer` class which handles much of the complexity of training for you.
|
||||
|
||||
@@ -279,6 +279,7 @@ Finally, you can view the results, including any calculated metrics, by launchin
|
||||
``logging_dir`` directory.
|
||||
|
||||
|
||||
|
||||
.. _additional-resources:
|
||||
|
||||
Additional resources
|
||||
|
||||
@@ -54,7 +54,7 @@ Coming soon!
|
||||
| Task | Example datasets | Trainer support | TFTrainer support | 🤗 Datasets | Colab
|
||||
|---|---|:---:|:---:|:---:|:---:|
|
||||
| [**`language-modeling`**](https://github.com/huggingface/transformers/tree/master/examples/language-modeling) | Raw text | ✅ | - | ✅ | [](https://colab.research.google.com/github/huggingface/blog/blob/master/notebooks/01_how_to_train.ipynb)
|
||||
| [**`multiple-choice`**](https://github.com/huggingface/transformers/tree/master/examples/multiple-choice) | SWAG, RACE, ARC | ✅ | ✅ | - | [](https://colab.research.google.com/github/ViktorAlm/notebooks/blob/master/MPC_GPU_Demo_for_TF_and_PT.ipynb)
|
||||
| [**`multiple-choice`**](https://github.com/huggingface/transformers/tree/master/examples/multiple-choice) | SWAG, RACE, ARC | ✅ | ✅ | ✅ | [](https://colab.research.google.com/github/ViktorAlm/notebooks/blob/master/MPC_GPU_Demo_for_TF_and_PT.ipynb)
|
||||
| [**`question-answering`**](https://github.com/huggingface/transformers/tree/master/examples/question-answering) | SQuAD | ✅ | ✅ | ✅ | [](https://github.com/huggingface/notebooks/blob/master/examples/question_answering.ipynb)
|
||||
| [**`summarization`**](https://github.com/huggingface/transformers/tree/master/examples/seq2seq) | CNN/Daily Mail | ✅ | - | - | -
|
||||
| [**`text-classification`**](https://github.com/huggingface/transformers/tree/master/examples/text-classification) | GLUE, XNLI | ✅ | ✅ | ✅ | [](https://github.com/huggingface/notebooks/blob/master/examples/text_classification.ipynb)
|
||||
@@ -69,6 +69,43 @@ Coming soon!
|
||||
**Coming soon!**
|
||||
-->
|
||||
|
||||
## Distributed training and mixed precision
|
||||
|
||||
All the PyTorch scripts mentioned above work out of the box with distributed training and mixed precision, thanks to
|
||||
the [Trainer API](https://huggingface.co/transformers/main_classes/trainer.html). To launch one of them on _n_ GPUS,
|
||||
use the following command:
|
||||
|
||||
```bash
|
||||
python -m torch.distributed.launch \
|
||||
--nproc_per_node number_of_gpu_you_have path_to_script.py \
|
||||
--all_arguments_of_the_script
|
||||
```
|
||||
|
||||
As an example, here is how you would fine-tune the BERT large model (with whole word masking) on the text
|
||||
classification MNLI task using the `run_glue` script, with 8 GPUs:
|
||||
|
||||
```bash
|
||||
python -m torch.distributed.launch \
|
||||
--nproc_per_node 8 text-classification/run_glue.py \
|
||||
--model_name_or_path bert-large-uncased-whole-word-masking \
|
||||
--task_name mnli \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--max_seq_length 128 \
|
||||
--per_device_train_batch_size 8 \
|
||||
--learning_rate 2e-5 \
|
||||
--num_train_epochs 3.0 \
|
||||
--output_dir /tmp/mnli_output/
|
||||
```
|
||||
|
||||
If you have a GPU with mixed precision capabilities (architecture Pascal or more recent), you can use mixed precision
|
||||
training with PyTorch 1.6.0 or latest, or by installing the [Apex](https://github.com/NVIDIA/apex) library for previous
|
||||
versions. Just add the flag `--fp16` to your command launching one of the scripts mentioned above!
|
||||
|
||||
Using mixed precision training usually results in 2x-speedup for training with the same final results (as shown in
|
||||
[this table](https://github.com/huggingface/transformers/tree/master/examples/text-classification#mixed-precision-training)
|
||||
for text classification).
|
||||
|
||||
## Running on TPUs
|
||||
|
||||
When using Tensorflow, TPUs are supported out of the box as a `tf.distribute.Strategy`.
|
||||
@@ -76,27 +113,34 @@ When using Tensorflow, TPUs are supported out of the box as a `tf.distribute.Str
|
||||
When using PyTorch, we support TPUs thanks to `pytorch/xla`. For more context and information on how to setup your TPU environment refer to Google's documentation and to the
|
||||
very detailed [pytorch/xla README](https://github.com/pytorch/xla/blob/master/README.md).
|
||||
|
||||
In this repo, we provide a very simple launcher script named [xla_spawn.py](https://github.com/huggingface/transformers/tree/master/examples/xla_spawn.py) that lets you run our example scripts on multiple TPU cores without any boilerplate.
|
||||
Just pass a `--num_cores` flag to this script, then your regular training script with its arguments (this is similar to the `torch.distributed.launch` helper for torch.distributed).
|
||||
Note that this approach does not work for examples that use `pytorch-lightning`.
|
||||
|
||||
For example for `run_glue`:
|
||||
In this repo, we provide a very simple launcher script named
|
||||
[xla_spawn.py](https://github.com/huggingface/transformers/tree/master/examples/xla_spawn.py) that lets you run our
|
||||
example scripts on multiple TPU cores without any boilerplate. Just pass a `--num_cores` flag to this script, then your
|
||||
regular training script with its arguments (this is similar to the `torch.distributed.launch` helper for
|
||||
`torch.distributed`):
|
||||
|
||||
```bash
|
||||
python examples/xla_spawn.py --num_cores 8 \
|
||||
examples/text-classification/run_glue.py \
|
||||
--model_name_or_path bert-base-cased \
|
||||
--task_name mnli \
|
||||
--data_dir ./data/glue_data/MNLI \
|
||||
--output_dir ./models/tpu \
|
||||
--overwrite_output_dir \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--num_train_epochs 1 \
|
||||
--save_steps 20000
|
||||
python xla_spawn.py --num_cores num_tpu_you_have \
|
||||
path_to_script.py \
|
||||
--all_arguments_of_the_script
|
||||
```
|
||||
|
||||
Feedback and more use cases and benchmarks involving TPUs are welcome, please share with the community.
|
||||
As an example, here is how you would fine-tune the BERT large model (with whole word masking) on the text
|
||||
classification MNLI task using the `run_glue` script, with 8 TPUs:
|
||||
|
||||
```bash
|
||||
python xla_spawn.py --num_cores 8 \
|
||||
text-classification/run_glue.py \
|
||||
--model_name_or_path bert-large-uncased-whole-word-masking \
|
||||
--task_name mnli \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--max_seq_length 128 \
|
||||
--per_device_train_batch_size 8 \
|
||||
--learning_rate 2e-5 \
|
||||
--num_train_epochs 3.0 \
|
||||
--output_dir /tmp/mnli_output/
|
||||
```
|
||||
|
||||
## Logging & Experiment tracking
|
||||
|
||||
|
||||
@@ -25,8 +25,7 @@ objectives in our [model summary](https://huggingface.co/transformers/model_summ
|
||||
These scripts leverage the 🤗 Datasets library and the Trainer API. You can easily customize them to your needs if you
|
||||
need extra processing on your datasets.
|
||||
|
||||
**Note:** The old script `run_language_modeling.py` is still available
|
||||
[here](https://github.com/huggingface/transformers/blob/master/examples/contrib/legacy/run_language_modeling.py).
|
||||
**Note:** The old script `run_language_modeling.py` is still available [here](https://github.com/huggingface/transformers/blob/master/examples/legacy/run_language_modeling.py).
|
||||
|
||||
The following examples, will run on a datasets hosted on our [hub](https://huggingface.co/datasets) or with your own
|
||||
text files for training and validation. We give examples of both below.
|
||||
|
||||
@@ -83,6 +83,17 @@ class ModelArguments:
|
||||
default=True,
|
||||
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
||||
)
|
||||
model_revision: str = field(
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
)
|
||||
use_auth_token: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||
"with private models)."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -224,22 +235,29 @@ def main():
|
||||
# The .from_pretrained methods guarantee that only one local process can concurrently
|
||||
# download model & vocab.
|
||||
|
||||
config_kwargs = {
|
||||
"cache_dir": model_args.cache_dir,
|
||||
"revision": model_args.model_revision,
|
||||
"use_auth_token": True if model_args.use_auth_token else None,
|
||||
}
|
||||
if model_args.config_name:
|
||||
config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
|
||||
config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
|
||||
elif model_args.model_name_or_path:
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
||||
else:
|
||||
config = CONFIG_MAPPING[model_args.model_type]()
|
||||
logger.warning("You are instantiating a new config instance from scratch.")
|
||||
|
||||
tokenizer_kwargs = {
|
||||
"cache_dir": model_args.cache_dir,
|
||||
"use_fast": model_args.use_fast_tokenizer,
|
||||
"revision": model_args.model_revision,
|
||||
"use_auth_token": True if model_args.use_auth_token else None,
|
||||
}
|
||||
if model_args.tokenizer_name:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs)
|
||||
elif model_args.model_name_or_path:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
|
||||
@@ -252,6 +270,8 @@ def main():
|
||||
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
else:
|
||||
logger.info("Training new model from scratch")
|
||||
@@ -341,9 +361,20 @@ def main():
|
||||
if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path))
|
||||
else None
|
||||
)
|
||||
trainer.train(model_path=model_path)
|
||||
train_result = trainer.train(model_path=model_path)
|
||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||
|
||||
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
|
||||
if trainer.is_world_process_zero():
|
||||
with open(output_train_file, "w") as writer:
|
||||
logger.info("***** Train results *****")
|
||||
for key, value in sorted(train_result.metrics.items()):
|
||||
logger.info(f" {key} = {value}")
|
||||
writer.write(f"{key} = {value}\n")
|
||||
|
||||
# Need to save the state, since Trainer.save_model saves only the tokenizer with the model
|
||||
trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json"))
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if training_args.do_eval:
|
||||
@@ -358,7 +389,7 @@ def main():
|
||||
if trainer.is_world_process_zero():
|
||||
with open(output_eval_file, "w") as writer:
|
||||
logger.info("***** Eval results *****")
|
||||
for key, value in results.items():
|
||||
for key, value in sorted(results.items()):
|
||||
logger.info(f" {key} = {value}")
|
||||
writer.write(f"{key} = {value}\n")
|
||||
|
||||
|
||||
@@ -81,6 +81,17 @@ class ModelArguments:
|
||||
default=True,
|
||||
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
||||
)
|
||||
model_revision: str = field(
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
)
|
||||
use_auth_token: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||
"with private models)."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -234,22 +245,29 @@ def main():
|
||||
# Distributed training:
|
||||
# The .from_pretrained methods guarantee that only one local process can concurrently
|
||||
# download model & vocab.
|
||||
config_kwargs = {
|
||||
"cache_dir": model_args.cache_dir,
|
||||
"revision": model_args.model_revision,
|
||||
"use_auth_token": True if model_args.use_auth_token else None,
|
||||
}
|
||||
if model_args.config_name:
|
||||
config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
|
||||
config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
|
||||
elif model_args.model_name_or_path:
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
||||
else:
|
||||
config = CONFIG_MAPPING[model_args.model_type]()
|
||||
logger.warning("You are instantiating a new config instance from scratch.")
|
||||
|
||||
tokenizer_kwargs = {
|
||||
"cache_dir": model_args.cache_dir,
|
||||
"use_fast": model_args.use_fast_tokenizer,
|
||||
"revision": model_args.model_revision,
|
||||
"use_auth_token": True if model_args.use_auth_token else None,
|
||||
}
|
||||
if model_args.tokenizer_name:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs)
|
||||
elif model_args.model_name_or_path:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
|
||||
@@ -262,6 +280,8 @@ def main():
|
||||
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
else:
|
||||
logger.info("Training new model from scratch")
|
||||
@@ -376,9 +396,20 @@ def main():
|
||||
if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path))
|
||||
else None
|
||||
)
|
||||
trainer.train(model_path=model_path)
|
||||
train_result = trainer.train(model_path=model_path)
|
||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||
|
||||
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
|
||||
if trainer.is_world_process_zero():
|
||||
with open(output_train_file, "w") as writer:
|
||||
logger.info("***** Train results *****")
|
||||
for key, value in sorted(train_result.metrics.items()):
|
||||
logger.info(f" {key} = {value}")
|
||||
writer.write(f"{key} = {value}\n")
|
||||
|
||||
# Need to save the state, since Trainer.save_model saves only the tokenizer with the model
|
||||
trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json"))
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if training_args.do_eval:
|
||||
@@ -393,7 +424,7 @@ def main():
|
||||
if trainer.is_world_process_zero():
|
||||
with open(output_eval_file, "w") as writer:
|
||||
logger.info("***** Eval results *****")
|
||||
for key, value in results.items():
|
||||
for key, value in sorted(results.items()):
|
||||
logger.info(f" {key} = {value}")
|
||||
writer.write(f"{key} = {value}\n")
|
||||
|
||||
|
||||
@@ -83,6 +83,17 @@ class ModelArguments:
|
||||
default=True,
|
||||
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
||||
)
|
||||
model_revision: str = field(
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
)
|
||||
use_auth_token: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||
"with private models)."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -247,22 +258,29 @@ def main():
|
||||
# Distributed training:
|
||||
# The .from_pretrained methods guarantee that only one local process can concurrently
|
||||
# download model & vocab.
|
||||
config_kwargs = {
|
||||
"cache_dir": model_args.cache_dir,
|
||||
"revision": model_args.model_revision,
|
||||
"use_auth_token": True if model_args.use_auth_token else None,
|
||||
}
|
||||
if model_args.config_name:
|
||||
config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
|
||||
config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
|
||||
elif model_args.model_name_or_path:
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
||||
else:
|
||||
config = CONFIG_MAPPING[model_args.model_type]()
|
||||
logger.warning("You are instantiating a new config instance from scratch.")
|
||||
|
||||
tokenizer_kwargs = {
|
||||
"cache_dir": model_args.cache_dir,
|
||||
"use_fast": model_args.use_fast_tokenizer,
|
||||
"revision": model_args.model_revision,
|
||||
"use_auth_token": True if model_args.use_auth_token else None,
|
||||
}
|
||||
if model_args.tokenizer_name:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs)
|
||||
elif model_args.model_name_or_path:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
|
||||
@@ -275,6 +293,8 @@ def main():
|
||||
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
else:
|
||||
logger.info("Training new model from scratch")
|
||||
@@ -334,9 +354,20 @@ def main():
|
||||
if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path))
|
||||
else None
|
||||
)
|
||||
trainer.train(model_path=model_path)
|
||||
train_result = trainer.train(model_path=model_path)
|
||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||
|
||||
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
|
||||
if trainer.is_world_process_zero():
|
||||
with open(output_train_file, "w") as writer:
|
||||
logger.info("***** Train results *****")
|
||||
for key, value in sorted(train_result.metrics.items()):
|
||||
logger.info(f" {key} = {value}")
|
||||
writer.write(f"{key} = {value}\n")
|
||||
|
||||
# Need to save the state, since Trainer.save_model saves only the tokenizer with the model
|
||||
trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json"))
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if training_args.do_eval:
|
||||
@@ -351,7 +382,7 @@ def main():
|
||||
if trainer.is_world_process_zero():
|
||||
with open(output_eval_file, "w") as writer:
|
||||
logger.info("***** Eval results *****")
|
||||
for key, value in results.items():
|
||||
for key, value in sorted(results.items()):
|
||||
logger.info(f" {key} = {value}")
|
||||
writer.write(f"{key} = {value}\n")
|
||||
|
||||
|
||||
@@ -71,6 +71,17 @@ class ModelArguments:
|
||||
default=True,
|
||||
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
||||
)
|
||||
model_revision: str = field(
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
)
|
||||
use_auth_token: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||
"with private models)."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -231,22 +242,29 @@ def main():
|
||||
# Distributed training:
|
||||
# The .from_pretrained methods guarantee that only one local process can concurrently
|
||||
# download model & vocab.
|
||||
config_kwargs = {
|
||||
"cache_dir": model_args.cache_dir,
|
||||
"revision": model_args.model_revision,
|
||||
"use_auth_token": True if model_args.use_auth_token else None,
|
||||
}
|
||||
if model_args.config_name:
|
||||
config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
|
||||
config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
|
||||
elif model_args.model_name_or_path:
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
||||
else:
|
||||
config = XLNetConfig()
|
||||
logger.warning("You are instantiating a new config instance from scratch.")
|
||||
|
||||
tokenizer_kwargs = {
|
||||
"cache_dir": model_args.cache_dir,
|
||||
"use_fast": model_args.use_fast_tokenizer,
|
||||
"revision": model_args.model_revision,
|
||||
"use_auth_token": True if model_args.use_auth_token else None,
|
||||
}
|
||||
if model_args.tokenizer_name:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs)
|
||||
elif model_args.model_name_or_path:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
|
||||
@@ -259,6 +277,8 @@ def main():
|
||||
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
else:
|
||||
logger.info("Training new model from scratch")
|
||||
@@ -363,9 +383,20 @@ def main():
|
||||
if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path))
|
||||
else None
|
||||
)
|
||||
trainer.train(model_path=model_path)
|
||||
train_result = trainer.train(model_path=model_path)
|
||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||
|
||||
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
|
||||
if trainer.is_world_process_zero():
|
||||
with open(output_train_file, "w") as writer:
|
||||
logger.info("***** Train results *****")
|
||||
for key, value in sorted(train_result.metrics.items()):
|
||||
logger.info(f" {key} = {value}")
|
||||
writer.write(f"{key} = {value}\n")
|
||||
|
||||
# Need to save the state, since Trainer.save_model saves only the tokenizer with the model
|
||||
trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json"))
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if training_args.do_eval:
|
||||
@@ -380,7 +411,7 @@ def main():
|
||||
if trainer.is_world_process_zero():
|
||||
with open(output_eval_file, "w") as writer:
|
||||
logger.info("***** Eval results *****")
|
||||
for key, value in results.items():
|
||||
for key, value in sorted(results.items()):
|
||||
logger.info(f" {key} = {value}")
|
||||
writer.write(f"{key} = {value}\n")
|
||||
|
||||
|
||||
579
examples/legacy/multiple_choice/utils_multiple_choice.py
Normal file
579
examples/legacy/multiple_choice/utils_multiple_choice.py
Normal file
@@ -0,0 +1,579 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Multiple choice fine-tuning: utilities to work with multiple choice tasks of reading comprehension """
|
||||
|
||||
|
||||
import csv
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
import tqdm
|
||||
|
||||
from filelock import FileLock
|
||||
from transformers import PreTrainedTokenizer, is_tf_available, is_torch_available
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class InputExample:
|
||||
"""
|
||||
A single training/test example for multiple choice
|
||||
|
||||
Args:
|
||||
example_id: Unique id for the example.
|
||||
question: string. The untokenized text of the second sequence (question).
|
||||
contexts: list of str. The untokenized text of the first sequence (context of corresponding question).
|
||||
endings: list of str. multiple choice's options. Its length must be equal to contexts' length.
|
||||
label: (Optional) string. The label of the example. This should be
|
||||
specified for train and dev examples, but not for test examples.
|
||||
"""
|
||||
|
||||
example_id: str
|
||||
question: str
|
||||
contexts: List[str]
|
||||
endings: List[str]
|
||||
label: Optional[str]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class InputFeatures:
|
||||
"""
|
||||
A single set of features of data.
|
||||
Property names are the same names as the corresponding inputs to a model.
|
||||
"""
|
||||
|
||||
example_id: str
|
||||
input_ids: List[List[int]]
|
||||
attention_mask: Optional[List[List[int]]]
|
||||
token_type_ids: Optional[List[List[int]]]
|
||||
label: Optional[int]
|
||||
|
||||
|
||||
class Split(Enum):
|
||||
train = "train"
|
||||
dev = "dev"
|
||||
test = "test"
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch.utils.data.dataset import Dataset
|
||||
|
||||
class MultipleChoiceDataset(Dataset):
|
||||
"""
|
||||
This will be superseded by a framework-agnostic approach
|
||||
soon.
|
||||
"""
|
||||
|
||||
features: List[InputFeatures]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
task: str,
|
||||
max_seq_length: Optional[int] = None,
|
||||
overwrite_cache=False,
|
||||
mode: Split = Split.train,
|
||||
):
|
||||
processor = processors[task]()
|
||||
|
||||
cached_features_file = os.path.join(
|
||||
data_dir,
|
||||
"cached_{}_{}_{}_{}".format(
|
||||
mode.value,
|
||||
tokenizer.__class__.__name__,
|
||||
str(max_seq_length),
|
||||
task,
|
||||
),
|
||||
)
|
||||
|
||||
# Make sure only the first process in distributed training processes the dataset,
|
||||
# and the others will use the cache.
|
||||
lock_path = cached_features_file + ".lock"
|
||||
with FileLock(lock_path):
|
||||
|
||||
if os.path.exists(cached_features_file) and not overwrite_cache:
|
||||
logger.info(f"Loading features from cached file {cached_features_file}")
|
||||
self.features = torch.load(cached_features_file)
|
||||
else:
|
||||
logger.info(f"Creating features from dataset file at {data_dir}")
|
||||
label_list = processor.get_labels()
|
||||
if mode == Split.dev:
|
||||
examples = processor.get_dev_examples(data_dir)
|
||||
elif mode == Split.test:
|
||||
examples = processor.get_test_examples(data_dir)
|
||||
else:
|
||||
examples = processor.get_train_examples(data_dir)
|
||||
logger.info("Training examples: %s", len(examples))
|
||||
self.features = convert_examples_to_features(
|
||||
examples,
|
||||
label_list,
|
||||
max_seq_length,
|
||||
tokenizer,
|
||||
)
|
||||
logger.info("Saving features into cached file %s", cached_features_file)
|
||||
torch.save(self.features, cached_features_file)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.features)
|
||||
|
||||
def __getitem__(self, i) -> InputFeatures:
|
||||
return self.features[i]
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
class TFMultipleChoiceDataset:
|
||||
"""
|
||||
This will be superseded by a framework-agnostic approach
|
||||
soon.
|
||||
"""
|
||||
|
||||
features: List[InputFeatures]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
task: str,
|
||||
max_seq_length: Optional[int] = 128,
|
||||
overwrite_cache=False,
|
||||
mode: Split = Split.train,
|
||||
):
|
||||
processor = processors[task]()
|
||||
|
||||
logger.info(f"Creating features from dataset file at {data_dir}")
|
||||
label_list = processor.get_labels()
|
||||
if mode == Split.dev:
|
||||
examples = processor.get_dev_examples(data_dir)
|
||||
elif mode == Split.test:
|
||||
examples = processor.get_test_examples(data_dir)
|
||||
else:
|
||||
examples = processor.get_train_examples(data_dir)
|
||||
logger.info("Training examples: %s", len(examples))
|
||||
|
||||
self.features = convert_examples_to_features(
|
||||
examples,
|
||||
label_list,
|
||||
max_seq_length,
|
||||
tokenizer,
|
||||
)
|
||||
|
||||
def gen():
|
||||
for (ex_index, ex) in tqdm.tqdm(enumerate(self.features), desc="convert examples to features"):
|
||||
if ex_index % 10000 == 0:
|
||||
logger.info("Writing example %d of %d" % (ex_index, len(examples)))
|
||||
|
||||
yield (
|
||||
{
|
||||
"example_id": 0,
|
||||
"input_ids": ex.input_ids,
|
||||
"attention_mask": ex.attention_mask,
|
||||
"token_type_ids": ex.token_type_ids,
|
||||
},
|
||||
ex.label,
|
||||
)
|
||||
|
||||
self.dataset = tf.data.Dataset.from_generator(
|
||||
gen,
|
||||
(
|
||||
{
|
||||
"example_id": tf.int32,
|
||||
"input_ids": tf.int32,
|
||||
"attention_mask": tf.int32,
|
||||
"token_type_ids": tf.int32,
|
||||
},
|
||||
tf.int64,
|
||||
),
|
||||
(
|
||||
{
|
||||
"example_id": tf.TensorShape([]),
|
||||
"input_ids": tf.TensorShape([None, None]),
|
||||
"attention_mask": tf.TensorShape([None, None]),
|
||||
"token_type_ids": tf.TensorShape([None, None]),
|
||||
},
|
||||
tf.TensorShape([]),
|
||||
),
|
||||
)
|
||||
|
||||
def get_dataset(self):
|
||||
self.dataset = self.dataset.apply(tf.data.experimental.assert_cardinality(len(self.features)))
|
||||
|
||||
return self.dataset
|
||||
|
||||
def __len__(self):
|
||||
return len(self.features)
|
||||
|
||||
def __getitem__(self, i) -> InputFeatures:
|
||||
return self.features[i]
|
||||
|
||||
|
||||
class DataProcessor:
|
||||
"""Base class for data converters for multiple choice data sets."""
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""Gets a collection of `InputExample`s for the train set."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""Gets a collection of `InputExample`s for the dev set."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_test_examples(self, data_dir):
|
||||
"""Gets a collection of `InputExample`s for the test set."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_labels(self):
|
||||
"""Gets the list of labels for this data set."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class RaceProcessor(DataProcessor):
|
||||
"""Processor for the RACE data set."""
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
logger.info("LOOKING AT {} train".format(data_dir))
|
||||
high = os.path.join(data_dir, "train/high")
|
||||
middle = os.path.join(data_dir, "train/middle")
|
||||
high = self._read_txt(high)
|
||||
middle = self._read_txt(middle)
|
||||
return self._create_examples(high + middle, "train")
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
logger.info("LOOKING AT {} dev".format(data_dir))
|
||||
high = os.path.join(data_dir, "dev/high")
|
||||
middle = os.path.join(data_dir, "dev/middle")
|
||||
high = self._read_txt(high)
|
||||
middle = self._read_txt(middle)
|
||||
return self._create_examples(high + middle, "dev")
|
||||
|
||||
def get_test_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
logger.info("LOOKING AT {} test".format(data_dir))
|
||||
high = os.path.join(data_dir, "test/high")
|
||||
middle = os.path.join(data_dir, "test/middle")
|
||||
high = self._read_txt(high)
|
||||
middle = self._read_txt(middle)
|
||||
return self._create_examples(high + middle, "test")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
return ["0", "1", "2", "3"]
|
||||
|
||||
def _read_txt(self, input_dir):
|
||||
lines = []
|
||||
files = glob.glob(input_dir + "/*txt")
|
||||
for file in tqdm.tqdm(files, desc="read files"):
|
||||
with open(file, "r", encoding="utf-8") as fin:
|
||||
data_raw = json.load(fin)
|
||||
data_raw["race_id"] = file
|
||||
lines.append(data_raw)
|
||||
return lines
|
||||
|
||||
def _create_examples(self, lines, set_type):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
examples = []
|
||||
for (_, data_raw) in enumerate(lines):
|
||||
race_id = "%s-%s" % (set_type, data_raw["race_id"])
|
||||
article = data_raw["article"]
|
||||
for i in range(len(data_raw["answers"])):
|
||||
truth = str(ord(data_raw["answers"][i]) - ord("A"))
|
||||
question = data_raw["questions"][i]
|
||||
options = data_raw["options"][i]
|
||||
|
||||
examples.append(
|
||||
InputExample(
|
||||
example_id=race_id,
|
||||
question=question,
|
||||
contexts=[article, article, article, article], # this is not efficient but convenient
|
||||
endings=[options[0], options[1], options[2], options[3]],
|
||||
label=truth,
|
||||
)
|
||||
)
|
||||
return examples
|
||||
|
||||
|
||||
class SynonymProcessor(DataProcessor):
|
||||
"""Processor for the Synonym data set."""
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
logger.info("LOOKING AT {} train".format(data_dir))
|
||||
return self._create_examples(self._read_csv(os.path.join(data_dir, "mctrain.csv")), "train")
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
logger.info("LOOKING AT {} dev".format(data_dir))
|
||||
return self._create_examples(self._read_csv(os.path.join(data_dir, "mchp.csv")), "dev")
|
||||
|
||||
def get_test_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
logger.info("LOOKING AT {} dev".format(data_dir))
|
||||
|
||||
return self._create_examples(self._read_csv(os.path.join(data_dir, "mctest.csv")), "test")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
return ["0", "1", "2", "3", "4"]
|
||||
|
||||
def _read_csv(self, input_file):
|
||||
with open(input_file, "r", encoding="utf-8") as f:
|
||||
return list(csv.reader(f))
|
||||
|
||||
def _create_examples(self, lines: List[List[str]], type: str):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
|
||||
examples = [
|
||||
InputExample(
|
||||
example_id=line[0],
|
||||
question="", # in the swag dataset, the
|
||||
# common beginning of each
|
||||
# choice is stored in "sent2".
|
||||
contexts=[line[1], line[1], line[1], line[1], line[1]],
|
||||
endings=[line[2], line[3], line[4], line[5], line[6]],
|
||||
label=line[7],
|
||||
)
|
||||
for line in lines # we skip the line with the column names
|
||||
]
|
||||
|
||||
return examples
|
||||
|
||||
|
||||
class SwagProcessor(DataProcessor):
|
||||
"""Processor for the SWAG data set."""
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
logger.info("LOOKING AT {} train".format(data_dir))
|
||||
return self._create_examples(self._read_csv(os.path.join(data_dir, "train.csv")), "train")
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
logger.info("LOOKING AT {} dev".format(data_dir))
|
||||
return self._create_examples(self._read_csv(os.path.join(data_dir, "val.csv")), "dev")
|
||||
|
||||
def get_test_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
logger.info("LOOKING AT {} dev".format(data_dir))
|
||||
raise ValueError(
|
||||
"For swag testing, the input file does not contain a label column. It can not be tested in current code"
|
||||
"setting!"
|
||||
)
|
||||
return self._create_examples(self._read_csv(os.path.join(data_dir, "test.csv")), "test")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
return ["0", "1", "2", "3"]
|
||||
|
||||
def _read_csv(self, input_file):
|
||||
with open(input_file, "r", encoding="utf-8") as f:
|
||||
return list(csv.reader(f))
|
||||
|
||||
def _create_examples(self, lines: List[List[str]], type: str):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
if type == "train" and lines[0][-1] != "label":
|
||||
raise ValueError("For training, the input file must contain a label column.")
|
||||
|
||||
examples = [
|
||||
InputExample(
|
||||
example_id=line[2],
|
||||
question=line[5], # in the swag dataset, the
|
||||
# common beginning of each
|
||||
# choice is stored in "sent2".
|
||||
contexts=[line[4], line[4], line[4], line[4]],
|
||||
endings=[line[7], line[8], line[9], line[10]],
|
||||
label=line[11],
|
||||
)
|
||||
for line in lines[1:] # we skip the line with the column names
|
||||
]
|
||||
|
||||
return examples
|
||||
|
||||
|
||||
class ArcProcessor(DataProcessor):
|
||||
"""Processor for the ARC data set (request from allennlp)."""
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
logger.info("LOOKING AT {} train".format(data_dir))
|
||||
return self._create_examples(self._read_json(os.path.join(data_dir, "train.jsonl")), "train")
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
logger.info("LOOKING AT {} dev".format(data_dir))
|
||||
return self._create_examples(self._read_json(os.path.join(data_dir, "dev.jsonl")), "dev")
|
||||
|
||||
def get_test_examples(self, data_dir):
|
||||
logger.info("LOOKING AT {} test".format(data_dir))
|
||||
return self._create_examples(self._read_json(os.path.join(data_dir, "test.jsonl")), "test")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
return ["0", "1", "2", "3"]
|
||||
|
||||
def _read_json(self, input_file):
|
||||
with open(input_file, "r", encoding="utf-8") as fin:
|
||||
lines = fin.readlines()
|
||||
return lines
|
||||
|
||||
def _create_examples(self, lines, type):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
|
||||
# There are two types of labels. They should be normalized
|
||||
def normalize(truth):
|
||||
if truth in "ABCD":
|
||||
return ord(truth) - ord("A")
|
||||
elif truth in "1234":
|
||||
return int(truth) - 1
|
||||
else:
|
||||
logger.info("truth ERROR! %s", str(truth))
|
||||
return None
|
||||
|
||||
examples = []
|
||||
three_choice = 0
|
||||
four_choice = 0
|
||||
five_choice = 0
|
||||
other_choices = 0
|
||||
# we deleted example which has more than or less than four choices
|
||||
for line in tqdm.tqdm(lines, desc="read arc data"):
|
||||
data_raw = json.loads(line.strip("\n"))
|
||||
if len(data_raw["question"]["choices"]) == 3:
|
||||
three_choice += 1
|
||||
continue
|
||||
elif len(data_raw["question"]["choices"]) == 5:
|
||||
five_choice += 1
|
||||
continue
|
||||
elif len(data_raw["question"]["choices"]) != 4:
|
||||
other_choices += 1
|
||||
continue
|
||||
four_choice += 1
|
||||
truth = str(normalize(data_raw["answerKey"]))
|
||||
assert truth != "None"
|
||||
question_choices = data_raw["question"]
|
||||
question = question_choices["stem"]
|
||||
id = data_raw["id"]
|
||||
options = question_choices["choices"]
|
||||
if len(options) == 4:
|
||||
examples.append(
|
||||
InputExample(
|
||||
example_id=id,
|
||||
question=question,
|
||||
contexts=[
|
||||
options[0]["para"].replace("_", ""),
|
||||
options[1]["para"].replace("_", ""),
|
||||
options[2]["para"].replace("_", ""),
|
||||
options[3]["para"].replace("_", ""),
|
||||
],
|
||||
endings=[options[0]["text"], options[1]["text"], options[2]["text"], options[3]["text"]],
|
||||
label=truth,
|
||||
)
|
||||
)
|
||||
|
||||
if type == "train":
|
||||
assert len(examples) > 1
|
||||
assert examples[0].label is not None
|
||||
logger.info("len examples: %s}", str(len(examples)))
|
||||
logger.info("Three choices: %s", str(three_choice))
|
||||
logger.info("Five choices: %s", str(five_choice))
|
||||
logger.info("Other choices: %s", str(other_choices))
|
||||
logger.info("four choices: %s", str(four_choice))
|
||||
|
||||
return examples
|
||||
|
||||
|
||||
def convert_examples_to_features(
|
||||
examples: List[InputExample],
|
||||
label_list: List[str],
|
||||
max_length: int,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
) -> List[InputFeatures]:
|
||||
"""
|
||||
Loads a data file into a list of `InputFeatures`
|
||||
"""
|
||||
|
||||
label_map = {label: i for i, label in enumerate(label_list)}
|
||||
|
||||
features = []
|
||||
for (ex_index, example) in tqdm.tqdm(enumerate(examples), desc="convert examples to features"):
|
||||
if ex_index % 10000 == 0:
|
||||
logger.info("Writing example %d of %d" % (ex_index, len(examples)))
|
||||
choices_inputs = []
|
||||
for ending_idx, (context, ending) in enumerate(zip(example.contexts, example.endings)):
|
||||
text_a = context
|
||||
if example.question.find("_") != -1:
|
||||
# this is for cloze question
|
||||
text_b = example.question.replace("_", ending)
|
||||
else:
|
||||
text_b = example.question + " " + ending
|
||||
|
||||
inputs = tokenizer(
|
||||
text_a,
|
||||
text_b,
|
||||
add_special_tokens=True,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_overflowing_tokens=True,
|
||||
)
|
||||
if "num_truncated_tokens" in inputs and inputs["num_truncated_tokens"] > 0:
|
||||
logger.info(
|
||||
"Attention! you are cropping tokens (swag task is ok). "
|
||||
"If you are training ARC and RACE and you are poping question + options,"
|
||||
"you need to try to use a bigger max seq length!"
|
||||
)
|
||||
|
||||
choices_inputs.append(inputs)
|
||||
|
||||
label = label_map[example.label]
|
||||
|
||||
input_ids = [x["input_ids"] for x in choices_inputs]
|
||||
attention_mask = (
|
||||
[x["attention_mask"] for x in choices_inputs] if "attention_mask" in choices_inputs[0] else None
|
||||
)
|
||||
token_type_ids = (
|
||||
[x["token_type_ids"] for x in choices_inputs] if "token_type_ids" in choices_inputs[0] else None
|
||||
)
|
||||
|
||||
features.append(
|
||||
InputFeatures(
|
||||
example_id=example.example_id,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
label=label,
|
||||
)
|
||||
)
|
||||
|
||||
for f in features[:2]:
|
||||
logger.info("*** Example ***")
|
||||
logger.info("feature: %s" % f)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
processors = {"race": RaceProcessor, "swag": SwagProcessor, "arc": ArcProcessor, "syn": SynonymProcessor}
|
||||
MULTIPLE_CHOICE_TASKS_NUM_LABELS = {"race", 4, "swag", 4, "arc", 4, "syn", 5}
|
||||
@@ -19,3 +19,4 @@ pytest
|
||||
conllu
|
||||
sentencepiece != 0.1.92
|
||||
protobuf
|
||||
ray
|
||||
|
||||
@@ -16,27 +16,20 @@ limitations under the License.
|
||||
|
||||
## Multiple Choice
|
||||
|
||||
Based on the script [`run_multiple_choice.py`]().
|
||||
Based on the script [`run_swag.py`]().
|
||||
|
||||
#### Fine-tuning on SWAG
|
||||
Download [swag](https://github.com/rowanz/swagaf/tree/master/data) data
|
||||
|
||||
```bash
|
||||
#training on 4 tesla V100(16GB) GPUS
|
||||
export SWAG_DIR=/path/to/swag_data_dir
|
||||
python ./examples/multiple-choice/run_multiple_choice.py \
|
||||
--task_name swag \
|
||||
python examples/multiple-choice/run_swag.py \
|
||||
--model_name_or_path roberta-base \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--data_dir $SWAG_DIR \
|
||||
--learning_rate 5e-5 \
|
||||
--num_train_epochs 3 \
|
||||
--max_seq_length 80 \
|
||||
--output_dir models_bert/swag_base \
|
||||
--output_dir /tmp/swag_base \
|
||||
--per_gpu_eval_batch_size=16 \
|
||||
--per_device_train_batch_size=16 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--overwrite_output
|
||||
```
|
||||
Training with the defined hyper-parameters yields the following results:
|
||||
|
||||
377
examples/multiple-choice/run_swag.py
Normal file
377
examples/multiple-choice/run_swag.py
Normal file
@@ -0,0 +1,377 @@
|
||||
# coding=utf-8
|
||||
# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Fine-tuning the library models for multiple choice.
|
||||
"""
|
||||
# You can also adapt this script on your own multiple choice task. Pointers for this are left as comments.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForMultipleChoice,
|
||||
AutoTokenizer,
|
||||
HfArgumentParser,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
default_data_collator,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase
|
||||
from transformers.trainer_utils import is_main_process
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
||||
"""
|
||||
|
||||
model_name_or_path: str = field(
|
||||
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
||||
)
|
||||
config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||
)
|
||||
tokenizer_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
|
||||
)
|
||||
use_fast_tokenizer: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
||||
)
|
||||
model_revision: str = field(
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
)
|
||||
use_auth_token: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||
"with private models)."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataTrainingArguments:
|
||||
"""
|
||||
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||
"""
|
||||
|
||||
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
|
||||
validation_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||
)
|
||||
preprocessing_num_workers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for the preprocessing."},
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The maximum total input sequence length after tokenization. If passed, sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
pad_to_max_length: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to pad all samples to the maximum sentence length. "
|
||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
|
||||
"efficient on GPU but very bad for TPU."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.train_file is not None:
|
||||
extension = self.train_file.split(".")[-1]
|
||||
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
|
||||
if self.validation_file is not None:
|
||||
extension = self.validation_file.split(".")[-1]
|
||||
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForMultipleChoice:
|
||||
"""
|
||||
Data collator that will dynamically pad the inputs for multiple choice received.
|
||||
|
||||
Args:
|
||||
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
|
||||
The tokenizer used for encoding the data.
|
||||
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
||||
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
||||
among:
|
||||
|
||||
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
||||
sequence if provided).
|
||||
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
||||
maximum acceptable input length for the model if that argument is not provided.
|
||||
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
||||
different lengths).
|
||||
max_length (:obj:`int`, `optional`):
|
||||
Maximum length of the returned list and optionally padding length (see above).
|
||||
pad_to_multiple_of (:obj:`int`, `optional`):
|
||||
If set will pad the sequence to a multiple of the provided value.
|
||||
|
||||
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
||||
7.5 (Volta).
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
padding: Union[bool, str, PaddingStrategy] = True
|
||||
max_length: Optional[int] = None
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
|
||||
def __call__(self, features):
|
||||
label_name = "label" if "label" in features[0].keys() else "labels"
|
||||
labels = [feature.pop(label_name) for feature in features]
|
||||
batch_size = len(features)
|
||||
num_choices = len(features[0]["input_ids"])
|
||||
flattened_features = [
|
||||
[{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
|
||||
]
|
||||
flattened_features = sum(flattened_features, [])
|
||||
|
||||
batch = self.tokenizer.pad(
|
||||
flattened_features,
|
||||
padding=self.padding,
|
||||
max_length=self.max_length,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
# Un-flatten
|
||||
batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
|
||||
# Add back labels
|
||||
batch["labels"] = torch.tensor(labels, dtype=torch.int64)
|
||||
return batch
|
||||
|
||||
|
||||
def main():
|
||||
# See all possible arguments in src/transformers/training_args.py
|
||||
# or by passing the --help flag to this script.
|
||||
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
# If we pass only one argument to the script and it's the path to a json file,
|
||||
# let's parse it to get our arguments.
|
||||
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
if (
|
||||
os.path.exists(training_args.output_dir)
|
||||
and os.listdir(training_args.output_dir)
|
||||
and training_args.do_train
|
||||
and not training_args.overwrite_output_dir
|
||||
):
|
||||
raise ValueError(
|
||||
f"Output directory ({training_args.output_dir}) already exists and is not empty."
|
||||
"Use --overwrite_output_dir to overcome."
|
||||
)
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN,
|
||||
)
|
||||
|
||||
# Log on each process the small summary:
|
||||
logger.warning(
|
||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
||||
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||
)
|
||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||
if is_main_process(training_args.local_rank):
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
logger.info("Training/evaluation parameters %s", training_args)
|
||||
|
||||
# Set seed before initializing model.
|
||||
set_seed(training_args.seed)
|
||||
|
||||
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
||||
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
||||
# (the dataset will be downloaded automatically from the datasets Hub).
|
||||
|
||||
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
|
||||
# 'text' is found. You can easily tweak this behavior (see below).
|
||||
|
||||
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
|
||||
# download the dataset.
|
||||
if data_args.train_file is not None or data_args.validation_file is not None:
|
||||
data_files = {}
|
||||
if data_args.train_file is not None:
|
||||
data_files["train"] = data_args.train_file
|
||||
if data_args.validation_file is not None:
|
||||
data_files["validation"] = data_args.validation_file
|
||||
extension = data_args.train_file.split(".")[-1]
|
||||
datasets = load_dataset(extension, data_files=data_files)
|
||||
else:
|
||||
# Downloading and loading the swag dataset from the hub.
|
||||
datasets = load_dataset("swag", "regular")
|
||||
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
||||
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
|
||||
# Distributed training:
|
||||
# The .from_pretrained methods guarantee that only one local process can concurrently
|
||||
# download model & vocab.
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_fast=model_args.use_fast_tokenizer,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
model = AutoModelForMultipleChoice.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
|
||||
# When using your own dataset or a different dataset from swag, you will probably need to change this.
|
||||
ending_names = [f"ending{i}" for i in range(4)]
|
||||
context_name = "sent1"
|
||||
question_header_name = "sent2"
|
||||
|
||||
# Preprocessing the datasets.
|
||||
def preprocess_function(examples):
|
||||
first_sentences = [[context] * 4 for context in examples[context_name]]
|
||||
question_headers = examples[question_header_name]
|
||||
second_sentences = [
|
||||
[f"{header} {examples[end][i]}" for end in ending_names] for i, header in enumerate(question_headers)
|
||||
]
|
||||
|
||||
# Flatten out
|
||||
first_sentences = sum(first_sentences, [])
|
||||
second_sentences = sum(second_sentences, [])
|
||||
|
||||
# Tokenize
|
||||
tokenized_examples = tokenizer(
|
||||
first_sentences,
|
||||
second_sentences,
|
||||
truncation=True,
|
||||
max_length=data_args.max_seq_length,
|
||||
padding="max_length" if data_args.pad_to_max_length else False,
|
||||
)
|
||||
# Un-flatten
|
||||
return {k: [v[i : i + 4] for i in range(0, len(v), 4)] for k, v in tokenized_examples.items()}
|
||||
|
||||
tokenized_datasets = datasets.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
|
||||
# Data collator
|
||||
data_collator = (
|
||||
default_data_collator if data_args.pad_to_max_length else DataCollatorForMultipleChoice(tokenizer=tokenizer)
|
||||
)
|
||||
|
||||
# Metric
|
||||
def compute_metrics(eval_predictions):
|
||||
predictions, label_ids = eval_predictions
|
||||
preds = np.argmax(predictions, axis=1)
|
||||
return {"accuracy": (preds == label_ids).astype(np.float32).mean().item()}
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=tokenized_datasets["train"] if training_args.do_train else None,
|
||||
eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
compute_metrics=compute_metrics,
|
||||
)
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
train_result = trainer.train(
|
||||
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
|
||||
)
|
||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||
|
||||
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
|
||||
if trainer.is_world_process_zero():
|
||||
with open(output_train_file, "w") as writer:
|
||||
logger.info("***** Train results *****")
|
||||
for key, value in sorted(train_result.metrics.items()):
|
||||
logger.info(f" {key} = {value}")
|
||||
writer.write(f"{key} = {value}\n")
|
||||
|
||||
# Need to save the state, since Trainer.save_model saves only the tokenizer with the model
|
||||
trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json"))
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
|
||||
results = trainer.evaluate()
|
||||
|
||||
output_eval_file = os.path.join(training_args.output_dir, "eval_results_swag.txt")
|
||||
if trainer.is_world_process_zero():
|
||||
with open(output_eval_file, "w") as writer:
|
||||
logger.info("***** Eval results *****")
|
||||
for key, value in sorted(results.items()):
|
||||
logger.info(f" {key} = {value}")
|
||||
writer.write(f"{key} = {value}\n")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _mp_fn(index):
|
||||
# For xla_spawn (TPUs)
|
||||
main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -23,8 +23,7 @@ uses special features of those tokenizers. You can check if your favorite model
|
||||
[this table](https://huggingface.co/transformers/index.html#bigtable), if it doesn't you can still use the old version
|
||||
of the script.
|
||||
|
||||
The old version of this script can be found [here](https://github.com/huggingface/transformers/blob/master/examples/contrib/legacy/question-answering/run_squad.py).
|
||||
|
||||
The old version of this script can be found [here](https://github.com/huggingface/transformers/tree/master/examples/legacy/question-answering).
|
||||
#### Fine-tuning BERT on SQuAD1.0
|
||||
|
||||
This example code fine-tunes BERT on the SQuAD1.0 dataset. It runs in 24 min (with BERT-base) or 68 min (with BERT-large)
|
||||
|
||||
@@ -65,6 +65,17 @@ class ModelArguments:
|
||||
default=None,
|
||||
metadata={"help": "Path to directory to store the pretrained models downloaded from huggingface.co"},
|
||||
)
|
||||
model_revision: str = field(
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
)
|
||||
use_auth_token: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||
"with private models)."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -220,17 +231,23 @@ def main():
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_fast=True,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
model = AutoModelForQuestionAnswering.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
|
||||
# Tokenizer check: this script requires a fast tokenizer.
|
||||
@@ -438,11 +455,22 @@ def main():
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
trainer.train(
|
||||
train_result = trainer.train(
|
||||
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
|
||||
)
|
||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||
|
||||
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
|
||||
if trainer.is_world_process_zero():
|
||||
with open(output_train_file, "w") as writer:
|
||||
logger.info("***** Train results *****")
|
||||
for key, value in sorted(train_result.metrics.items()):
|
||||
logger.info(f" {key} = {value}")
|
||||
writer.write(f"{key} = {value}\n")
|
||||
|
||||
# Need to save the state, since Trainer.save_model saves only the tokenizer with the model
|
||||
trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json"))
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if training_args.do_eval:
|
||||
@@ -453,7 +481,7 @@ def main():
|
||||
if trainer.is_world_process_zero():
|
||||
with open(output_eval_file, "w") as writer:
|
||||
logger.info("***** Eval results *****")
|
||||
for key, value in results.items():
|
||||
for key, value in sorted(results.items()):
|
||||
logger.info(f" {key} = {value}")
|
||||
writer.write(f"{key} = {value}\n")
|
||||
|
||||
|
||||
@@ -64,9 +64,16 @@ class ModelArguments:
|
||||
default=None,
|
||||
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
|
||||
)
|
||||
use_fast_tokenizer: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
||||
model_revision: str = field(
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
)
|
||||
use_auth_token: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||
"with private models)."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -223,16 +230,22 @@ def main():
|
||||
config = XLNetConfig.from_pretrained(
|
||||
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
tokenizer = XLNetTokenizerFast.from_pretrained(
|
||||
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
model = XLNetForQuestionAnswering.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
|
||||
# Preprocessing the datasets.
|
||||
@@ -481,11 +494,22 @@ def main():
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
trainer.train(
|
||||
train_result = trainer.train(
|
||||
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
|
||||
)
|
||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||
|
||||
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
|
||||
if trainer.is_world_process_zero():
|
||||
with open(output_train_file, "w") as writer:
|
||||
logger.info("***** Train results *****")
|
||||
for key, value in sorted(train_result.metrics.items()):
|
||||
logger.info(f" {key} = {value}")
|
||||
writer.write(f"{key} = {value}\n")
|
||||
|
||||
# Need to save the state, since Trainer.save_model saves only the tokenizer with the model
|
||||
trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json"))
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if training_args.do_eval:
|
||||
@@ -496,7 +520,7 @@ def main():
|
||||
if trainer.is_world_process_zero():
|
||||
with open(output_eval_file, "w") as writer:
|
||||
logger.info("***** Eval results *****")
|
||||
for key, value in results.items():
|
||||
for key, value in sorted(results.items()):
|
||||
logger.info(f" {key} = {value}")
|
||||
writer.write(f"{key} = {value}\n")
|
||||
|
||||
|
||||
@@ -206,7 +206,7 @@ def postprocess_qa_predictions(
|
||||
|
||||
# Make `predictions` JSON-serializable by casting np.float back to float.
|
||||
all_nbest_json[example["id"]] = [
|
||||
{k: (float(v) if isinstance(v, (np.float32, np.float64)) else v) for k, v in pred.items()}
|
||||
{k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()}
|
||||
for pred in predictions
|
||||
]
|
||||
|
||||
@@ -394,7 +394,7 @@ def postprocess_qa_predictions_with_beam_search(
|
||||
|
||||
# Make `predictions` JSON-serializable by casting np.float back to float.
|
||||
all_nbest_json[example["id"]] = [
|
||||
{k: (float(v) if isinstance(v, (np.float32, np.float64)) else v) for k, v in pred.items()}
|
||||
{k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()}
|
||||
for pred in predictions
|
||||
]
|
||||
|
||||
|
||||
388
examples/research_projects/bertology/run_prune_gpt.py
Normal file
388
examples/research_projects/bertology/run_prune_gpt.py
Normal file
@@ -0,0 +1,388 @@
|
||||
#!/usr/bin/env python3
|
||||
""" This script is adapted from the Bertology pruning code (https://github.com/huggingface/transformers/blob/783d7d2629e97c5f0c5f9ef01b8c66410275c204/examples/research_projects/bertology/run_bertology.py)
|
||||
to prune GPT-like models. The author is @altsoph.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, RandomSampler, TensorDataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import GPT2LMHeadModel
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def save_model(model, dirpath):
|
||||
# save results
|
||||
if os.path.exists(dirpath):
|
||||
if os.path.exists(os.path.join(dirpath, "config.json")) and os.path.isfile(
|
||||
os.path.join(dirpath, "config.json")
|
||||
):
|
||||
os.remove(os.path.join(dirpath, "config.json"))
|
||||
if os.path.exists(os.path.join(dirpath, "pytorch_model.bin")) and os.path.isfile(
|
||||
os.path.join(dirpath, "pytorch_model.bin")
|
||||
):
|
||||
os.remove(os.path.join(dirpath, "pytorch_model.bin"))
|
||||
else:
|
||||
os.makedirs(dirpath)
|
||||
model.save_pretrained(dirpath)
|
||||
|
||||
|
||||
def entropy(p, unlogit=False):
|
||||
""" Compute the entropy of a probability distribution """
|
||||
exponent = 2
|
||||
if unlogit:
|
||||
p = torch.pow(p, exponent)
|
||||
plogp = p * torch.log(p)
|
||||
plogp[p == 0] = 0
|
||||
return -plogp.sum(dim=-1)
|
||||
|
||||
|
||||
def print_2d_tensor(tensor):
|
||||
""" Print a 2D tensor """
|
||||
logger.info("lv, h >\t" + "\t".join(f"{x + 1}" for x in range(len(tensor))))
|
||||
for row in range(len(tensor)):
|
||||
if tensor.dtype != torch.long:
|
||||
logger.info(f"layer {row + 1}:\t" + "\t".join(f"{x:.5f}" for x in tensor[row].cpu().data))
|
||||
else:
|
||||
logger.info(f"layer {row + 1}:\t" + "\t".join(f"{x:d}" for x in tensor[row].cpu().data))
|
||||
|
||||
|
||||
def compute_heads_importance(
|
||||
args, model, eval_dataloader, compute_entropy=True, compute_importance=True, head_mask=None, actually_pruned=False
|
||||
):
|
||||
"""This method shows how to compute:
|
||||
- head attention entropy
|
||||
- head importance scores according to http://arxiv.org/abs/1905.10650
|
||||
"""
|
||||
# Prepare our tensors
|
||||
n_layers, n_heads = model.config.num_hidden_layers, model.config.num_attention_heads
|
||||
head_importance = torch.zeros(n_layers, n_heads).to(args.device)
|
||||
attn_entropy = torch.zeros(n_layers, n_heads).to(args.device)
|
||||
|
||||
if head_mask is None:
|
||||
head_mask = torch.ones(n_layers, n_heads).to(args.device)
|
||||
|
||||
head_mask.requires_grad_(requires_grad=True)
|
||||
# If actually pruned attention multi-head, set head mask to None to avoid shape mismatch
|
||||
if actually_pruned:
|
||||
head_mask = None
|
||||
|
||||
tot_tokens = 0.0
|
||||
total_loss = 0.0
|
||||
for step, inputs in enumerate(tqdm(eval_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])):
|
||||
inputs = tuple(t.to(args.device) for t in inputs)
|
||||
(input_ids,) = inputs
|
||||
|
||||
# Do a forward pass (not with torch.no_grad() since we need gradients for importance score - see below)
|
||||
outputs = model(input_ids, labels=input_ids, head_mask=head_mask)
|
||||
# (loss), lm_logits, presents, (all hidden_states), (attentions)
|
||||
loss, _, all_attentions = (
|
||||
outputs[0],
|
||||
outputs[1],
|
||||
outputs[-1],
|
||||
) # Loss and logits are the first, attention the last
|
||||
loss.backward() # Backpropagate to populate the gradients in the head mask
|
||||
total_loss += loss.detach().cpu().numpy()
|
||||
if compute_entropy:
|
||||
for layer, attn in enumerate(all_attentions):
|
||||
masked_entropy = entropy(attn.detach(), True)
|
||||
attn_entropy[layer] += masked_entropy.sum(-1).sum(0).sum(0).detach()
|
||||
|
||||
if compute_importance:
|
||||
head_importance += head_mask.grad.abs().detach()
|
||||
tot_tokens += torch.ones_like(input_ids).float().detach().sum().data
|
||||
|
||||
# Normalize
|
||||
attn_entropy /= tot_tokens
|
||||
head_importance /= tot_tokens
|
||||
# Layerwise importance normalization
|
||||
if not args.dont_normalize_importance_by_layer:
|
||||
exponent = 2
|
||||
norm_by_layer = torch.pow(torch.pow(head_importance, exponent).sum(-1), 1 / exponent)
|
||||
head_importance /= norm_by_layer.unsqueeze(-1) + 1e-20
|
||||
|
||||
if not args.dont_normalize_global_importance:
|
||||
head_importance = (head_importance - head_importance.min()) / (head_importance.max() - head_importance.min())
|
||||
|
||||
# Print matrices
|
||||
if compute_entropy:
|
||||
logger.info("Attention entropies")
|
||||
print_2d_tensor(attn_entropy)
|
||||
if compute_importance:
|
||||
logger.info("Head importance scores")
|
||||
print_2d_tensor(head_importance)
|
||||
logger.info("Head ranked by importance scores")
|
||||
head_ranks = torch.zeros(head_importance.numel(), dtype=torch.long, device=args.device)
|
||||
head_ranks[head_importance.view(-1).sort(descending=True)[1]] = torch.arange(
|
||||
head_importance.numel(), device=args.device
|
||||
)
|
||||
head_ranks = head_ranks.view_as(head_importance)
|
||||
print_2d_tensor(head_ranks)
|
||||
return attn_entropy, head_importance, total_loss
|
||||
|
||||
|
||||
def mask_heads(args, model, eval_dataloader):
|
||||
"""This method shows how to mask head (set some heads to zero), to test the effect on the network,
|
||||
based on the head importance scores, as described in Michel et al. (http://arxiv.org/abs/1905.10650)
|
||||
"""
|
||||
_, head_importance, loss = compute_heads_importance(args, model, eval_dataloader, compute_entropy=False)
|
||||
original_score = 1 / loss # instead of downsteam score use the LM loss
|
||||
logger.info("Pruning: original score: %f, threshold: %f", original_score, original_score * args.masking_threshold)
|
||||
|
||||
new_head_mask = torch.ones_like(head_importance)
|
||||
num_to_mask = max(1, int(new_head_mask.numel() * args.masking_amount))
|
||||
|
||||
current_score = original_score
|
||||
while current_score >= original_score * args.masking_threshold:
|
||||
head_mask = new_head_mask.clone().detach() # save current head mask
|
||||
# heads from least important to most - keep only not-masked heads
|
||||
head_importance[head_mask == 0.0] = float("Inf")
|
||||
current_heads_to_mask = head_importance.view(-1).sort()[1]
|
||||
|
||||
if len(current_heads_to_mask) <= num_to_mask:
|
||||
print("BREAK BY num_to_mask")
|
||||
break
|
||||
|
||||
# mask heads
|
||||
current_heads_to_mask = current_heads_to_mask[:num_to_mask]
|
||||
logger.info("Heads to mask: %s", str(current_heads_to_mask.tolist()))
|
||||
new_head_mask = new_head_mask.view(-1)
|
||||
new_head_mask[current_heads_to_mask] = 0.0
|
||||
new_head_mask = new_head_mask.view_as(head_mask)
|
||||
new_head_mask = new_head_mask.clone().detach()
|
||||
print_2d_tensor(new_head_mask)
|
||||
|
||||
# Compute metric and head importance again
|
||||
_, head_importance, loss = compute_heads_importance(
|
||||
args, model, eval_dataloader, compute_entropy=False, head_mask=new_head_mask
|
||||
)
|
||||
current_score = 1 / loss
|
||||
logger.info(
|
||||
"Masking: current score: %f, remaining heads %d (%.1f percents)",
|
||||
current_score,
|
||||
new_head_mask.sum(),
|
||||
new_head_mask.sum() / new_head_mask.numel() * 100,
|
||||
)
|
||||
|
||||
logger.info("Final head mask")
|
||||
print_2d_tensor(head_mask)
|
||||
np.save(os.path.join(args.output_dir, "head_mask.npy"), head_mask.detach().cpu().numpy())
|
||||
|
||||
return head_mask
|
||||
|
||||
|
||||
def prune_heads(args, model, eval_dataloader, head_mask):
|
||||
"""This method shows how to prune head (remove heads weights) based on
|
||||
the head importance scores as described in Michel et al. (http://arxiv.org/abs/1905.10650)
|
||||
"""
|
||||
# Try pruning and test time speedup
|
||||
# Pruning is like masking but we actually remove the masked weights
|
||||
before_time = datetime.now()
|
||||
_, _, loss = compute_heads_importance(
|
||||
args, model, eval_dataloader, compute_entropy=False, compute_importance=False, head_mask=head_mask
|
||||
)
|
||||
score_masking = 1 / loss
|
||||
original_time = datetime.now() - before_time
|
||||
|
||||
original_num_params = sum(p.numel() for p in model.parameters())
|
||||
heads_to_prune = dict(
|
||||
(layer, (1 - head_mask[layer].long()).nonzero().squeeze().tolist()) for layer in range(len(head_mask))
|
||||
)
|
||||
|
||||
for k, v in heads_to_prune.items():
|
||||
if isinstance(v, int):
|
||||
heads_to_prune[k] = [
|
||||
v,
|
||||
]
|
||||
|
||||
assert sum(len(h) for h in heads_to_prune.values()) == (1 - head_mask.long()).sum().item()
|
||||
model.prune_heads(heads_to_prune)
|
||||
pruned_num_params = sum(p.numel() for p in model.parameters())
|
||||
|
||||
before_time = datetime.now()
|
||||
_, _, loss = compute_heads_importance(
|
||||
args,
|
||||
model,
|
||||
eval_dataloader,
|
||||
compute_entropy=False,
|
||||
compute_importance=False,
|
||||
head_mask=None,
|
||||
actually_pruned=True,
|
||||
)
|
||||
|
||||
score_pruning = 1 / loss
|
||||
new_time = datetime.now() - before_time
|
||||
|
||||
logger.info(
|
||||
"Pruning: original num of params: %.2e, after pruning %.2e (%.1f percents)",
|
||||
original_num_params,
|
||||
pruned_num_params,
|
||||
pruned_num_params / original_num_params * 100,
|
||||
)
|
||||
logger.info("Pruning: score with masking: %f score with pruning: %f", score_masking, score_pruning)
|
||||
logger.info("Pruning: speed ratio (original timing / new timing): %f percents", original_time / new_time * 100)
|
||||
save_model(model, args.output_dir)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
|
||||
# Other parameters
|
||||
parser.add_argument(
|
||||
"--config_name",
|
||||
default="",
|
||||
type=str,
|
||||
help="Pretrained config name or path if not the same as model_name_or_path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
default="",
|
||||
type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name_or_path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_subset", type=int, default=-1, help="If > 0: limit the data to a subset of data_subset instances."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite_output_dir", action="store_true", help="Whether to overwrite data in output directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dont_normalize_importance_by_layer", action="store_true", help="Don't normalize importance score by layers"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dont_normalize_global_importance",
|
||||
action="store_true",
|
||||
help="Don't normalize all importance scores between 0 and 1",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--try_masking", action="store_true", help="Whether to try to mask head until a threshold of accuracy."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--masking_threshold",
|
||||
default=0.9,
|
||||
type=float,
|
||||
help="masking threshold in term of metrics (stop masking when metric < threshold * original metric value).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--masking_amount", default=0.1, type=float, help="Amount to heads to masking at each masking step."
|
||||
)
|
||||
parser.add_argument("--metric_name", default="acc", type=str, help="Metric to use for head masking.")
|
||||
|
||||
parser.add_argument(
|
||||
"--max_seq_length",
|
||||
default=128,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after WordPiece tokenization. \n"
|
||||
"Sequences longer than this will be truncated, sequences shorter padded.",
|
||||
)
|
||||
parser.add_argument("--batch_size", default=1, type=int, help="Batch size.")
|
||||
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
|
||||
parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available")
|
||||
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
|
||||
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.server_ip and args.server_port:
|
||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||
import ptvsd
|
||||
|
||||
print("Waiting for debugger attach")
|
||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||
ptvsd.wait_for_attach()
|
||||
|
||||
# Setup devices and distributed training
|
||||
if args.local_rank == -1 or args.no_cuda:
|
||||
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
|
||||
else:
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
args.device = torch.device("cuda", args.local_rank)
|
||||
args.n_gpu = 1
|
||||
torch.distributed.init_process_group(backend="nccl") # Initializes the distributed backend
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
||||
logger.info("device: {} n_gpu: {}, distributed: {}".format(args.device, args.n_gpu, bool(args.local_rank != -1)))
|
||||
|
||||
model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path)
|
||||
|
||||
# Distributed and parallel training
|
||||
model.to(args.device)
|
||||
if args.local_rank != -1:
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
||||
)
|
||||
elif args.n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
# Print/save training arguments
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
torch.save(args, os.path.join(args.output_dir, "run_args.bin"))
|
||||
logger.info("Training/evaluation parameters %s", args)
|
||||
|
||||
# Prepare dataset
|
||||
numpy_data = np.concatenate(
|
||||
[
|
||||
np.loadtxt(args.data_dir, dtype=np.int64),
|
||||
]
|
||||
)
|
||||
train_tensor_dataset = (torch.from_numpy(numpy_data),)
|
||||
train_data = TensorDataset(*train_tensor_dataset)
|
||||
train_sampler = RandomSampler(train_data)
|
||||
eval_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.batch_size)
|
||||
|
||||
# Compute head entropy and importance score
|
||||
compute_heads_importance(args, model, eval_dataloader)
|
||||
|
||||
# Try head masking (set heads to zero until the score goes under a threshole)
|
||||
# and head pruning (remove masked heads and see the effect on the network)
|
||||
if args.try_masking and args.masking_threshold > 0.0 and args.masking_threshold < 1.0:
|
||||
head_mask = mask_heads(args, model, eval_dataloader)
|
||||
prune_heads(args, model, eval_dataloader, head_mask)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -90,7 +90,7 @@ torchvision==0.7.0
|
||||
tornado==6.0.4
|
||||
tqdm==4.48.2
|
||||
traitlets
|
||||
transformers==3.5.1
|
||||
git+https://github.com/huggingface/transformers.git
|
||||
urllib3==1.25.8
|
||||
wcwidth==0.2.5
|
||||
webencodings==0.5.1
|
||||
25
examples/research_projects/performer/README.md
Normal file
25
examples/research_projects/performer/README.md
Normal file
@@ -0,0 +1,25 @@
|
||||
# Performer fine-tuning
|
||||
|
||||
Example authors: @TevenLeScao, @Patrickvonplaten
|
||||
|
||||
Paper authors: Krzysztof Choromanski, Valerii Likhosherstov, David Dohan, Xingyou Song, Andreea Gane, Tamas Sarlos, Peter Hawkins, Jared Davis, Afroz Mohiuddin, Lukasz Kaiser, David Belanger, Lucy Colwell, Adrian Weller
|
||||
|
||||
## Requirements
|
||||
|
||||
`datasets`, `flax` and `jax`. `wandb` integration is built-in if you want to use it.
|
||||
|
||||
## Examples
|
||||
|
||||
`sanity_script.sh` will launch performer fine-tuning from the bert-base-cased checkpoint on the Simple Wikipedia dataset (a small, easy-language English Wikipedia) from `datasets`.
|
||||
`full_script.sh` will launch performer fine-tuning from the bert-large-cased checkpoint on the English Wikipedia dataset from `datasets`.
|
||||
|
||||
Here are a few key arguments:
|
||||
- Remove the `--performer` argument to use a standard Bert model.
|
||||
|
||||
- Add `--reinitialize` to start from a blank model rather than a Bert checkpoint.
|
||||
|
||||
- You may change the Bert size by passing a different [checkpoint](https://huggingface.co/transformers/pretrained_models.html) to the `--model_name_or_path` argument.
|
||||
|
||||
- Passing your user name to the `--wandb_user_name` argument will trigger weights and biases logging.
|
||||
|
||||
- You can choose a dataset with `--dataset_name` and `--dataset_config`. Our [viewer](https://huggingface.co/datasets/viewer/) will help you find what you need.
|
||||
1
examples/research_projects/performer/full_script.sh
Executable file
1
examples/research_projects/performer/full_script.sh
Executable file
@@ -0,0 +1 @@
|
||||
TOKENIZERS_PARALLELISM=true python run_mlm_performer.py --output_dir experiments --dataset_name wikipedia --dataset_config_name 20200501.en --model_name_or_path bert-large-cased --tokenizer_name bert-large-cased --do_train --overwrite_output_dir --per_device_train_batch_size 4 --learning_rate 5e-4 --warmup_steps 100 --num_train_epochs 3 --performer
|
||||
553
examples/research_projects/performer/modeling_flax_performer.py
Normal file
553
examples/research_projects/performer/modeling_flax_performer.py
Normal file
@@ -0,0 +1,553 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Callable, Dict, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.random import PRNGKey
|
||||
from modeling_flax_performer_utils import make_fast_softmax_attention
|
||||
from transformers.file_utils import add_start_docstrings
|
||||
from transformers.modeling_flax_utils import ACT2FN
|
||||
from transformers.models.bert.configuration_bert import BertConfig
|
||||
from transformers.models.bert.modeling_flax_bert import FlaxBertOnlyMLMHead, FlaxBertPreTrainedModel
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "BertConfig"
|
||||
_TOKENIZER_FOR_DOC = "BertTokenizer"
|
||||
|
||||
BERT_START_DOCSTRING = r"""
|
||||
|
||||
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
|
||||
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
|
||||
pruning heads etc.)
|
||||
|
||||
This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
|
||||
subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
|
||||
general usage and behavior.
|
||||
|
||||
Parameters:
|
||||
config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
|
||||
weights.
|
||||
"""
|
||||
|
||||
BERT_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
|
||||
Indices can be obtained using :class:`~transformers.BertTokenizer`. See
|
||||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
|
||||
details.
|
||||
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
|
||||
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
|
||||
1]``:
|
||||
|
||||
- 0 corresponds to a `sentence A` token,
|
||||
- 1 corresponds to a `sentence B` token.
|
||||
|
||||
`What are token type IDs? <../glossary.html#token-type-ids>`_
|
||||
position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
|
||||
config.max_position_embeddings - 1]``.
|
||||
|
||||
`What are position IDs? <../glossary.html#position-ids>`_
|
||||
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
|
||||
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
|
||||
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
|
||||
vectors than the model's internal embedding lookup matrix.
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (:obj:`bool`, `optional`):
|
||||
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
||||
more detail.
|
||||
return_dict (:obj:`bool`, `optional`):
|
||||
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
class FlaxPerformerLayerNorm(nn.Module):
|
||||
"""
|
||||
Layer normalization (https://arxiv.org/abs/1607.06450). Operates on the last axis of the input data.
|
||||
"""
|
||||
|
||||
epsilon: float = 1e-6
|
||||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
||||
bias: bool = True # If True, bias (beta) is added.
|
||||
scale: bool = True # If True, multiply by scale (gamma). When the next layer is linear
|
||||
# (also e.g. nn.relu), this can be disabled since the scaling will be
|
||||
# done by the next layer.
|
||||
bias_init: jnp.ndarray = nn.initializers.zeros
|
||||
scale_init: jnp.ndarray = nn.initializers.ones
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
"""
|
||||
Applies layer normalization on the input. It normalizes the activations of the layer for each given example in
|
||||
a batch independently, rather than across a batch like Batch Normalization. i.e. applies a transformation that
|
||||
maintains the mean activation within each example close to 0 and the activation standard deviation close to 1
|
||||
|
||||
Args:
|
||||
x: the inputs
|
||||
|
||||
Returns:
|
||||
Normalized inputs (the same shape as inputs).
|
||||
"""
|
||||
features = x.shape[-1]
|
||||
mean = jnp.mean(x, axis=-1, keepdims=True)
|
||||
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
|
||||
var = mean2 - jax.lax.square(mean)
|
||||
mul = jax.lax.rsqrt(var + self.epsilon)
|
||||
if self.scale:
|
||||
mul = mul * jnp.asarray(self.param("gamma", self.scale_init, (features,)), self.dtype)
|
||||
y = (x - mean) * mul
|
||||
if self.bias:
|
||||
y = y + jnp.asarray(self.param("beta", self.bias_init, (features,)), self.dtype)
|
||||
return y
|
||||
|
||||
|
||||
class FlaxPerformerEmbedding(nn.Module):
|
||||
"""
|
||||
Specify a new class for doing the embedding stuff as Flax's one use 'embedding' for the parameter name and PyTorch
|
||||
use 'weight'
|
||||
"""
|
||||
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
emb_init: Callable[..., np.ndarray] = nn.initializers.normal(stddev=0.1)
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, inputs):
|
||||
embedding = self.param("weight", self.emb_init, (self.vocab_size, self.hidden_size))
|
||||
return jnp.take(embedding, inputs, axis=0)
|
||||
|
||||
|
||||
class FlaxPerformerEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word, position and token_type embeddings."""
|
||||
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
type_vocab_size: int
|
||||
max_length: int
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask):
|
||||
# Embed
|
||||
w_emb = FlaxPerformerEmbedding(self.vocab_size, self.hidden_size, name="word_embeddings")(
|
||||
jnp.atleast_2d(input_ids.astype("i4"))
|
||||
)
|
||||
p_emb = FlaxPerformerEmbedding(self.max_length, self.hidden_size, name="position_embeddings")(
|
||||
jnp.atleast_2d(position_ids.astype("i4"))
|
||||
)
|
||||
t_emb = FlaxPerformerEmbedding(self.type_vocab_size, self.hidden_size, name="token_type_embeddings")(
|
||||
jnp.atleast_2d(token_type_ids.astype("i4"))
|
||||
)
|
||||
|
||||
# Sum all embeddings
|
||||
summed_emb = w_emb + jnp.broadcast_to(p_emb, w_emb.shape) + t_emb
|
||||
|
||||
# Layer Norm
|
||||
layer_norm = FlaxPerformerLayerNorm(name="layer_norm")(summed_emb)
|
||||
|
||||
return layer_norm
|
||||
|
||||
|
||||
class FlaxPerformerAttention(nn.Module):
|
||||
num_heads: int
|
||||
head_size: int
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_state, attention_mask):
|
||||
single_head_dim = self.head_size // self.num_heads
|
||||
fast_softmax_attention = make_fast_softmax_attention(qkv_dim=single_head_dim)
|
||||
self_att = nn.attention.SelfAttention(
|
||||
num_heads=self.num_heads, qkv_features=self.head_size, name="self", attention_fn=fast_softmax_attention
|
||||
)(hidden_state, attention_mask)
|
||||
|
||||
layer_norm = FlaxPerformerLayerNorm(name="layer_norm")(self_att + hidden_state)
|
||||
return layer_norm
|
||||
|
||||
|
||||
class FlaxPerformerIntermediate(nn.Module):
|
||||
output_size: int
|
||||
hidden_act: str = "gelu"
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_state):
|
||||
# TODO: Add ACT2FN reference to change activation function
|
||||
dense = nn.Dense(features=self.output_size, name="dense")(hidden_state)
|
||||
return ACT2FN[self.hidden_act](dense)
|
||||
|
||||
|
||||
class FlaxPerformerOutput(nn.Module):
|
||||
@nn.compact
|
||||
def __call__(self, intermediate_output, attention_output):
|
||||
hidden_state = nn.Dense(attention_output.shape[-1], name="dense")(intermediate_output)
|
||||
hidden_state = FlaxPerformerLayerNorm(name="layer_norm")(hidden_state + attention_output)
|
||||
return hidden_state
|
||||
|
||||
|
||||
class FlaxPerformerLayer(nn.Module):
|
||||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
hidden_act: str = "gelu"
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_state, attention_mask):
|
||||
attention = FlaxPerformerAttention(self.num_heads, self.head_size, name="attention")(
|
||||
hidden_state, attention_mask
|
||||
)
|
||||
intermediate = FlaxPerformerIntermediate(
|
||||
self.intermediate_size, name="intermediate", hidden_act=self.hidden_act
|
||||
)(attention)
|
||||
output = FlaxPerformerOutput(name="output")(intermediate, attention)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class FlaxPerformerLayerCollection(nn.Module):
|
||||
"""
|
||||
Stores N BertLayer(s)
|
||||
"""
|
||||
|
||||
num_layers: int
|
||||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
hidden_act: str = "gelu"
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, inputs, attention_mask):
|
||||
assert self.num_layers > 0, f"num_layers should be >= 1, got ({self.num_layers})"
|
||||
|
||||
# Initialize input / output
|
||||
input_i = inputs
|
||||
|
||||
# Forward over all encoders
|
||||
for i in range(self.num_layers):
|
||||
layer = FlaxPerformerLayer(
|
||||
self.num_heads, self.head_size, self.intermediate_size, hidden_act=self.hidden_act, name=f"{i}"
|
||||
)
|
||||
input_i = layer(input_i, attention_mask)
|
||||
return input_i
|
||||
|
||||
|
||||
class FlaxPerformerEncoder(nn.Module):
|
||||
num_layers: int
|
||||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
hidden_act: str = "gelu"
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_state, attention_mask):
|
||||
layer = FlaxPerformerLayerCollection(
|
||||
self.num_layers,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
self.intermediate_size,
|
||||
name="layer",
|
||||
hidden_act=self.hidden_act,
|
||||
)(hidden_state, attention_mask)
|
||||
return layer
|
||||
|
||||
|
||||
class FlaxPerformerPooler(nn.Module):
|
||||
@nn.compact
|
||||
def __call__(self, hidden_state):
|
||||
cls_token = hidden_state[:, 0]
|
||||
out = nn.Dense(hidden_state.shape[-1], name="dense")(cls_token)
|
||||
return jax.lax.tanh(out)
|
||||
|
||||
|
||||
class FlaxPerformerModule(nn.Module):
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
type_vocab_size: int
|
||||
max_length: int
|
||||
num_encoder_layers: int
|
||||
num_heads: int
|
||||
head_size: int
|
||||
intermediate_size: int
|
||||
hidden_act: str = "gelu"
|
||||
add_pooling_layer: bool = True
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask):
|
||||
# Embedding
|
||||
embeddings = FlaxPerformerEmbeddings(
|
||||
self.vocab_size, self.hidden_size, self.type_vocab_size, self.max_length, name="embeddings"
|
||||
)(input_ids, token_type_ids, position_ids, attention_mask)
|
||||
|
||||
# N stacked encoding layers
|
||||
encoder = FlaxPerformerEncoder(
|
||||
self.num_encoder_layers,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
name="encoder",
|
||||
)(embeddings, attention_mask)
|
||||
|
||||
if not self.add_pooling_layer:
|
||||
return encoder
|
||||
|
||||
pooled = FlaxPerformerPooler(name="pooler")(encoder)
|
||||
return encoder, pooled
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
BERT_START_DOCSTRING,
|
||||
)
|
||||
class FlaxPerformerModel(FlaxBertPreTrainedModel):
|
||||
"""
|
||||
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
||||
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
|
||||
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
||||
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
||||
"""
|
||||
|
||||
model_class = FlaxPerformerModule
|
||||
config_class = BertConfig
|
||||
base_model_prefix = "bert"
|
||||
|
||||
@staticmethod
|
||||
def convert_from_pytorch(pt_state: Dict, config: BertConfig) -> Dict:
|
||||
jax_state = dict(pt_state)
|
||||
|
||||
# Need to change some parameters name to match Flax names so that we don't have to fork any layer
|
||||
for key, tensor in pt_state.items():
|
||||
# Key parts
|
||||
key_parts = set(key.split("."))
|
||||
|
||||
# Every dense layer has "kernel" parameters instead of "weight"
|
||||
if "dense.weight" in key:
|
||||
del jax_state[key]
|
||||
key = key.replace("weight", "kernel")
|
||||
jax_state[key] = tensor
|
||||
|
||||
# SelfAttention needs also to replace "weight" by "kernel"
|
||||
if {"query", "key", "value"} & key_parts:
|
||||
|
||||
# Flax SelfAttention decomposes the heads (num_head, size // num_heads)
|
||||
if "bias" in key:
|
||||
jax_state[key] = tensor.reshape((config.num_attention_heads, -1))
|
||||
elif "weight":
|
||||
del jax_state[key]
|
||||
key = key.replace("weight", "kernel")
|
||||
tensor = tensor.reshape((config.num_attention_heads, -1, config.hidden_size)).transpose((2, 0, 1))
|
||||
jax_state[key] = tensor
|
||||
|
||||
# SelfAttention output is not a separate layer, remove one nesting
|
||||
if "attention.output.dense" in key:
|
||||
del jax_state[key]
|
||||
key = key.replace("attention.output.dense", "attention.self.out")
|
||||
jax_state[key] = tensor
|
||||
|
||||
# SelfAttention output is not a separate layer, remove nesting on layer norm
|
||||
if "attention.output.LayerNorm" in key:
|
||||
del jax_state[key]
|
||||
key = key.replace("attention.output.LayerNorm", "attention.LayerNorm")
|
||||
jax_state[key] = tensor
|
||||
|
||||
# There are some transposed parameters w.r.t their PyTorch counterpart
|
||||
if "intermediate.dense.kernel" in key or "output.dense.kernel" in key:
|
||||
jax_state[key] = tensor.T
|
||||
|
||||
# Self Attention output projection needs to be transposed
|
||||
if "out.kernel" in key:
|
||||
jax_state[key] = tensor.reshape((config.hidden_size, config.num_attention_heads, -1)).transpose(
|
||||
1, 2, 0
|
||||
)
|
||||
|
||||
# Pooler needs to transpose its kernel
|
||||
if "pooler.dense.kernel" in key:
|
||||
jax_state[key] = tensor.T
|
||||
|
||||
# Handle LayerNorm conversion
|
||||
if "LayerNorm" in key:
|
||||
del jax_state[key]
|
||||
|
||||
# Replace LayerNorm by layer_norm
|
||||
new_key = key.replace("LayerNorm", "layer_norm")
|
||||
|
||||
if "weight" in key:
|
||||
new_key = new_key.replace("weight", "gamma")
|
||||
elif "bias" in key:
|
||||
new_key = new_key.replace("bias", "beta")
|
||||
|
||||
jax_state[new_key] = tensor
|
||||
|
||||
return jax_state
|
||||
|
||||
def __init__(
|
||||
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
||||
):
|
||||
module = FlaxPerformerModule(
|
||||
vocab_size=config.vocab_size,
|
||||
hidden_size=config.hidden_size,
|
||||
type_vocab_size=config.type_vocab_size,
|
||||
max_length=config.max_position_embeddings,
|
||||
num_encoder_layers=config.num_hidden_layers,
|
||||
num_heads=config.num_attention_heads,
|
||||
head_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
dropout_rate=config.hidden_dropout_prob,
|
||||
hidden_act=config.hidden_act,
|
||||
)
|
||||
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
|
||||
@property
|
||||
def module(self) -> nn.Module:
|
||||
return self._module
|
||||
|
||||
def __call__(
|
||||
self, input_ids, token_type_ids=None, position_ids=None, dropout_rng: PRNGKey = None, attention_mask=None
|
||||
):
|
||||
|
||||
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||
input_ids, attention_mask, token_type_ids, position_ids
|
||||
)
|
||||
|
||||
# Handle any PRNG if needed
|
||||
rngs = {}
|
||||
if dropout_rng is not None:
|
||||
rngs["dropout"] = dropout_rng
|
||||
|
||||
return self.module.apply(
|
||||
{"params": self.params},
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(token_type_ids, dtype="i4"),
|
||||
jnp.array(position_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
rng=rngs,
|
||||
)
|
||||
|
||||
|
||||
class FlaxPerformerForMaskedLM(FlaxBertPreTrainedModel):
|
||||
def __init__(
|
||||
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
|
||||
):
|
||||
module = FlaxPerformerForMaskedLMModule(
|
||||
vocab_size=config.vocab_size,
|
||||
type_vocab_size=config.type_vocab_size,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
head_size=config.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_encoder_layers=config.num_hidden_layers,
|
||||
max_length=config.max_position_embeddings,
|
||||
hidden_act=config.hidden_act,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
params: dict = None,
|
||||
train: bool = False,
|
||||
dropout_rng: PRNGKey = None,
|
||||
):
|
||||
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||
input_ids, attention_mask, token_type_ids, position_ids
|
||||
)
|
||||
|
||||
# Handle any PRNG if needed
|
||||
rngs = {}
|
||||
if dropout_rng is not None:
|
||||
rngs["dropout"] = dropout_rng
|
||||
|
||||
return self.module.apply(
|
||||
{"params": params or self.params},
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
jnp.array(token_type_ids, dtype="i4"),
|
||||
jnp.array(position_ids, dtype="i4"),
|
||||
not train,
|
||||
rngs=rngs,
|
||||
)
|
||||
|
||||
|
||||
class FlaxPerformerForMaskedLMModule(nn.Module):
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
head_size: int
|
||||
num_heads: int
|
||||
num_encoder_layers: int
|
||||
type_vocab_size: int
|
||||
max_length: int
|
||||
hidden_act: str
|
||||
dropout_rate: float = 0.0
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
@nn.compact
|
||||
def __call__(
|
||||
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True
|
||||
):
|
||||
# Model
|
||||
encoder = FlaxPerformerModule(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
max_length=self.max_length,
|
||||
num_encoder_layers=self.num_encoder_layers,
|
||||
num_heads=self.num_heads,
|
||||
head_size=self.hidden_size,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
add_pooling_layer=False,
|
||||
name="bert",
|
||||
)(input_ids, attention_mask, token_type_ids, position_ids)
|
||||
|
||||
# Compute the prediction scores
|
||||
encoder = nn.Dropout(rate=self.dropout_rate)(encoder, deterministic=deterministic)
|
||||
logits = FlaxBertOnlyMLMHead(
|
||||
vocab_size=self.vocab_size, hidden_act=self.hidden_act, name="cls", dtype=self.dtype
|
||||
)(encoder)
|
||||
|
||||
return (logits,)
|
||||
@@ -0,0 +1,660 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The Google Research Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
IMPORTANT:
|
||||
|
||||
This code was copied from
|
||||
https://github.com/google-research/google-research/blob/master/performer/fast_self_attention/fast_self_attention.py on
|
||||
6/11/2020. This is very new code, so it might be prone to change soon -> make sure to check the original code and
|
||||
update accordingly
|
||||
|
||||
Core Fast Attention Module for Flax. Implementation of the approximate fast softmax and generalized attention mechanism
|
||||
leveraging structured random feature maps [RFM] techniques and low rank decomposition of the attention matrix.
|
||||
"""
|
||||
# pylint: disable=invalid-name, missing-function-docstring, line-too-long
|
||||
|
||||
import abc
|
||||
import functools
|
||||
from collections.abc import Iterable # pylint: disable=g-importing-member
|
||||
|
||||
import numpy as onp
|
||||
from absl import logging
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import lax, random
|
||||
|
||||
|
||||
def nonnegative_softmax_kernel_feature_creator(
|
||||
data, projection_matrix, attention_dims_t, batch_dims_t, precision, is_query, normalize_data=True, eps=0.0001
|
||||
):
|
||||
"""
|
||||
Constructs nonnegative kernel features for fast softmax attention
|
||||
|
||||
Args:
|
||||
data: input for which features are computes
|
||||
projection_matrix: random matrix used to compute features
|
||||
attention_dims_t: tuple of attention dimensions
|
||||
batch_dims_t: tuple of batch dimensions
|
||||
precision: precision parameter
|
||||
is_query: predicate indicating whether input data corresponds to queries or
|
||||
keys
|
||||
normalize_data: predicate indicating whether data should be normalized,
|
||||
eps: numerical stabilizer
|
||||
|
||||
Returns:
|
||||
Random features for fast softmax attention.
|
||||
"""
|
||||
del attention_dims_t
|
||||
if normalize_data:
|
||||
# We have e^{qk^T/sqrt{d}} = e^{q_norm k_norm^T}, where
|
||||
# w_norm = w * data_normalizer for w in {q,k}.
|
||||
data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(data.shape[-1])))
|
||||
else:
|
||||
data_normalizer = 1.0
|
||||
ratio = 1.0 / jnp.sqrt(projection_matrix.shape[0])
|
||||
data_mod_shape = data.shape[0 : len(batch_dims_t)] + projection_matrix.shape
|
||||
data_thick_random_matrix = jnp.zeros(data_mod_shape) + projection_matrix
|
||||
|
||||
data_dash = lax.dot_general(
|
||||
data_normalizer * data,
|
||||
data_thick_random_matrix,
|
||||
(((data.ndim - 1,), (data_thick_random_matrix.ndim - 1,)), (batch_dims_t, batch_dims_t)),
|
||||
precision=precision,
|
||||
)
|
||||
|
||||
diag_data = jnp.square(data)
|
||||
diag_data = jnp.sum(diag_data, axis=data.ndim - 1)
|
||||
diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer
|
||||
diag_data = jnp.expand_dims(diag_data, axis=data.ndim - 1)
|
||||
|
||||
if is_query:
|
||||
last_dims_t = (len(data_dash.shape) - 1,)
|
||||
data_dash = ratio * (
|
||||
jnp.exp(data_dash - diag_data - jnp.max(data_dash, axis=last_dims_t, keepdims=True)) + eps
|
||||
)
|
||||
else:
|
||||
data_dash = ratio * (jnp.exp(data_dash - diag_data - jnp.max(data_dash)) + eps)
|
||||
|
||||
return data_dash
|
||||
|
||||
|
||||
def sincos_softmax_kernel_feature_creator(
|
||||
data, projection_matrix, attention_dims_t, batch_dims_t, precision, normalize_data=True
|
||||
):
|
||||
"""
|
||||
Constructs kernel sin-cos features for fast softmax attention
|
||||
|
||||
Args:
|
||||
data: input for which features are computes
|
||||
projection_matrix: random matrix used to compute features
|
||||
attention_dims_t: tuple of attention dimensions
|
||||
batch_dims_t: tuple of batch dimensions
|
||||
precision: precision parameter
|
||||
normalize_data: predicate indicating whether data should be normalized
|
||||
|
||||
Returns:
|
||||
Random features for fast softmax attention.
|
||||
"""
|
||||
if normalize_data:
|
||||
# We have: exp(qk^T/sqrt{d}) = exp(|q|^2/2sqrt{d}) * exp(|k|^2/2sqrt{d}) *
|
||||
# exp(-(|q*c-k*c|^2)/2), where c = 1.0 / sqrt{sqrt{d}}.
|
||||
data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(data.shape[-1])))
|
||||
else:
|
||||
data_normalizer = 1.0
|
||||
ratio = 1.0 / jnp.sqrt(projection_matrix.shape[0])
|
||||
data_mod_shape = data.shape[0 : len(batch_dims_t)] + projection_matrix.shape
|
||||
data_thick_random_matrix = jnp.zeros(data_mod_shape) + projection_matrix
|
||||
|
||||
data_dash = lax.dot_general(
|
||||
data_normalizer * data,
|
||||
data_thick_random_matrix,
|
||||
(((data.ndim - 1,), (data_thick_random_matrix.ndim - 1,)), (batch_dims_t, batch_dims_t)),
|
||||
precision=precision,
|
||||
)
|
||||
data_dash_cos = ratio * jnp.cos(data_dash)
|
||||
data_dash_sin = ratio * jnp.sin(data_dash)
|
||||
data_dash = jnp.concatenate((data_dash_cos, data_dash_sin), axis=-1)
|
||||
|
||||
# Constructing D_data and data^{'}
|
||||
diag_data = jnp.square(data)
|
||||
diag_data = jnp.sum(diag_data, axis=data.ndim - 1)
|
||||
diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer
|
||||
diag_data = jnp.expand_dims(diag_data, axis=data.ndim - 1)
|
||||
# Additional renormalization for numerical stability
|
||||
data_renormalizer = jnp.max(diag_data, attention_dims_t, keepdims=True)
|
||||
diag_data -= data_renormalizer
|
||||
diag_data = jnp.exp(diag_data)
|
||||
data_prime = data_dash * diag_data
|
||||
return data_prime
|
||||
|
||||
|
||||
def generalized_kernel_feature_creator(
|
||||
data, projection_matrix, batch_dims_t, precision, kernel_fn, kernel_epsilon, normalize_data
|
||||
):
|
||||
"""
|
||||
Constructs kernel features for fast generalized attention
|
||||
|
||||
Args:
|
||||
data: input for which features are computes
|
||||
projection_matrix: matrix used to compute features
|
||||
batch_dims_t: tuple of batch dimensions
|
||||
precision: precision parameter
|
||||
kernel_fn: kernel function used
|
||||
kernel_epsilon: additive positive term added to every feature for numerical
|
||||
stability
|
||||
normalize_data: predicate indicating whether data should be normalized
|
||||
|
||||
Returns:
|
||||
Random features for fast generalized attention.
|
||||
"""
|
||||
if normalize_data:
|
||||
data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(data.shape[-1])))
|
||||
else:
|
||||
data_normalizer = 1.0
|
||||
if projection_matrix is None:
|
||||
return kernel_fn(data_normalizer * data) + kernel_epsilon
|
||||
else:
|
||||
data_mod_shape = data.shape[0 : len(batch_dims_t)] + projection_matrix.shape
|
||||
data_thick_random_matrix = jnp.zeros(data_mod_shape) + projection_matrix
|
||||
data_dash = lax.dot_general(
|
||||
data_normalizer * data,
|
||||
data_thick_random_matrix,
|
||||
(((data.ndim - 1,), (data_thick_random_matrix.ndim - 1,)), (batch_dims_t, batch_dims_t)),
|
||||
precision=precision,
|
||||
)
|
||||
data_prime = kernel_fn(data_dash) + kernel_epsilon
|
||||
return data_prime
|
||||
|
||||
|
||||
def make_fast_softmax_attention(
|
||||
qkv_dim,
|
||||
renormalize_attention=True,
|
||||
numerical_stabilizer=0.000001,
|
||||
nb_features=256,
|
||||
ortho_features=True,
|
||||
ortho_scaling=0.0,
|
||||
redraw_features=True,
|
||||
unidirectional=False,
|
||||
nonnegative_features=True,
|
||||
lax_scan_unroll=1,
|
||||
):
|
||||
"""Construct a fast softmax attention method."""
|
||||
logging.info(
|
||||
"Fast softmax attention: %s features and orthogonal=%s, renormalize=%s",
|
||||
nb_features,
|
||||
ortho_features,
|
||||
renormalize_attention,
|
||||
)
|
||||
if ortho_features:
|
||||
matrix_creator = functools.partial(GaussianOrthogonalRandomMatrix, nb_features, qkv_dim, scaling=ortho_scaling)
|
||||
else:
|
||||
matrix_creator = functools.partial(GaussianUnstructuredRandomMatrix, nb_features, qkv_dim)
|
||||
if nonnegative_features:
|
||||
|
||||
def kernel_feature_creator(
|
||||
data, projection_matrix, attention_dims_t, batch_dims_t, precision, is_query, normalize_data=True
|
||||
):
|
||||
return nonnegative_softmax_kernel_feature_creator(
|
||||
data,
|
||||
projection_matrix,
|
||||
attention_dims_t,
|
||||
batch_dims_t,
|
||||
precision,
|
||||
is_query,
|
||||
normalize_data,
|
||||
numerical_stabilizer,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
def kernel_feature_creator(
|
||||
data, projection_matrix, attention_dims_t, batch_dims_t, precision, is_query, normalize_data=True
|
||||
):
|
||||
del is_query
|
||||
return sincos_softmax_kernel_feature_creator(
|
||||
data, projection_matrix, attention_dims_t, batch_dims_t, precision, normalize_data
|
||||
)
|
||||
|
||||
attention_fn = FastAttentionviaLowRankDecomposition(
|
||||
matrix_creator,
|
||||
kernel_feature_creator,
|
||||
renormalize_attention=renormalize_attention,
|
||||
numerical_stabilizer=numerical_stabilizer,
|
||||
redraw_features=redraw_features,
|
||||
unidirectional=unidirectional,
|
||||
lax_scan_unroll=lax_scan_unroll,
|
||||
).dot_product_attention
|
||||
return attention_fn
|
||||
|
||||
|
||||
def make_fast_generalized_attention(
|
||||
qkv_dim,
|
||||
renormalize_attention=True,
|
||||
numerical_stabilizer=0.0,
|
||||
nb_features=256,
|
||||
features_type="deterministic",
|
||||
kernel_fn=jax.nn.relu,
|
||||
kernel_epsilon=0.001,
|
||||
redraw_features=False,
|
||||
unidirectional=False,
|
||||
lax_scan_unroll=1,
|
||||
):
|
||||
"""Construct a fast generalized attention menthod."""
|
||||
logging.info("Fast generalized attention.: %s features and renormalize=%s", nb_features, renormalize_attention)
|
||||
if features_type == "ortho":
|
||||
matrix_creator = functools.partial(GaussianOrthogonalRandomMatrix, nb_features, qkv_dim, scaling=False)
|
||||
elif features_type == "iid":
|
||||
matrix_creator = functools.partial(GaussianUnstructuredRandomMatrix, nb_features, qkv_dim)
|
||||
elif features_type == "deterministic":
|
||||
matrix_creator = None
|
||||
else:
|
||||
raise ValueError("Unknown feature value type")
|
||||
|
||||
def kernel_feature_creator(
|
||||
data, projection_matrix, attention_dims_t, batch_dims_t, precision, is_query, normalize_data=False
|
||||
):
|
||||
del attention_dims_t
|
||||
del is_query
|
||||
return generalized_kernel_feature_creator(
|
||||
data, projection_matrix, batch_dims_t, precision, kernel_fn, kernel_epsilon, normalize_data
|
||||
)
|
||||
|
||||
attention_fn = FastAttentionviaLowRankDecomposition(
|
||||
matrix_creator,
|
||||
kernel_feature_creator,
|
||||
renormalize_attention=renormalize_attention,
|
||||
numerical_stabilizer=numerical_stabilizer,
|
||||
redraw_features=redraw_features,
|
||||
unidirectional=unidirectional,
|
||||
lax_scan_unroll=lax_scan_unroll,
|
||||
).dot_product_attention
|
||||
return attention_fn
|
||||
|
||||
|
||||
class RandomMatrix(object):
|
||||
r"""
|
||||
Abstract class providing a method for constructing 2D random arrays. Class is responsible for constructing 2D
|
||||
random arrays.
|
||||
"""
|
||||
|
||||
__metaclass__ = abc.ABCMeta
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_2d_array(self):
|
||||
raise NotImplementedError("Abstract method")
|
||||
|
||||
|
||||
class GaussianUnstructuredRandomMatrix(RandomMatrix):
|
||||
def __init__(self, nb_rows, nb_columns, key):
|
||||
self.nb_rows = nb_rows
|
||||
self.nb_columns = nb_columns
|
||||
self.key = key
|
||||
|
||||
def get_2d_array(self):
|
||||
return random.normal(self.key, (self.nb_rows, self.nb_columns))
|
||||
|
||||
|
||||
class GaussianOrthogonalRandomMatrix(RandomMatrix):
|
||||
r"""
|
||||
Class providing a method to create Gaussian orthogonal matrix. Class is responsible for constructing 2D Gaussian
|
||||
orthogonal arrays.
|
||||
"""
|
||||
|
||||
def __init__(self, nb_rows, nb_columns, key, scaling=0):
|
||||
self.nb_rows = nb_rows
|
||||
self.nb_columns = nb_columns
|
||||
self.key = key
|
||||
self.scaling = scaling
|
||||
|
||||
def get_2d_array(self):
|
||||
nb_full_blocks = int(self.nb_rows / self.nb_columns)
|
||||
block_list = []
|
||||
rng = self.key
|
||||
for _ in range(nb_full_blocks):
|
||||
rng, rng_input = jax.random.split(rng)
|
||||
unstructured_block = random.normal(rng_input, (self.nb_columns, self.nb_columns))
|
||||
q, _ = jnp.linalg.qr(unstructured_block)
|
||||
q = jnp.transpose(q)
|
||||
block_list.append(q)
|
||||
remaining_rows = self.nb_rows - nb_full_blocks * self.nb_columns
|
||||
if remaining_rows > 0:
|
||||
rng, rng_input = jax.random.split(rng)
|
||||
unstructured_block = random.normal(rng_input, (self.nb_columns, self.nb_columns))
|
||||
q, _ = jnp.linalg.qr(unstructured_block)
|
||||
q = jnp.transpose(q)
|
||||
block_list.append(q[0:remaining_rows])
|
||||
final_matrix = jnp.vstack(block_list)
|
||||
|
||||
if self.scaling == 0:
|
||||
multiplier = jnp.linalg.norm(random.normal(self.key, (self.nb_rows, self.nb_columns)), axis=1)
|
||||
elif self.scaling == 1:
|
||||
multiplier = jnp.sqrt(float(self.nb_columns)) * jnp.ones((self.nb_rows))
|
||||
else:
|
||||
raise ValueError("Scaling must be one of {0, 1}. Was %s" % self._scaling)
|
||||
|
||||
return jnp.matmul(jnp.diag(multiplier), final_matrix)
|
||||
|
||||
|
||||
class FastAttention(object):
|
||||
r"""
|
||||
Abstract class providing a method for fast attention. Class is responsible for providing a method
|
||||
<dot_product_attention> for fast approximate attention.
|
||||
"""
|
||||
|
||||
__metaclass__ = abc.ABCMeta
|
||||
|
||||
@abc.abstractmethod
|
||||
def dot_product_attention(
|
||||
self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
dtype=jnp.float32,
|
||||
bias=None,
|
||||
axis=None,
|
||||
broadcast_dropout=True,
|
||||
dropout_rng=None,
|
||||
dropout_rate=0.0,
|
||||
deterministic=False,
|
||||
precision=None,
|
||||
):
|
||||
"""
|
||||
Computes dot-product attention given query, key, and value. This is the core function for applying fast
|
||||
approximate dot-product attention. It calculates the attention weights given query and key and combines the
|
||||
values using the attention weights. This function supports multi-dimensional inputs
|
||||
|
||||
Args:
|
||||
query: queries for calculating attention with shape of [batch_size, dim1,
|
||||
dim2, ..., dimN, num_heads, mem_channels].
|
||||
key: keys for calculating attention with shape of [batch_size, dim1, dim2,
|
||||
..., dimN, num_heads, mem_channels].
|
||||
value: values to be used in attention with shape of [batch_size, dim1,
|
||||
dim2,..., dimN, num_heads, value_channels].
|
||||
dtype: the dtype of the computation (default: float32)
|
||||
bias: bias for the attention weights. This can be used for incorporating
|
||||
autoregressive mask, padding mask, proximity bias.
|
||||
axis: axises over which the attention is applied.
|
||||
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
|
||||
dropout_rng: JAX PRNGKey: to be used for dropout.
|
||||
dropout_rate: dropout rate.
|
||||
deterministic: bool, deterministic or not (to apply dropout).
|
||||
precision: numerical precision of the computation see `jax.lax.Precision`
|
||||
for details
|
||||
|
||||
Returns:
|
||||
Output of shape [bs, dim1, dim2, ..., dimN,, num_heads, value_channels].
|
||||
"""
|
||||
raise NotImplementedError("Abstract method")
|
||||
|
||||
|
||||
def _numerator(z_slice_shape, precision, unroll=1):
|
||||
def fwd(qs, ks, vs):
|
||||
def body(p, qkv):
|
||||
(q, k, v) = qkv
|
||||
p += jnp.einsum("...m,...d->...md", k, v, precision=precision)
|
||||
X_slice = jnp.einsum("...m,...md->...d", q, p, precision=precision)
|
||||
return p, X_slice
|
||||
|
||||
init_value = jnp.zeros(z_slice_shape)
|
||||
p, W = lax.scan(body, init_value, (qs, ks, vs), unroll=unroll)
|
||||
return W, (p, qs, ks, vs)
|
||||
|
||||
def bwd(pqkv, W_ct):
|
||||
def body(carry, qkv_xct):
|
||||
p, p_ct = carry
|
||||
q, k, v, x_ct = qkv_xct
|
||||
q_ct = jnp.einsum("...d,...md->...m", x_ct, p, precision=precision)
|
||||
p_ct += jnp.einsum("...d,...m->...md", x_ct, q, precision=precision)
|
||||
k_ct = jnp.einsum("...md,...d->...m", p_ct, v, precision=precision)
|
||||
v_ct = jnp.einsum("...md,...m->...d", p_ct, k, precision=precision)
|
||||
p -= jnp.einsum("...m,...d->...md", k, v, precision=precision)
|
||||
return (p, p_ct), (q_ct, k_ct, v_ct)
|
||||
|
||||
p, qs, ks, vs = pqkv
|
||||
_, (qs_ct, ks_ct, vs_ct) = lax.scan(
|
||||
body, (p, jnp.zeros_like(p)), (qs, ks, vs, W_ct), reverse=True, unroll=unroll
|
||||
)
|
||||
return qs_ct, ks_ct, vs_ct
|
||||
|
||||
@jax.custom_vjp
|
||||
def _numerator_impl(qs, ks, vs):
|
||||
W, _ = fwd(qs, ks, vs)
|
||||
return W
|
||||
|
||||
_numerator_impl.defvjp(fwd, bwd)
|
||||
|
||||
return _numerator_impl
|
||||
|
||||
|
||||
def _denominator(t_slice_shape, precision, unroll=1):
|
||||
def fwd(qs, ks):
|
||||
def body(p, qk):
|
||||
q, k = qk
|
||||
p += k
|
||||
x = jnp.einsum("...m,...m->...", q, p, precision=precision)
|
||||
return p, x
|
||||
|
||||
p = jnp.zeros(t_slice_shape)
|
||||
p, R = lax.scan(body, p, (qs, ks), unroll=unroll)
|
||||
return R, (qs, ks, p)
|
||||
|
||||
def bwd(qkp, R_ct):
|
||||
def body(carry, qkx):
|
||||
p, p_ct = carry
|
||||
q, k, x_ct = qkx
|
||||
q_ct = jnp.einsum("...,...m->...m", x_ct, p, precision=precision)
|
||||
p_ct += jnp.einsum("...,...m->...m", x_ct, q, precision=precision)
|
||||
k_ct = p_ct
|
||||
p -= k
|
||||
return (p, p_ct), (q_ct, k_ct)
|
||||
|
||||
qs, ks, p = qkp
|
||||
_, (qs_ct, ks_ct) = lax.scan(body, (p, jnp.zeros_like(p)), (qs, ks, R_ct), reverse=True, unroll=unroll)
|
||||
return (qs_ct, ks_ct)
|
||||
|
||||
@jax.custom_vjp
|
||||
def _denominator_impl(qs, ks):
|
||||
R, _ = fwd(qs, ks)
|
||||
return R
|
||||
|
||||
_denominator_impl.defvjp(fwd, bwd)
|
||||
|
||||
return _denominator_impl
|
||||
|
||||
|
||||
class FastAttentionviaLowRankDecomposition(FastAttention):
|
||||
r"""
|
||||
Class providing a method for fast attention via low rank decomposition. Class is responsible for providing a method
|
||||
<dot_product_attention> for fast dot-product attention with the use of low rank decomposition (e.g. with random
|
||||
feature maps).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
matrix_creator,
|
||||
kernel_feature_creator,
|
||||
renormalize_attention,
|
||||
numerical_stabilizer,
|
||||
redraw_features,
|
||||
unidirectional,
|
||||
lax_scan_unroll=1,
|
||||
): # For optimal GPU performance, set to 16.
|
||||
rng = random.PRNGKey(0)
|
||||
self.matrix_creator = matrix_creator
|
||||
self.projection_matrix = self.draw_weights(rng)
|
||||
self.kernel_feature_creator = kernel_feature_creator
|
||||
self.renormalize_attention = renormalize_attention
|
||||
self.numerical_stabilizer = numerical_stabilizer
|
||||
self.redraw_features = redraw_features
|
||||
self.unidirectional = unidirectional
|
||||
self.lax_scan_unroll = lax_scan_unroll
|
||||
|
||||
def draw_weights(self, key):
|
||||
if self.matrix_creator is None:
|
||||
return None
|
||||
matrixrng, _ = random.split(key)
|
||||
projection_matrix = self.matrix_creator(key=matrixrng).get_2d_array()
|
||||
return projection_matrix
|
||||
|
||||
def dot_product_attention(
|
||||
self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
dtype=jnp.float32,
|
||||
bias=None,
|
||||
axis=None,
|
||||
broadcast_dropout=True,
|
||||
dropout_rng=None,
|
||||
dropout_rate=0.0,
|
||||
deterministic=False,
|
||||
precision=None,
|
||||
):
|
||||
|
||||
assert key.shape[:-1] == value.shape[:-1]
|
||||
assert query.shape[0:1] == key.shape[0:1] and query.shape[-1] == key.shape[-1]
|
||||
if axis is None:
|
||||
axis = tuple(range(1, key.ndim - 2))
|
||||
if not isinstance(axis, Iterable):
|
||||
axis = (axis,)
|
||||
assert key.ndim == query.ndim
|
||||
assert key.ndim == value.ndim
|
||||
for ax in axis:
|
||||
if not (query.ndim >= 3 and 1 <= ax < query.ndim - 2):
|
||||
raise ValueError("Attention axis must be between the batch " "axis and the last-two axes.")
|
||||
n = key.ndim
|
||||
|
||||
# Constructing projection tensor.
|
||||
if self.redraw_features:
|
||||
# TODO(kchoro): Get rid of the constant below.
|
||||
query_seed = lax.convert_element_type(jnp.ceil(jnp.sum(query) * 10000000.0), jnp.int32)
|
||||
rng = random.PRNGKey(query_seed)
|
||||
self.projection_matrix = self.draw_weights(rng)
|
||||
|
||||
# batch_dims is <bs, <non-attention dims>, num_heads>
|
||||
batch_dims = tuple(onp.delete(range(n), axis + (n - 1,)))
|
||||
# q & k -> (bs, <non-attention dims>, num_heads, <attention dims>, channels)
|
||||
qk_perm = batch_dims + axis + (n - 1,)
|
||||
k_extra_perm = axis + batch_dims + (n - 1,)
|
||||
key_extra = key.transpose(k_extra_perm)
|
||||
key = key.transpose(qk_perm)
|
||||
query = query.transpose(qk_perm)
|
||||
# v -> (bs, <non-attention dims>, num_heads, <attention dims>, channels)
|
||||
v_perm = batch_dims + axis + (n - 1,)
|
||||
value = value.transpose(v_perm)
|
||||
batch_dims_t = tuple(range(len(batch_dims)))
|
||||
attention_dims_t = tuple(range(len(batch_dims), len(batch_dims) + len(axis)))
|
||||
|
||||
# Constructing tensors Q^{'} and K^{'}.
|
||||
query_prime = self.kernel_feature_creator(
|
||||
query, self.projection_matrix, attention_dims_t, batch_dims_t, precision, True
|
||||
)
|
||||
key_prime = self.kernel_feature_creator(
|
||||
key, self.projection_matrix, attention_dims_t, batch_dims_t, precision, False
|
||||
)
|
||||
|
||||
if self.unidirectional:
|
||||
index = attention_dims_t[0]
|
||||
z_slice_shape = key_prime.shape[0 : len(batch_dims_t)] + (key_prime.shape[-1],) + (value.shape[-1],)
|
||||
|
||||
numerator_fn = _numerator(z_slice_shape, precision, self.lax_scan_unroll)
|
||||
W = numerator_fn(
|
||||
jnp.moveaxis(query_prime, index, 0), jnp.moveaxis(key_prime, index, 0), jnp.moveaxis(value, index, 0)
|
||||
)
|
||||
|
||||
# Constructing W = (Q^{'}(K^{'})^{T})_{masked}V
|
||||
W = jnp.moveaxis(W, 0, index)
|
||||
|
||||
if not self.renormalize_attention:
|
||||
# Unidirectional, not-normalized attention.
|
||||
perm_inv = _invert_perm(qk_perm)
|
||||
result = W.transpose(perm_inv)
|
||||
return result
|
||||
else:
|
||||
# Unidirectional, normalized attention.
|
||||
thick_all_ones = jnp.zeros(key.shape[0:-1]) + jnp.ones(key_extra.shape[0 : len(axis)])
|
||||
|
||||
index = attention_dims_t[0]
|
||||
t_slice_shape = key_prime.shape[0 : len(batch_dims_t)] + (key_prime.shape[-1],)
|
||||
denominator_fn = _denominator(t_slice_shape, precision, self.lax_scan_unroll)
|
||||
R = denominator_fn(jnp.moveaxis(query_prime, index, 0), jnp.moveaxis(key_prime, index, 0))
|
||||
|
||||
R = jnp.moveaxis(R, 0, index)
|
||||
else:
|
||||
contract_query = tuple(range(len(batch_dims) + len(axis), len(batch_dims) + len(axis) + 1))
|
||||
contract_z = tuple(range(len(batch_dims), len(batch_dims) + 1))
|
||||
# Constructing Z = (K^{'})^{T}V
|
||||
# Z (bs, <non-attention dims>, num_heads, channels_m, channels_v)
|
||||
Z = lax.dot_general(
|
||||
key_prime,
|
||||
value,
|
||||
((attention_dims_t, attention_dims_t), (batch_dims_t, batch_dims_t)),
|
||||
precision=precision,
|
||||
)
|
||||
# Constructing W = Q^{'}Z = Q^{'}(K^{'})^{T}V
|
||||
# q (bs, <non-attention dims>, num_heads, <attention dims>, channels_m)
|
||||
# Z (bs, <non-attention dims>, num_heads, channels_m, channels_v)
|
||||
# W (bs, <non-attention dims>, num_heads, <attention dims>, channels_v)
|
||||
W = lax.dot_general(
|
||||
query_prime, Z, ((contract_query, contract_z), (batch_dims_t, batch_dims_t)), precision=precision
|
||||
)
|
||||
if not self.renormalize_attention:
|
||||
# Bidirectional, not-normalized attention.
|
||||
perm_inv = _invert_perm(qk_perm)
|
||||
result = W.transpose(perm_inv)
|
||||
return result
|
||||
else:
|
||||
# Bidirectional, normalized attention.
|
||||
thick_all_ones = jnp.zeros(key.shape[0:-1]) + jnp.ones(key_extra.shape[0 : len(axis)])
|
||||
contract_key = tuple(range(len(batch_dims), len(batch_dims) + len(axis)))
|
||||
contract_thick_all_ones = tuple(range(thick_all_ones.ndim - len(axis), thick_all_ones.ndim))
|
||||
# Construct T = (K^{'})^{T} 1_L
|
||||
# k (bs, <non-attention dims>, num_heads, <attention dims>, channels)
|
||||
T = lax.dot_general(
|
||||
key_prime,
|
||||
thick_all_ones,
|
||||
((contract_key, contract_thick_all_ones), (batch_dims_t, batch_dims_t)),
|
||||
precision=precision,
|
||||
)
|
||||
|
||||
# Construct partition function: R = Q^{'} T = Q^{'}(K^{'})^{T} 1_L
|
||||
# q_p (bs, <non-attention dims>, num_heads, <attention dims>, channs_m)
|
||||
# T (bs, <non-attention dims>, num_heads, channels_m)
|
||||
R = lax.dot_general(
|
||||
query_prime,
|
||||
T,
|
||||
(((query_prime.ndim - 1,), (T.ndim - 1,)), (batch_dims_t, range(0, len(T.shape) - 1))),
|
||||
precision=precision,
|
||||
)
|
||||
|
||||
R = R + 2 * self.numerical_stabilizer * (jnp.abs(R) <= self.numerical_stabilizer)
|
||||
R = jnp.reciprocal(R)
|
||||
R = jnp.expand_dims(R, len(R.shape))
|
||||
# W (bs, <non-attention dims>, num_heads, <attention dims>, channels_v)
|
||||
# R (bs, <non-attention dims>, num_heads, <attention dims>, extra_channel)
|
||||
result = W * R
|
||||
# back to (bs, dim1, dim2, ..., dimN, num_heads, channels)
|
||||
perm_inv = _invert_perm(qk_perm)
|
||||
result = result.transpose(perm_inv)
|
||||
return result
|
||||
|
||||
|
||||
def _invert_perm(perm):
|
||||
perm_inv = [0] * len(perm)
|
||||
for i, j in enumerate(perm):
|
||||
perm_inv[j] = i
|
||||
return tuple(perm_inv)
|
||||
685
examples/research_projects/performer/run_mlm_performer.py
Normal file
685
examples/research_projects/performer/run_mlm_performer.py
Normal file
@@ -0,0 +1,685 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Team All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) with whole word masking on a
|
||||
text file or a dataset.
|
||||
|
||||
Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
|
||||
https://huggingface.co/models?filter=masked-lm
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax import jax_utils
|
||||
from flax.optim import Adam
|
||||
from flax.training import common_utils
|
||||
from flax.training.common_utils import get_metrics
|
||||
from jax.nn import log_softmax
|
||||
from modeling_flax_performer import FlaxPerformerForMaskedLM
|
||||
from transformers import (
|
||||
MODEL_FOR_MASKED_LM_MAPPING,
|
||||
AutoTokenizer,
|
||||
BertConfig,
|
||||
FlaxBertForMaskedLM,
|
||||
HfArgumentParser,
|
||||
PreTrainedTokenizerBase,
|
||||
TensorType,
|
||||
TrainingArguments,
|
||||
is_tensorboard_available,
|
||||
set_seed,
|
||||
)
|
||||
|
||||
|
||||
# Cache the result
|
||||
has_tensorboard = is_tensorboard_available()
|
||||
if has_tensorboard:
|
||||
try:
|
||||
from flax.metrics.tensorboard import SummaryWriter
|
||||
except ImportError as ie:
|
||||
has_tensorboard = False
|
||||
print(f"Unable to display metrics through TensorBoard because some package are not installed: {ie}")
|
||||
|
||||
else:
|
||||
print(
|
||||
"Unable to display metrics through TensorBoard because the package is not installed: "
|
||||
"Please run pip install tensorboard to enable."
|
||||
)
|
||||
|
||||
MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys())
|
||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||
|
||||
|
||||
@dataclass
|
||||
class WandbArguments:
|
||||
"""
|
||||
Arguments for logging
|
||||
"""
|
||||
|
||||
wandb_user_name: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "The WandB user name for potential logging. If left None, no logging"},
|
||||
)
|
||||
wandb_project_name: Optional[str] = field(
|
||||
default="performer-experiments",
|
||||
metadata={"help": "The WandB project name for potential logging"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
|
||||
"""
|
||||
|
||||
model_name_or_path: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The model checkpoint for weights initialization."
|
||||
"Don't set if you want to train a model from scratch."
|
||||
},
|
||||
)
|
||||
performer: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to use FAVOR+ attention"},
|
||||
)
|
||||
reinitialize: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to use a blank model without pretraining"},
|
||||
)
|
||||
tokenizer_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
||||
)
|
||||
use_fast_tokenizer: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataTrainingArguments:
|
||||
"""
|
||||
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||
"""
|
||||
|
||||
dataset_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
dataset_config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
|
||||
validation_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
|
||||
)
|
||||
train_ref_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
|
||||
)
|
||||
validation_ref_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||
)
|
||||
validation_split_percentage: Optional[int] = field(
|
||||
default=5,
|
||||
metadata={
|
||||
"help": "The percentage of the train set used as validation set in case there's no validation split"
|
||||
},
|
||||
)
|
||||
max_seq_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated. Default to the max input length of the model."
|
||||
},
|
||||
)
|
||||
preprocessing_num_workers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for the preprocessing."},
|
||||
)
|
||||
mlm_probability: float = field(
|
||||
default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
|
||||
)
|
||||
pad_to_max_length: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to pad all samples to `max_seq_length`. "
|
||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
||||
raise ValueError("Need either a dataset name or a training/validation file.")
|
||||
else:
|
||||
if self.train_file is not None:
|
||||
extension = self.train_file.split(".")[-1]
|
||||
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
|
||||
if self.validation_file is not None:
|
||||
extension = self.validation_file.split(".")[-1]
|
||||
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
|
||||
|
||||
|
||||
# Adapted from transformers/data/data_collator.py
|
||||
# Letting here for now, let's discuss where it should live
|
||||
@dataclass
|
||||
class FlaxDataCollatorForLanguageModeling:
|
||||
"""
|
||||
Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
|
||||
are not all of the same length.
|
||||
|
||||
Args:
|
||||
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
|
||||
The tokenizer used for encoding the data.
|
||||
mlm (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not to use masked language modeling. If set to :obj:`False`, the labels are the same as the
|
||||
inputs with the padding tokens ignored (by setting them to -100). Otherwise, the labels are -100 for
|
||||
non-masked tokens and the value to predict for the masked token.
|
||||
mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
|
||||
The probability with which to (randomly) mask tokens in the input, when :obj:`mlm` is set to :obj:`True`.
|
||||
|
||||
.. note::
|
||||
|
||||
For best performance, this data collator should be used with a dataset having items that are dictionaries or
|
||||
BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
|
||||
:class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
|
||||
argument :obj:`return_special_tokens_mask=True`.
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
mlm: bool = True
|
||||
mlm_probability: float = 0.15
|
||||
|
||||
def __post_init__(self):
|
||||
if self.mlm and self.tokenizer.mask_token is None:
|
||||
raise ValueError(
|
||||
"This tokenizer does not have a mask token which is necessary for masked language modeling. "
|
||||
"You should pass `mlm=False` to train on causal language modeling instead."
|
||||
)
|
||||
|
||||
def __call__(self, examples: List[Dict[str, np.ndarray]], pad_to_multiple_of: int) -> Dict[str, np.ndarray]:
|
||||
# Handle dict or lists with proper padding and conversion to tensor.
|
||||
batch = self.tokenizer.pad(examples, pad_to_multiple_of=pad_to_multiple_of, return_tensors=TensorType.NUMPY)
|
||||
|
||||
# If special token mask has been preprocessed, pop it from the dict.
|
||||
special_tokens_mask = batch.pop("special_tokens_mask", None)
|
||||
if self.mlm:
|
||||
batch["input_ids"], batch["labels"] = self.mask_tokens(
|
||||
batch["input_ids"], special_tokens_mask=special_tokens_mask
|
||||
)
|
||||
else:
|
||||
labels = batch["input_ids"].copy()
|
||||
if self.tokenizer.pad_token_id is not None:
|
||||
labels[labels == self.tokenizer.pad_token_id] = -100
|
||||
batch["labels"] = labels
|
||||
return batch
|
||||
|
||||
def mask_tokens(
|
||||
self, inputs: np.ndarray, special_tokens_mask: Optional[np.ndarray]
|
||||
) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
||||
"""
|
||||
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
|
||||
"""
|
||||
labels = inputs.copy()
|
||||
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
|
||||
probability_matrix = np.full(labels.shape, self.mlm_probability)
|
||||
special_tokens_mask = special_tokens_mask.astype("bool")
|
||||
|
||||
probability_matrix[special_tokens_mask] = 0.0
|
||||
masked_indices = np.random.binomial(1, probability_matrix).astype("bool")
|
||||
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
||||
|
||||
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
||||
indices_replaced = np.random.binomial(1, np.full(labels.shape, 0.8)).astype("bool") & masked_indices
|
||||
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
|
||||
|
||||
# 10% of the time, we replace masked input tokens with random word
|
||||
indices_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype("bool")
|
||||
indices_random &= masked_indices & ~indices_replaced
|
||||
|
||||
random_words = np.random.randint(self.tokenizer.vocab_size, size=labels.shape, dtype="i4")
|
||||
inputs[indices_random] = random_words[indices_random]
|
||||
|
||||
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
||||
return inputs, labels
|
||||
|
||||
|
||||
def create_learning_rate_scheduler(
|
||||
factors="constant * linear_warmup * rsqrt_decay",
|
||||
base_learning_rate=0.5,
|
||||
warmup_steps=1000,
|
||||
decay_factor=0.5,
|
||||
steps_per_decay=20000,
|
||||
steps_per_cycle=100000,
|
||||
):
|
||||
"""Creates learning rate schedule.
|
||||
Interprets factors in the factors string which can consist of:
|
||||
* constant: interpreted as the constant value,
|
||||
* linear_warmup: interpreted as linear warmup until warmup_steps,
|
||||
* rsqrt_decay: divide by square root of max(step, warmup_steps)
|
||||
* rsqrt_normalized_decay: divide by square root of max(step/warmup_steps, 1)
|
||||
* decay_every: Every k steps decay the learning rate by decay_factor.
|
||||
* cosine_decay: Cyclic cosine decay, uses steps_per_cycle parameter.
|
||||
Args:
|
||||
factors: string, factors separated by "*" that defines the schedule.
|
||||
base_learning_rate: float, the starting constant for the lr schedule.
|
||||
warmup_steps: int, how many steps to warm up for in the warmup schedule.
|
||||
decay_factor: float, the amount to decay the learning rate by.
|
||||
steps_per_decay: int, how often to decay the learning rate.
|
||||
steps_per_cycle: int, steps per cycle when using cosine decay.
|
||||
Returns:
|
||||
a function learning_rate(step): float -> {"learning_rate": float}, the
|
||||
step-dependent lr.
|
||||
"""
|
||||
factors = [n.strip() for n in factors.split("*")]
|
||||
|
||||
def step_fn(step):
|
||||
"""Step to learning rate function."""
|
||||
ret = 1.0
|
||||
for name in factors:
|
||||
if name == "constant":
|
||||
ret *= base_learning_rate
|
||||
elif name == "linear_warmup":
|
||||
ret *= jnp.minimum(1.0, step / warmup_steps)
|
||||
elif name == "rsqrt_decay":
|
||||
ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))
|
||||
elif name == "rsqrt_normalized_decay":
|
||||
ret *= jnp.sqrt(warmup_steps)
|
||||
ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))
|
||||
elif name == "decay_every":
|
||||
ret *= decay_factor ** (step // steps_per_decay)
|
||||
elif name == "cosine_decay":
|
||||
progress = jnp.maximum(0.0, (step - warmup_steps) / float(steps_per_cycle))
|
||||
ret *= jnp.maximum(0.0, 0.5 * (1.0 + jnp.cos(jnp.pi * (progress % 1.0))))
|
||||
else:
|
||||
raise ValueError("Unknown factor %s." % name)
|
||||
return jnp.asarray(ret, dtype=jnp.float32)
|
||||
|
||||
return step_fn
|
||||
|
||||
|
||||
def compute_metrics(logits, labels, weights, label_smoothing=0.0):
|
||||
"""Compute summary metrics."""
|
||||
loss, normalizer = cross_entropy(logits, labels, weights, label_smoothing)
|
||||
acc, _ = accuracy(logits, labels, weights)
|
||||
metrics = {"loss": loss, "accuracy": acc, "normalizer": normalizer}
|
||||
metrics = jax.lax.psum(metrics, axis_name="batch")
|
||||
return metrics
|
||||
|
||||
|
||||
def accuracy(logits, targets, weights=None):
|
||||
"""Compute weighted accuracy for log probs and targets.
|
||||
Args:
|
||||
logits: [batch, length, num_classes] float array.
|
||||
targets: categorical targets [batch, length] int array.
|
||||
weights: None or array of shape [batch, length]
|
||||
Returns:
|
||||
Tuple of scalar loss and batch normalizing factor.
|
||||
"""
|
||||
if logits.ndim != targets.ndim + 1:
|
||||
raise ValueError(
|
||||
"Incorrect shapes. Got shape %s logits and %s targets" % (str(logits.shape), str(targets.shape))
|
||||
)
|
||||
|
||||
loss = jnp.equal(jnp.argmax(logits, axis=-1), targets)
|
||||
loss *= weights
|
||||
|
||||
return loss.sum(), weights.sum()
|
||||
|
||||
|
||||
def cross_entropy(logits, targets, weights=None, label_smoothing=0.0):
|
||||
"""Compute cross entropy and entropy for log probs and targets.
|
||||
Args:
|
||||
logits: [batch, length, num_classes] float array.
|
||||
targets: categorical targets [batch, length] int array.
|
||||
weights: None or array of shape [batch, length]
|
||||
label_smoothing: label smoothing constant, used to determine the on and off values.
|
||||
Returns:
|
||||
Tuple of scalar loss and batch normalizing factor.
|
||||
"""
|
||||
if logits.ndim != targets.ndim + 1:
|
||||
raise ValueError(
|
||||
"Incorrect shapes. Got shape %s logits and %s targets" % (str(logits.shape), str(targets.shape))
|
||||
)
|
||||
|
||||
vocab_size = logits.shape[-1]
|
||||
confidence = 1.0 - label_smoothing
|
||||
low_confidence = (1.0 - confidence) / (vocab_size - 1)
|
||||
normalizing_constant = -(
|
||||
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
|
||||
)
|
||||
soft_targets = common_utils.onehot(targets, vocab_size, on_value=confidence, off_value=low_confidence)
|
||||
|
||||
loss = -jnp.sum(soft_targets * log_softmax(logits), axis=-1)
|
||||
loss = loss - normalizing_constant
|
||||
|
||||
if weights is not None:
|
||||
loss = loss * weights
|
||||
normalizing_factor = weights.sum()
|
||||
else:
|
||||
normalizing_factor = np.prod(targets.shape)
|
||||
|
||||
return loss.sum(), normalizing_factor
|
||||
|
||||
|
||||
def training_step(optimizer, batch, dropout_rng):
|
||||
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
|
||||
|
||||
def loss_fn(params):
|
||||
targets = batch.pop("labels")
|
||||
|
||||
# Hide away tokens which doesn't participate in the optimization
|
||||
token_mask = jnp.where(targets > 0, 1.0, 0.0)
|
||||
|
||||
logits = model(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
||||
loss, weight_sum = cross_entropy(logits, targets, token_mask)
|
||||
return loss / weight_sum
|
||||
|
||||
step = optimizer.state.step
|
||||
lr = lr_scheduler_fn(step)
|
||||
grad_fn = jax.value_and_grad(loss_fn)
|
||||
loss, grad = grad_fn(optimizer.target)
|
||||
grad = jax.lax.pmean(grad, "batch")
|
||||
optimizer = optimizer.apply_gradient(grad, learning_rate=lr)
|
||||
|
||||
return loss, optimizer, new_dropout_rng
|
||||
|
||||
|
||||
def eval_step(params, batch):
|
||||
"""
|
||||
Calculate evaluation metrics on a batch.
|
||||
"""
|
||||
targets = batch.pop("labels")
|
||||
|
||||
# Hide away tokens which doesn't participate in the optimization
|
||||
token_mask = jnp.where(targets > 0, 1.0, 0.0)
|
||||
logits = model(**batch, params=params, train=False)[0]
|
||||
|
||||
return compute_metrics(logits, targets, token_mask)
|
||||
|
||||
|
||||
def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
|
||||
nb_samples = len(samples_idx)
|
||||
samples_to_remove = nb_samples % batch_size
|
||||
|
||||
if samples_to_remove != 0:
|
||||
samples_idx = samples_idx[:-samples_to_remove]
|
||||
sections_split = nb_samples // batch_size
|
||||
batch_idx = np.split(samples_idx, sections_split)
|
||||
return batch_idx
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# See all possible arguments in src/transformers/training_args.py
|
||||
# or by passing the --help flag to this script.
|
||||
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments, WandbArguments))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
# If we pass only one argument to the script and it's the path to a json file,
|
||||
# let's parse it to get our arguments.
|
||||
model_args, data_args, training_args, wandb_args = parser.parse_json_file(
|
||||
json_file=os.path.abspath(sys.argv[1])
|
||||
)
|
||||
else:
|
||||
model_args, data_args, training_args, wandb_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
if (
|
||||
os.path.exists(training_args.output_dir)
|
||||
and os.listdir(training_args.output_dir)
|
||||
and training_args.do_train
|
||||
and not training_args.overwrite_output_dir
|
||||
):
|
||||
raise ValueError(
|
||||
f"Output directory ({training_args.output_dir}) already exists and is not empty."
|
||||
"Use --overwrite_output_dir to overcome."
|
||||
)
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
level="NOTSET",
|
||||
datefmt="[%X]",
|
||||
)
|
||||
|
||||
# Log on each process the small summary:
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(
|
||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
||||
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||
)
|
||||
|
||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||
logger.info("Training/evaluation parameters %s", training_args)
|
||||
|
||||
# Set seed before initializing model.
|
||||
set_seed(training_args.seed)
|
||||
|
||||
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
||||
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
||||
# (the dataset will be downloaded automatically from the datasets Hub).
|
||||
#
|
||||
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
|
||||
# 'text' is found. You can easily tweak this behavior (see below).
|
||||
#
|
||||
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
|
||||
# download the dataset.
|
||||
if data_args.dataset_name is not None:
|
||||
# Downloading and loading a dataset from the hub.
|
||||
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name)
|
||||
if "validation" not in datasets.keys():
|
||||
datasets["validation"] = load_dataset(
|
||||
data_args.dataset_name,
|
||||
data_args.dataset_config_name,
|
||||
split=f"train[:{data_args.validation_split_percentage}%]",
|
||||
)
|
||||
datasets["train"] = load_dataset(
|
||||
data_args.dataset_name,
|
||||
data_args.dataset_config_name,
|
||||
split=f"train[{data_args.validation_split_percentage}%:]",
|
||||
)
|
||||
else:
|
||||
data_files = {}
|
||||
if data_args.train_file is not None:
|
||||
data_files["train"] = data_args.train_file
|
||||
if data_args.validation_file is not None:
|
||||
data_files["validation"] = data_args.validation_file
|
||||
extension = data_args.train_file.split(".")[-1]
|
||||
if extension == "txt":
|
||||
extension = "text"
|
||||
datasets = load_dataset(extension, data_files=data_files)
|
||||
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
||||
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
|
||||
# Distributed training:
|
||||
# The .from_pretrained methods guarantee that only one local process can concurrently
|
||||
# download model & vocab.
|
||||
|
||||
rng = jax.random.PRNGKey(training_args.seed)
|
||||
dropout_rngs = jax.random.split(rng, jax.local_device_count())
|
||||
|
||||
config = BertConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
|
||||
lm_class = FlaxPerformerForMaskedLM if model_args.performer else FlaxBertForMaskedLM
|
||||
if model_args.reinitialize:
|
||||
model = lm_class(config=BertConfig.from_pretrained(model_args.model_name_or_path))
|
||||
else:
|
||||
model = lm_class.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
dtype=jnp.float32,
|
||||
input_shape=(training_args.train_batch_size, config.max_position_embeddings),
|
||||
seed=training_args.seed,
|
||||
dropout_rate=0.1,
|
||||
)
|
||||
|
||||
if model_args.tokenizer_name:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
||||
)
|
||||
elif model_args.model_name_or_path:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
|
||||
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
|
||||
)
|
||||
|
||||
# Preprocessing the datasets.
|
||||
# First we tokenize all the texts.
|
||||
if training_args.do_train:
|
||||
column_names = datasets["train"].column_names
|
||||
else:
|
||||
column_names = datasets["validation"].column_names
|
||||
text_column_name = "text" if "text" in column_names else column_names[0]
|
||||
|
||||
padding = "max_length" if data_args.pad_to_max_length else False
|
||||
|
||||
def tokenize_function(examples):
|
||||
# Remove empty lines
|
||||
examples = [line for line in examples if len(line) > 0 and not line.isspace()]
|
||||
return tokenizer(
|
||||
examples,
|
||||
return_special_tokens_mask=True,
|
||||
padding=padding,
|
||||
truncation=True,
|
||||
max_length=data_args.max_seq_length,
|
||||
)
|
||||
|
||||
tokenized_datasets = datasets.map(
|
||||
tokenize_function,
|
||||
input_columns=[text_column_name],
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
|
||||
# Enable tensorboard only on the master node
|
||||
if has_tensorboard and jax.host_id() == 0:
|
||||
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir).joinpath("logs").as_posix())
|
||||
|
||||
# Data collator
|
||||
# This one will take care of randomly masking the tokens.
|
||||
data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
|
||||
|
||||
# Setup optimizer
|
||||
optimizer = Adam(
|
||||
learning_rate=training_args.learning_rate,
|
||||
weight_decay=training_args.weight_decay,
|
||||
beta1=training_args.adam_beta1,
|
||||
beta2=training_args.adam_beta2,
|
||||
).create(model.params)
|
||||
|
||||
# Create learning rate scheduler
|
||||
lr_scheduler_fn = create_learning_rate_scheduler(
|
||||
base_learning_rate=training_args.learning_rate, warmup_steps=max(training_args.warmup_steps, 1)
|
||||
)
|
||||
|
||||
# Create parallel version of the training and evaluation steps
|
||||
p_training_step = jax.pmap(training_step, "batch", donate_argnums=(0,))
|
||||
p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
|
||||
|
||||
# Replicate the optimizer on each device
|
||||
optimizer = jax_utils.replicate(optimizer)
|
||||
|
||||
# Store some constant
|
||||
nb_epochs = int(training_args.num_train_epochs)
|
||||
batch_size = int(training_args.train_batch_size)
|
||||
eval_batch_size = int(training_args.eval_batch_size)
|
||||
|
||||
if wandb_args.wandb_user_name is not None:
|
||||
import wandb
|
||||
|
||||
wandb.init(project=wandb_args.wandb_project_name, entity=wandb_args.wandb_user_name)
|
||||
|
||||
epochs = tqdm(range(nb_epochs), desc=f"Epoch ... (1/{nb_epochs})", position=0)
|
||||
for epoch in epochs:
|
||||
|
||||
# ======================== Training ================================
|
||||
# Create sampling rng
|
||||
rng, training_rng, eval_rng = jax.random.split(rng, 3)
|
||||
|
||||
# Generate an epoch by shuffling sampling indices from the train dataset
|
||||
nb_training_samples = len(tokenized_datasets["train"])
|
||||
training_samples_idx = jax.random.permutation(training_rng, jnp.arange(nb_training_samples))
|
||||
training_batch_idx = generate_batch_splits(training_samples_idx, batch_size)
|
||||
|
||||
# Gather the indexes for creating the batch and do a training step
|
||||
for batch_idx in tqdm(training_batch_idx, desc="Training...", position=1):
|
||||
samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
|
||||
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
||||
|
||||
# Model forward
|
||||
model_inputs = common_utils.shard(model_inputs.data)
|
||||
loss, optimizer, dropout_rngs = p_training_step(optimizer, model_inputs, dropout_rngs)
|
||||
|
||||
if wandb_args.wandb_user_name is not None:
|
||||
wandb.log({"Training loss": np.array(loss).mean()})
|
||||
|
||||
epochs.write(f"Loss: {loss}")
|
||||
|
||||
# ======================== Evaluating ==============================
|
||||
nb_eval_samples = len(tokenized_datasets["validation"])
|
||||
eval_samples_idx = jnp.arange(nb_eval_samples)
|
||||
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
||||
|
||||
eval_metrics = []
|
||||
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
|
||||
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
|
||||
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
||||
|
||||
# Model forward
|
||||
model_inputs = common_utils.shard(model_inputs.data)
|
||||
metrics = p_eval_step(optimizer.target, model_inputs)
|
||||
eval_metrics.append(metrics)
|
||||
|
||||
eval_metrics_np = get_metrics(eval_metrics)
|
||||
eval_metrics_np = jax.tree_map(jnp.sum, eval_metrics_np)
|
||||
eval_normalizer = eval_metrics_np.pop("normalizer")
|
||||
eval_summary = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics_np)
|
||||
|
||||
# Update progress bar
|
||||
epochs.desc = (
|
||||
f"Epoch... ({epoch + 1}/{nb_epochs} | Loss: {eval_summary['loss']}, Acc: {eval_summary['accuracy']})"
|
||||
)
|
||||
|
||||
if wandb_args.wandb_user_name is not None:
|
||||
wandb.log({"Eval loss": np.array(eval_summary["loss"]).mean()})
|
||||
|
||||
# Save metrics
|
||||
if has_tensorboard and jax.host_id() == 0:
|
||||
for name, value in eval_summary.items():
|
||||
summary_writer.scalar(name, value, epoch)
|
||||
1
examples/research_projects/performer/sanity_script.sh
Executable file
1
examples/research_projects/performer/sanity_script.sh
Executable file
@@ -0,0 +1 @@
|
||||
TOKENIZERS_PARALLELISM=true python run_mlm_performer.py --output_dir experiments --dataset_name wikipedia --dataset_config_name 20200501.simple --model_name_or_path bert-base-cased --tokenizer_name bert-base-cased --do_train --overwrite_output_dir --per_device_train_batch_size 4 --learning_rate 5e-4 --warmup_steps 100 --num_train_epochs 3 --performer
|
||||
@@ -23,10 +23,10 @@ test.source
|
||||
test.target
|
||||
```
|
||||
|
||||
A sample finetuning command (run ` ./examples/rag/finetune_rag.py --help` to list all available options):
|
||||
A sample finetuning command (run ` ./examples/research_projects/rag/finetune_rag.py --help` to list all available options):
|
||||
|
||||
```bash
|
||||
python examples/rag/finetune_rag.py \
|
||||
python examples/research_projects/rag/finetune_rag.py \
|
||||
--data_dir $DATA_DIR \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--model_name_or_path $MODEL_NAME_OR_PATH \
|
||||
@@ -42,7 +42,7 @@ The `base` models initialize the question encoder with [`facebook/dpr-question_e
|
||||
|
||||
If you would like to initialize finetuning with a base model using different question encoder and generator architectures, you can build it with a consolidation script, e.g.:
|
||||
```
|
||||
python examples/rag/consolidate_rag_checkpoint.py \
|
||||
python examples/research_projects/rag/consolidate_rag_checkpoint.py \
|
||||
--model_type rag_sequence \
|
||||
--generator_name_or_path facebook/bart-large-cnn \
|
||||
--question_encoder_name_or_path facebook/dpr-question_encoder-single-nq-base \
|
||||
@@ -50,6 +50,44 @@ python examples/rag/consolidate_rag_checkpoint.py \
|
||||
```
|
||||
You will then be able to pass `path/to/checkpoint` as `model_name_or_path` to the `finetune_rag.py` script.
|
||||
|
||||
## Document Retrieval
|
||||
When running distributed fine-tuning, each training worker needs to retrieve contextual documents
|
||||
for its input by querying a index loaded into memory. RAG provides two implementations for document retrieval,
|
||||
one with [`torch.distributed`](https://pytorch.org/docs/stable/distributed.html) communication package and the other
|
||||
with [`Ray`](https://docs.ray.io/en/master/).
|
||||
|
||||
This option can be configured with the `--distributed_retriever` flag which can either be set to `pytorch` or `ray`.
|
||||
By default this flag is set to `pytorch`.
|
||||
|
||||
For the Pytorch implementation, only training worker 0 loads the index into CPU memory, and a gather/scatter pattern is used
|
||||
to collect the inputs from the other training workers and send back the corresponding document embeddings.
|
||||
|
||||
For the Ray implementation, the index is loaded in *separate* process(es). The training workers randomly select which
|
||||
retriever worker to query. To use Ray for distributed retrieval, you have to set the `--distributed_retriever` arg to `ray`.
|
||||
To configure the number of retrieval workers (the number of processes that load the index), you can set the `num_retrieval_workers` flag.
|
||||
Also make sure to start the Ray cluster before running fine-tuning.
|
||||
|
||||
```bash
|
||||
# Start a single-node Ray cluster.
|
||||
ray start --head
|
||||
|
||||
python examples/research_projects/rag/finetune_rag.py \
|
||||
--data_dir $DATA_DIR \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--model_name_or_path $MODEL_NAME_OR_PATH \
|
||||
--model_type rag_sequence \
|
||||
--fp16 \
|
||||
--gpus 8
|
||||
--distributed_retriever ray \
|
||||
--num_retrieval_workers 4
|
||||
|
||||
# Stop the ray cluster once fine-tuning has finished.
|
||||
ray stop
|
||||
```
|
||||
|
||||
Using Ray can lead to retrieval speedups on multi-GPU settings since multiple processes load the index rather than
|
||||
just the rank 0 training worker. Using Ray also allows you to load the index on GPU since the index is loaded on a separate
|
||||
processes than the model, while with pytorch distributed retrieval, both are loaded in the same process potentially leading to GPU OOM.
|
||||
|
||||
# Evaluation
|
||||
Our evaluation script enables two modes of evaluation (controlled by the `eval_mode` argument): `e2e` - end2end evaluation, returns EM (exact match) and F1 scores calculated for the downstream task and `retrieval` - which returns precision@k of the documents retrieved for provided inputs.
|
||||
@@ -75,14 +113,14 @@ We demonstrate how to evaluate retrieval against DPR evaluation data. You can do
|
||||
2. Parse the unziped file using the `parse_dpr_relevance_data.py`
|
||||
```bash
|
||||
mkdir output # or wherever you want to save this
|
||||
python examples/rag/parse_dpr_relevance_data.py \
|
||||
python examples/research_projects/rag/parse_dpr_relevance_data.py \
|
||||
--src_path biencoder-nq-dev.json \
|
||||
--evaluation_set output/biencoder-nq-dev.questions \
|
||||
--gold_data_path output/biencoder-nq-dev.pages
|
||||
```
|
||||
3. Run evaluation:
|
||||
```bash
|
||||
python examples/rag/eval_rag.py \
|
||||
python examples/research_projects/rag/eval_rag.py \
|
||||
--model_name_or_path facebook/rag-sequence-nq \
|
||||
--model_type rag_sequence \
|
||||
--evaluation_set output/biencoder-nq-dev.questions \
|
||||
@@ -93,7 +131,7 @@ We demonstrate how to evaluate retrieval against DPR evaluation data. You can do
|
||||
```
|
||||
```bash
|
||||
# EXPLANATION
|
||||
python examples/rag/eval_rag.py \
|
||||
python examples/research_projects/rag/eval_rag.py \
|
||||
--model_name_or_path facebook/rag-sequence-nq \ # model name or path of the model we're evaluating
|
||||
--model_type rag_sequence \ # RAG model type (rag_token or rag_sequence)
|
||||
--evaluation_set output/biencoder-nq-dev.questions \ # an input dataset for evaluation
|
||||
@@ -121,7 +159,7 @@ Add `--recalculate` parameter to force the script to perform inference from scra
|
||||
|
||||
An example e2e evaluation run could look as follows:
|
||||
```bash
|
||||
python examples/rag/eval_rag.py \
|
||||
python examples/research_projects/rag/eval_rag.py \
|
||||
--model_name_or_path facebook/rag-sequence-nq \
|
||||
--model_type rag_sequence \
|
||||
--evaluation_set path/to/test.source \
|
||||
@@ -141,14 +179,14 @@ With `use_custom_knowledge_dataset.py` you can build your own knowledge source,
|
||||
|
||||
For instance, if documents are serialized as tab-separated csv files with the columns "title" and "text", one can use `use_own_knowledge_dataset.py` as follows:
|
||||
```bash
|
||||
python examples/rag/use_own_knowledge_dataset.py \
|
||||
python examples/research_projects/rag/use_own_knowledge_dataset.py \
|
||||
--csv_path path/to/my_csv \
|
||||
--output_dir path/to/my_knowledge_dataset \
|
||||
```
|
||||
|
||||
The created outputs in `path/to/my_knowledge_dataset` can then be used to finetune RAG as follows:
|
||||
```bash
|
||||
python examples/rag/finetune_rag.py \
|
||||
python examples/research_projects/rag/finetune_rag.py \
|
||||
--data_dir $DATA_DIR \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--model_name_or_path $MODEL_NAME_OR_PATH \
|
||||
|
||||
@@ -9,6 +9,7 @@ from transformers.file_utils import is_apex_available
|
||||
from transformers.testing_utils import (
|
||||
TestCasePlus,
|
||||
execute_subprocess_async,
|
||||
require_ray,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
)
|
||||
@@ -29,7 +30,7 @@ class RagFinetuneExampleTests(TestCasePlus):
|
||||
with open(os.path.join(data_dir, f"{split}.{field}"), "w") as f:
|
||||
f.write(content)
|
||||
|
||||
def _run_finetune(self, gpus: int):
|
||||
def _run_finetune(self, gpus: int, distributed_retriever: str = "pytorch"):
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
@@ -66,6 +67,7 @@ class RagFinetuneExampleTests(TestCasePlus):
|
||||
--gradient_accumulation_steps 1 \
|
||||
--distributed-port 8787 \
|
||||
--use_dummy_dataset 1 \
|
||||
--distributed_retriever {distributed_retriever} \
|
||||
""".split()
|
||||
|
||||
if gpus > 0:
|
||||
@@ -94,3 +96,15 @@ class RagFinetuneExampleTests(TestCasePlus):
|
||||
def test_finetune_multigpu(self):
|
||||
result = self._run_finetune(gpus=2)
|
||||
self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_ray
|
||||
def test_finetune_gpu_ray_retrieval(self):
|
||||
result = self._run_finetune(gpus=1, distributed_retriever="ray")
|
||||
self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2)
|
||||
|
||||
@require_torch_multi_gpu
|
||||
@require_ray
|
||||
def test_finetune_multigpu_ray_retrieval(self):
|
||||
result = self._run_finetune(gpus=1, distributed_retriever="ray")
|
||||
self.assertGreaterEqual(result["test"][0]["test_avg_em"], 0.2)
|
||||
|
||||
@@ -31,14 +31,13 @@ class RagPyTorchDistributedRetriever(RagRetriever):
|
||||
If specified, use this index instead of the one built using the configuration
|
||||
"""
|
||||
|
||||
_init_retrieval = False
|
||||
|
||||
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, index=None):
|
||||
super().__init__(
|
||||
config,
|
||||
question_encoder_tokenizer=question_encoder_tokenizer,
|
||||
generator_tokenizer=generator_tokenizer,
|
||||
index=index,
|
||||
init_retrieval=False,
|
||||
)
|
||||
self.process_group = None
|
||||
|
||||
154
examples/research_projects/rag/distributed_ray_retriever.py
Normal file
154
examples/research_projects/rag/distributed_ray_retriever.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import logging
|
||||
import random
|
||||
|
||||
import ray
|
||||
from transformers import RagConfig, RagRetriever, RagTokenizer
|
||||
from transformers.file_utils import requires_datasets, requires_faiss
|
||||
from transformers.models.rag.retrieval_rag import CustomHFIndex
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RayRetriever:
|
||||
def __init__(self):
|
||||
self.initialized = False
|
||||
|
||||
def create_rag_retriever(self, config, question_encoder_tokenizer, generator_tokenizer, index):
|
||||
if not self.initialized:
|
||||
self.retriever = RagRetriever(
|
||||
config,
|
||||
question_encoder_tokenizer=question_encoder_tokenizer,
|
||||
generator_tokenizer=generator_tokenizer,
|
||||
index=index,
|
||||
init_retrieval=False,
|
||||
)
|
||||
self.initialized = True
|
||||
|
||||
def init_retrieval(self):
|
||||
self.retriever.index.init_index()
|
||||
|
||||
def retrieve(self, question_hidden_states, n_docs):
|
||||
doc_ids, retrieved_doc_embeds = self.retriever._main_retrieve(question_hidden_states, n_docs)
|
||||
return doc_ids, retrieved_doc_embeds
|
||||
|
||||
|
||||
class RagRayDistributedRetriever(RagRetriever):
|
||||
"""
|
||||
A distributed retriever built on top of the ``Ray`` API, a library
|
||||
for building distributed applications (https://docs.ray.io/en/master/).
|
||||
package. During training, all training workers initialize their own
|
||||
instance of a `RagRayDistributedRetriever`, and each instance of
|
||||
this distributed retriever shares a common set of Retrieval Ray
|
||||
Actors (https://docs.ray.io/en/master/walkthrough.html#remote
|
||||
-classes-actors) that load the index on separate processes. Ray
|
||||
handles the communication between the `RagRayDistributedRetriever`
|
||||
instances and the remote Ray actors. If training is done in a
|
||||
non-distributed setup, the index will simply be loaded in the same
|
||||
process as the training worker and Ray will not be used.
|
||||
|
||||
Args:
|
||||
config (:class:`~transformers.RagConfig`):
|
||||
The configuration of the RAG model this Retriever is used with. Contains parameters indicating which ``Index`` to build.
|
||||
question_encoder_tokenizer (:class:`~transformers.PretrainedTokenizer`):
|
||||
The tokenizer that was used to tokenize the question.
|
||||
It is used to decode the question and then use the generator_tokenizer.
|
||||
generator_tokenizer (:class:`~transformers.PretrainedTokenizer`):
|
||||
The tokenizer used for the generator part of the RagModel.
|
||||
retrieval_workers (:obj:`List[ray.ActorClass(RayRetriever)]`): A list of already initialized `RayRetriever` actors.
|
||||
These actor classes run on remote processes and are responsible for performing the index lookup.
|
||||
index (:class:`~transformers.retrieval_rag.Index`, optional, defaults to the one defined by the configuration):
|
||||
If specified, use this index instead of the one built using the configuration
|
||||
"""
|
||||
|
||||
def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, retrieval_workers, index=None):
|
||||
if index is not None and index.is_initialized() and len(retrieval_workers) > 0:
|
||||
raise ValueError(
|
||||
"When using Ray for distributed fine-tuning, "
|
||||
"you'll need to provide the paths instead, "
|
||||
"as the dataset and the index are loaded "
|
||||
"separately. More info in examples/rag/use_own_knowledge_dataset.py "
|
||||
)
|
||||
super().__init__(
|
||||
config,
|
||||
question_encoder_tokenizer=question_encoder_tokenizer,
|
||||
generator_tokenizer=generator_tokenizer,
|
||||
index=index,
|
||||
init_retrieval=False,
|
||||
)
|
||||
self.retrieval_workers = retrieval_workers
|
||||
if len(self.retrieval_workers) > 0:
|
||||
ray.get(
|
||||
[
|
||||
worker.create_rag_retriever.remote(config, question_encoder_tokenizer, generator_tokenizer, index)
|
||||
for worker in self.retrieval_workers
|
||||
]
|
||||
)
|
||||
|
||||
def init_retrieval(self):
|
||||
"""
|
||||
Retriever initialization function, needs to be called from the
|
||||
training process. This function triggers retrieval initialization
|
||||
for all retrieval actors if using distributed setting, or loads
|
||||
index into current process if training is not distributed.
|
||||
"""
|
||||
logger.info("initializing retrieval")
|
||||
|
||||
if len(self.retrieval_workers) > 0:
|
||||
ray.get([worker.init_retrieval.remote() for worker in self.retrieval_workers])
|
||||
else:
|
||||
# Non-distributed training. Load index into this same process.
|
||||
self.index.init_index()
|
||||
|
||||
def retrieve(self, question_hidden_states, n_docs):
|
||||
"""
|
||||
Retrieves documents for specified ``question_hidden_states``. If
|
||||
running training with multiple workers, a random retrieval actor is
|
||||
selected to perform the index lookup and return the result.
|
||||
|
||||
Args:
|
||||
question_hidden_states (:obj:`np.ndarray` of shape :obj:`(batch_size, vector_size)`):
|
||||
A batch of query vectors to retrieve with.
|
||||
n_docs (:obj:`int`):
|
||||
The number of docs retrieved per query.
|
||||
|
||||
Output:
|
||||
retrieved_doc_embeds (:obj:`np.ndarray` of shape :obj:`(batch_size, n_docs, dim)`
|
||||
The retrieval embeddings of the retrieved docs per query.
|
||||
doc_ids (:obj:`np.ndarray` of shape :obj:`batch_size, n_docs`)
|
||||
The ids of the documents in the index
|
||||
doc_dicts (:obj:`List[dict]`):
|
||||
The retrieved_doc_embeds examples per query.
|
||||
"""
|
||||
if len(self.retrieval_workers) > 0:
|
||||
# Select a random retrieval actor.
|
||||
random_worker = self.retrieval_workers[random.randint(0, len(self.retrieval_workers) - 1)]
|
||||
doc_ids, retrieved_doc_embeds = ray.get(random_worker.retrieve.remote(question_hidden_states, n_docs))
|
||||
else:
|
||||
doc_ids, retrieved_doc_embeds = self._main_retrieve(question_hidden_states, n_docs)
|
||||
return retrieved_doc_embeds, doc_ids, self.index.get_doc_dicts(doc_ids)
|
||||
|
||||
@classmethod
|
||||
def get_tokenizers(cls, retriever_name_or_path, indexed_dataset=None, **kwargs):
|
||||
return super(RagRayDistributedRetriever, cls).get_tokenizers(retriever_name_or_path, indexed_dataset, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, retriever_name_or_path, actor_handles, indexed_dataset=None, **kwargs):
|
||||
requires_datasets(cls)
|
||||
requires_faiss(cls)
|
||||
config = kwargs.pop("config", None) or RagConfig.from_pretrained(retriever_name_or_path, **kwargs)
|
||||
rag_tokenizer = RagTokenizer.from_pretrained(retriever_name_or_path, config=config)
|
||||
question_encoder_tokenizer = rag_tokenizer.question_encoder
|
||||
generator_tokenizer = rag_tokenizer.generator
|
||||
if indexed_dataset is not None:
|
||||
config.index_name = "custom"
|
||||
index = CustomHFIndex(config.retrieval_vector_size, indexed_dataset)
|
||||
else:
|
||||
index = cls._build_index(config)
|
||||
return cls(
|
||||
config,
|
||||
question_encoder_tokenizer=question_encoder_tokenizer,
|
||||
generator_tokenizer=generator_tokenizer,
|
||||
retrieval_workers=actor_handles,
|
||||
index=index,
|
||||
)
|
||||
@@ -130,8 +130,6 @@ def evaluate_batch_e2e(args, rag_model, questions):
|
||||
early_stopping=False,
|
||||
num_return_sequences=1,
|
||||
bad_words_ids=[[0, 0]], # BART likes to repeat BOS tokens, dont allow it to generate more than one
|
||||
clean_up_tokenization=True,
|
||||
print_docs=args.print_docs,
|
||||
)
|
||||
answers = rag_model.retriever.generator_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
|
||||
@@ -29,6 +29,12 @@ from transformers import (
|
||||
T5ForConditionalGeneration,
|
||||
)
|
||||
from transformers import logging as transformers_logging
|
||||
from transformers.integrations import is_ray_available
|
||||
|
||||
|
||||
if is_ray_available():
|
||||
import ray
|
||||
from distributed_ray_retriever import RagRayDistributedRetriever, RayRetriever
|
||||
|
||||
|
||||
from callbacks_rag import ( # noqa: E402 # isort:skipq
|
||||
@@ -36,7 +42,8 @@ from callbacks_rag import ( # noqa: E402 # isort:skipq
|
||||
get_early_stopping_callback,
|
||||
Seq2SeqLoggingCallback,
|
||||
)
|
||||
from distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
|
||||
|
||||
from distributed_pytorch_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
|
||||
from utils_rag import ( # noqa: E402 # isort:skip
|
||||
calculate_exact_match,
|
||||
flatten_list,
|
||||
@@ -88,7 +95,12 @@ class CustomAccel(DDPAccelerator):
|
||||
os.environ["MASTER_PORT"] = str(self.distributed_port)
|
||||
super().init_ddp_connection(global_rank, world_size, is_slurm_managing_tasks)
|
||||
if module.is_rag_model:
|
||||
module.model.rag.retriever.init_retrieval(self.distributed_port)
|
||||
if module.distributed_retriever == "pytorch":
|
||||
module.model.rag.retriever.init_retrieval(self.distributed_port)
|
||||
elif module.distributed_retriever == "ray" and global_rank == 0:
|
||||
# For the Ray retriever, only initialize it once when global
|
||||
# rank is 0.
|
||||
module.model.rag.retriever.init_retrieval()
|
||||
|
||||
|
||||
class GenerativeQAModule(BaseTransformer):
|
||||
@@ -127,7 +139,13 @@ class GenerativeQAModule(BaseTransformer):
|
||||
config.generator.prefix = hparams.prefix
|
||||
config.label_smoothing = hparams.label_smoothing
|
||||
hparams, config.generator = set_extra_model_params(extra_model_params, hparams, config.generator)
|
||||
retriever = RagPyTorchDistributedRetriever.from_pretrained(hparams.model_name_or_path, config=config)
|
||||
if hparams.distributed_retriever == "pytorch":
|
||||
retriever = RagPyTorchDistributedRetriever.from_pretrained(hparams.model_name_or_path, config=config)
|
||||
elif hparams.distributed_retriever == "ray":
|
||||
# The Ray retriever needs the handles to the retriever actors.
|
||||
retriever = RagRayDistributedRetriever.from_pretrained(
|
||||
hparams.model_name_or_path, hparams.actor_handles, config=config
|
||||
)
|
||||
model = self.model_class.from_pretrained(hparams.model_name_or_path, config=config, retriever=retriever)
|
||||
prefix = config.question_encoder.prefix
|
||||
else:
|
||||
@@ -180,7 +198,12 @@ class GenerativeQAModule(BaseTransformer):
|
||||
# For single GPU training, init_ddp_connection is not called.
|
||||
# So we need to initialize the retrievers here.
|
||||
if hparams.gpus <= 1:
|
||||
self.model.retriever.init_retrieval(self.distributed_port)
|
||||
if hparams.distributed_retriever == "ray":
|
||||
self.model.retriever.init_retrieval()
|
||||
elif hparams.distributed_retriever == "pytorch":
|
||||
self.model.retriever.init_retrieval(self.distributed_port)
|
||||
|
||||
self.distributed_retriever = hparams.distributed_retriever
|
||||
|
||||
def forward(self, input_ids, **kwargs):
|
||||
return self.model(input_ids, **kwargs)
|
||||
@@ -442,6 +465,20 @@ class GenerativeQAModule(BaseTransformer):
|
||||
default=None,
|
||||
help="Path to the faiss index for custom index. More info about custom indexes in the RagRetriever documentation as well as in `examples/rag/use_own_knowledge_dataset.py`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--distributed_retriever",
|
||||
choices=["ray", "pytorch"],
|
||||
type=str,
|
||||
default="pytorch",
|
||||
help="What implementation to use for distributed retriever? If "
|
||||
"pytorch is selected, the index is loaded on training "
|
||||
"worker 0, and torch.distributed is used to handle "
|
||||
"communication between training worker 0, and the other "
|
||||
"training workers. If ray is selected, the Ray library is "
|
||||
"used to create load the index on separate processes, "
|
||||
"and Ray handles the communication between the training "
|
||||
"workers and the retrieval actors.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_dummy_dataset",
|
||||
type=bool,
|
||||
@@ -450,9 +487,30 @@ class GenerativeQAModule(BaseTransformer):
|
||||
)
|
||||
return parser
|
||||
|
||||
@staticmethod
|
||||
def add_ray_specific_args(parser):
|
||||
# Ray cluster address.
|
||||
parser.add_argument(
|
||||
"--ray-address",
|
||||
default="auto",
|
||||
type=str,
|
||||
help="The address of the Ray cluster to connect to. If not "
|
||||
"specified, Ray will attempt to automatically detect the "
|
||||
"cluster. Has no effect if pytorch is used as the distributed "
|
||||
"retriever.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_retrieval_workers",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The number of retrieval actors to use when Ray is selected"
|
||||
"for the distributed retriever. Has no effect when "
|
||||
"distributed_retriever is set to pytorch.",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main(args=None, model=None) -> GenerativeQAModule:
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd())
|
||||
@@ -461,6 +519,46 @@ def main(args=None, model=None) -> GenerativeQAModule:
|
||||
args = args or parser.parse_args()
|
||||
|
||||
Path(args.output_dir).mkdir(exist_ok=True)
|
||||
|
||||
named_actors = []
|
||||
if args.distributed_retriever == "ray" and args.gpus > 1:
|
||||
if not is_ray_available():
|
||||
raise RuntimeError("Please install Ray to use the Ray " "distributed retriever.")
|
||||
# Connect to an existing Ray cluster.
|
||||
try:
|
||||
ray.init(address=args.ray_address)
|
||||
except (ConnectionError, ValueError):
|
||||
logger.warning(
|
||||
"Connection to Ray cluster failed. Make sure a Ray"
|
||||
"cluster is running by either using Ray's cluster "
|
||||
"launcher (`ray up`) or by manually starting Ray on "
|
||||
"each node via `ray start --head` for the head node "
|
||||
"and `ray start --address='<ip address>:6379'` for "
|
||||
"additional nodes. See "
|
||||
"https://docs.ray.io/en/master/cluster/index.html "
|
||||
"for more info."
|
||||
)
|
||||
raise
|
||||
|
||||
# Create Ray actors only for rank 0.
|
||||
if ("LOCAL_RANK" not in os.environ or os.environ["LOCAL_RANK"] == 0) and (
|
||||
"NODE_RANK" not in os.environ or os.environ["NODE_RANK"] == 0
|
||||
):
|
||||
remote_cls = ray.remote(RayRetriever)
|
||||
named_actors = [
|
||||
remote_cls.options(name="retrieval_worker_{}".format(i)).remote()
|
||||
for i in range(args.num_retrieval_workers)
|
||||
]
|
||||
else:
|
||||
logger.info(
|
||||
"Getting named actors for NODE_RANK {}, LOCAL_RANK {}".format(
|
||||
os.environ["NODE_RANK"], os.environ["LOCAL_RANK"]
|
||||
)
|
||||
)
|
||||
named_actors = [ray.get_actor("retrieval_worker_{}".format(i)) for i in range(args.num_retrieval_workers)]
|
||||
args.actor_handles = named_actors
|
||||
assert args.actor_handles == named_actors
|
||||
|
||||
if model is None:
|
||||
model: GenerativeQAModule = GenerativeQAModule(args)
|
||||
|
||||
@@ -471,17 +569,17 @@ def main(args=None, model=None) -> GenerativeQAModule:
|
||||
or str(args.output_dir).startswith("/tmp")
|
||||
or str(args.output_dir).startswith("/var")
|
||||
):
|
||||
logger = True # don't pollute wandb logs unnecessarily
|
||||
training_logger = True # don't pollute wandb logs unnecessarily
|
||||
elif args.logger_name == "wandb":
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
project = os.environ.get("WANDB_PROJECT", dataset)
|
||||
logger = WandbLogger(name=model.output_dir.name, project=project)
|
||||
training_logger = WandbLogger(name=model.output_dir.name, project=project)
|
||||
|
||||
elif args.logger_name == "wandb_shared":
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
|
||||
training_logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
|
||||
|
||||
es_callback = (
|
||||
get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
|
||||
@@ -495,8 +593,9 @@ def main(args=None, model=None) -> GenerativeQAModule:
|
||||
logging_callback=Seq2SeqLoggingCallback(),
|
||||
checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric),
|
||||
early_stopping_callback=es_callback,
|
||||
logger=logger,
|
||||
logger=training_logger,
|
||||
accelerator=CustomAccel() if args.gpus > 1 else None,
|
||||
profiler=pl.profiler.AdvancedProfiler() if args.profile else None,
|
||||
)
|
||||
pickle_save(model.hparams, model.output_dir / "hparams.pkl")
|
||||
|
||||
@@ -509,4 +608,19 @@ def main(args=None, model=None) -> GenerativeQAModule:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
parser = GenerativeQAModule.add_model_specific_args(parser, os.getcwd())
|
||||
parser = GenerativeQAModule.add_retriever_specific_args(parser)
|
||||
parser = GenerativeQAModule.add_ray_specific_args(parser)
|
||||
|
||||
# Pytorch Lightning Profiler
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
action="store_true",
|
||||
help="If True, use pytorch_lightning.profiler.AdvancedProfiler to profile the Trainer.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
|
||||
# A sample finetuning run, you need to specify data_dir, output_dir and model_name_or_path
|
||||
# run ./examples/rag/finetune.sh --help to see all the possible options
|
||||
# run ./examples/rag/finetune_rag.sh --help to see all the possible options
|
||||
|
||||
python examples/rag/finetune_rag.py \
|
||||
--data_dir $DATA_DIR \
|
||||
@@ -11,10 +11,10 @@ python examples/rag/finetune_rag.py \
|
||||
--model_type rag_sequence \
|
||||
--fp16 \
|
||||
--gpus 8 \
|
||||
--profile \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
--n_val -1 \
|
||||
--val_check_interval 0.25 \
|
||||
--train_batch_size 8 \
|
||||
--eval_batch_size 1 \
|
||||
--max_source_length 128 \
|
||||
@@ -31,4 +31,4 @@ python examples/rag/finetune_rag.py \
|
||||
--learning_rate 3e-05 \
|
||||
--num_train_epochs 100 \
|
||||
--warmup_steps 500 \
|
||||
--gradient_accumulation_steps 1
|
||||
--gradient_accumulation_steps 1 \
|
||||
|
||||
44
examples/research_projects/rag/finetune_rag_ray.sh
Executable file
44
examples/research_projects/rag/finetune_rag_ray.sh
Executable file
@@ -0,0 +1,44 @@
|
||||
# Sample script to finetune RAG using Ray for distributed retrieval.
|
||||
|
||||
# Add parent directory to python path to access lightning_base.py
|
||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
|
||||
# Start a single-node Ray cluster.
|
||||
ray start --head
|
||||
|
||||
# A sample finetuning run, you need to specify data_dir, output_dir and model_name_or_path
|
||||
# run ./examples/rag/finetune_rag_ray.sh --help to see all the possible options
|
||||
|
||||
python examples/rag/finetune_rag.py \
|
||||
--data_dir $DATA_DIR \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--model_name_or_path $MODEL_NAME_OR_PATH \
|
||||
--model_type rag_sequence \
|
||||
--fp16 \
|
||||
--gpus 8 \
|
||||
--profile \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
--n_val -1 \
|
||||
--train_batch_size 8 \
|
||||
--eval_batch_size 1 \
|
||||
--max_source_length 128 \
|
||||
--max_target_length 25 \
|
||||
--val_max_target_length 25 \
|
||||
--test_max_target_length 25 \
|
||||
--label_smoothing 0.1 \
|
||||
--dropout 0.1 \
|
||||
--attention_dropout 0.1 \
|
||||
--weight_decay 0.001 \
|
||||
--adam_epsilon 1e-08 \
|
||||
--max_grad_norm 0.1 \
|
||||
--lr_scheduler polynomial \
|
||||
--learning_rate 3e-05 \
|
||||
--num_train_epochs 100 \
|
||||
--warmup_steps 500 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--distributed_retriever ray \
|
||||
--num_retrieval_workers 4
|
||||
|
||||
# Stop the Ray cluster.
|
||||
ray stop
|
||||
@@ -13,15 +13,27 @@ from datasets import Dataset
|
||||
import faiss
|
||||
from transformers import BartConfig, BartTokenizer, DPRConfig, DPRQuestionEncoderTokenizer, RagConfig
|
||||
from transformers.file_utils import is_datasets_available, is_faiss_available, is_psutil_available, is_torch_available
|
||||
from transformers.integrations import is_ray_available
|
||||
from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
|
||||
from transformers.models.rag.retrieval_rag import CustomHFIndex
|
||||
from transformers.models.rag.retrieval_rag import CustomHFIndex, RagRetriever
|
||||
from transformers.models.roberta.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
|
||||
from transformers.testing_utils import require_torch_non_multi_gpu_but_fix_me
|
||||
from transformers.testing_utils import require_ray, require_torch_non_multi_gpu_but_fix_me
|
||||
|
||||
|
||||
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # noqa: E402 # isort:skip
|
||||
|
||||
from distributed_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
|
||||
if is_torch_available():
|
||||
from distributed_pytorch_retriever import RagPyTorchDistributedRetriever # noqa: E402 # isort:skip
|
||||
else:
|
||||
RagPyTorchDistributedRetriever = None
|
||||
|
||||
if is_ray_available():
|
||||
import ray # noqa: E402 # isort:skip
|
||||
from distributed_ray_retriever import RagRayDistributedRetriever, RayRetriever # noqa: E402 # isort:skip
|
||||
else:
|
||||
ray = None
|
||||
RagRayDistributedRetriever = None
|
||||
RayRetriever = None
|
||||
|
||||
|
||||
def require_distributed_retrieval(test_case):
|
||||
@@ -32,8 +44,8 @@ def require_distributed_retrieval(test_case):
|
||||
These tests are skipped when respective libraries are not installed.
|
||||
|
||||
"""
|
||||
if not (is_torch_available() and is_datasets_available() and is_faiss_available() and is_psutil_available()):
|
||||
test_case = unittest.skip("test requires PyTorch, Datasets, Faiss, psutil")(test_case)
|
||||
if not (is_datasets_available() and is_faiss_available() and is_psutil_available()):
|
||||
test_case = unittest.skip("test requires Datasets, Faiss, psutil")(test_case)
|
||||
return test_case
|
||||
|
||||
|
||||
@@ -144,7 +156,31 @@ class RagRetrieverTest(TestCase):
|
||||
retriever.init_retrieval(port)
|
||||
return retriever
|
||||
|
||||
def get_dummy_custom_hf_index_retriever(self, init_retrieval: bool, from_disk: bool, port=12345):
|
||||
def get_dummy_ray_distributed_retriever(self, init_retrieval: bool) -> RagRayDistributedRetriever:
|
||||
# Have to run in local mode because sys.path modifications at top of
|
||||
# file are not propogated to remote workers.
|
||||
# https://stackoverflow.com/questions/54338013/parallel-import-a-python-file-from-sibling-folder
|
||||
ray.init(local_mode=True)
|
||||
config = RagConfig(
|
||||
retrieval_vector_size=self.retrieval_vector_size,
|
||||
question_encoder=DPRConfig().to_dict(),
|
||||
generator=BartConfig().to_dict(),
|
||||
)
|
||||
remote_cls = ray.remote(RayRetriever)
|
||||
workers = [remote_cls.remote() for _ in range(1)]
|
||||
with patch("transformers.models.rag.retrieval_rag.load_dataset") as mock_load_dataset:
|
||||
mock_load_dataset.return_value = self.get_dummy_dataset()
|
||||
retriever = RagRayDistributedRetriever(
|
||||
config,
|
||||
question_encoder_tokenizer=self.get_dpr_tokenizer(),
|
||||
generator_tokenizer=self.get_bart_tokenizer(),
|
||||
retrieval_workers=workers,
|
||||
)
|
||||
if init_retrieval:
|
||||
retriever.init_retrieval()
|
||||
return retriever
|
||||
|
||||
def get_dummy_custom_hf_index_pytorch_retriever(self, init_retrieval: bool, from_disk: bool, port=12345):
|
||||
dataset = self.get_dummy_dataset()
|
||||
config = RagConfig(
|
||||
retrieval_vector_size=self.retrieval_vector_size,
|
||||
@@ -175,13 +211,51 @@ class RagRetrieverTest(TestCase):
|
||||
retriever.init_retrieval(port)
|
||||
return retriever
|
||||
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
def test_pytorch_distributed_retriever_retrieve(self):
|
||||
n_docs = 1
|
||||
retriever = self.get_dummy_pytorch_distributed_retriever(init_retrieval=True)
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
def get_dummy_custom_hf_index_ray_retriever(self, init_retrieval: bool, from_disk: bool):
|
||||
# Have to run in local mode because sys.path modifications at top of
|
||||
# file are not propogated to remote workers.
|
||||
# https://stackoverflow.com/questions/54338013/parallel-import-a-python-file-from-sibling-folder
|
||||
ray.init(local_mode=True)
|
||||
dataset = self.get_dummy_dataset()
|
||||
config = RagConfig(
|
||||
retrieval_vector_size=self.retrieval_vector_size,
|
||||
question_encoder=DPRConfig().to_dict(),
|
||||
generator=BartConfig().to_dict(),
|
||||
index_name="custom",
|
||||
)
|
||||
remote_cls = ray.remote(RayRetriever)
|
||||
workers = [remote_cls.remote() for _ in range(1)]
|
||||
if from_disk:
|
||||
config.passages_path = os.path.join(self.tmpdirname, "dataset")
|
||||
config.index_path = os.path.join(self.tmpdirname, "index.faiss")
|
||||
dataset.get_index("embeddings").save(os.path.join(self.tmpdirname, "index.faiss"))
|
||||
dataset.drop_index("embeddings")
|
||||
dataset.save_to_disk(os.path.join(self.tmpdirname, "dataset"))
|
||||
del dataset
|
||||
retriever = RagRayDistributedRetriever(
|
||||
config,
|
||||
question_encoder_tokenizer=self.get_dpr_tokenizer(),
|
||||
generator_tokenizer=self.get_bart_tokenizer(),
|
||||
retrieval_workers=workers,
|
||||
index=CustomHFIndex.load_from_disk(
|
||||
vector_size=config.retrieval_vector_size,
|
||||
dataset_path=config.passages_path,
|
||||
index_path=config.index_path,
|
||||
),
|
||||
)
|
||||
else:
|
||||
retriever = RagRayDistributedRetriever(
|
||||
config,
|
||||
question_encoder_tokenizer=self.get_dpr_tokenizer(),
|
||||
generator_tokenizer=self.get_bart_tokenizer(),
|
||||
retrieval_workers=workers,
|
||||
index=CustomHFIndex(config.retrieval_vector_size, dataset),
|
||||
)
|
||||
if init_retrieval:
|
||||
retriever.init_retrieval()
|
||||
return retriever
|
||||
|
||||
def distributed_retriever_check(self, retriever: RagRetriever, hidden_states: np.array, n_docs: int) -> None:
|
||||
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
|
||||
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
|
||||
self.assertEqual(len(doc_dicts), 2)
|
||||
@@ -192,33 +266,76 @@ class RagRetrieverTest(TestCase):
|
||||
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
|
||||
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
def test_custom_hf_index_retriever_retrieve(self):
|
||||
def test_pytorch_distributed_retriever_retrieve(self):
|
||||
n_docs = 1
|
||||
retriever = self.get_dummy_custom_hf_index_retriever(init_retrieval=True, from_disk=False)
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
|
||||
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
|
||||
self.assertEqual(len(doc_dicts), 2)
|
||||
self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"])
|
||||
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
|
||||
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
|
||||
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
|
||||
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
|
||||
|
||||
self.distributed_retriever_check(
|
||||
self.get_dummy_pytorch_distributed_retriever(init_retrieval=True), hidden_states, n_docs
|
||||
)
|
||||
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
def test_custom_hf_index_pytorch_retriever_retrieve(self):
|
||||
n_docs = 1
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
|
||||
self.distributed_retriever_check(
|
||||
self.get_dummy_custom_hf_index_pytorch_retriever(init_retrieval=True, from_disk=False),
|
||||
hidden_states,
|
||||
n_docs,
|
||||
)
|
||||
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
def test_custom_pytorch_distributed_retriever_retrieve_from_disk(self):
|
||||
n_docs = 1
|
||||
retriever = self.get_dummy_custom_hf_index_retriever(init_retrieval=True, from_disk=True)
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
retrieved_doc_embeds, doc_ids, doc_dicts = retriever.retrieve(hidden_states, n_docs=n_docs)
|
||||
self.assertEqual(retrieved_doc_embeds.shape, (2, n_docs, self.retrieval_vector_size))
|
||||
self.assertEqual(len(doc_dicts), 2)
|
||||
self.assertEqual(sorted(doc_dicts[0]), ["embeddings", "id", "text", "title"])
|
||||
self.assertEqual(len(doc_dicts[0]["id"]), n_docs)
|
||||
self.assertEqual(doc_dicts[0]["id"][0], "1") # max inner product is reached with second doc
|
||||
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
|
||||
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
|
||||
|
||||
self.distributed_retriever_check(
|
||||
self.get_dummy_custom_hf_index_pytorch_retriever(init_retrieval=True, from_disk=True),
|
||||
hidden_states,
|
||||
n_docs,
|
||||
)
|
||||
|
||||
@require_ray
|
||||
def test_ray_distributed_retriever_retrieve(self):
|
||||
n_docs = 1
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
|
||||
self.distributed_retriever_check(
|
||||
self.get_dummy_ray_distributed_retriever(init_retrieval=True), hidden_states, n_docs
|
||||
)
|
||||
ray.shutdown()
|
||||
|
||||
@require_ray
|
||||
def test_custom_hf_index_ray_retriever_retrieve(self):
|
||||
n_docs = 1
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
self.distributed_retriever_check(
|
||||
self.get_dummy_custom_hf_index_ray_retriever(init_retrieval=True, from_disk=False),
|
||||
hidden_states,
|
||||
n_docs,
|
||||
)
|
||||
ray.shutdown()
|
||||
|
||||
@require_ray
|
||||
def test_custom_ray_distributed_retriever_retrieve_from_disk(self):
|
||||
n_docs = 1
|
||||
hidden_states = np.array(
|
||||
[np.ones(self.retrieval_vector_size), -np.ones(self.retrieval_vector_size)], dtype=np.float32
|
||||
)
|
||||
|
||||
self.distributed_retriever_check(
|
||||
self.get_dummy_custom_hf_index_ray_retriever(init_retrieval=True, from_disk=True), hidden_states, n_docs
|
||||
)
|
||||
ray.shutdown()
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
#!/usr/bin/env bash
|
||||
export PYTHONPATH="../":"${PYTHONPATH}"
|
||||
export WANDB_PROJECT=dmar
|
||||
python distillation.py \
|
||||
export MAX_LEN=128
|
||||
python finetune.py \
|
||||
--learning_rate=3e-4 \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
--fp16 --no_teacher \
|
||||
--fp16 \
|
||||
--val_check_interval 0.25 \
|
||||
--data_dir $ENRO_DIR \
|
||||
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
|
||||
|
||||
@@ -1,645 +0,0 @@
|
||||
import itertools
|
||||
import json
|
||||
import linecache
|
||||
import math
|
||||
import os
|
||||
import pickle
|
||||
import socket
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Iterable, List, Tuple, Union
|
||||
|
||||
import git
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from rouge_score import rouge_scorer, scoring
|
||||
from sacrebleu import corpus_bleu
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset, Sampler
|
||||
|
||||
from sentence_splitter import add_newline_to_end_of_each_sentence
|
||||
from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer, T5Tokenizer
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.models.bart.modeling_bart import shift_tokens_right
|
||||
|
||||
|
||||
try:
|
||||
from fairseq.data.data_utils import batch_by_size
|
||||
|
||||
FAIRSEQ_AVAILABLE = True
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
FAIRSEQ_AVAILABLE = False
|
||||
|
||||
|
||||
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
|
||||
"""From fairseq"""
|
||||
if target.dim() == lprobs.dim() - 1:
|
||||
target = target.unsqueeze(-1)
|
||||
nll_loss = -lprobs.gather(dim=-1, index=target)
|
||||
smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
|
||||
if ignore_index is not None:
|
||||
pad_mask = target.eq(ignore_index)
|
||||
nll_loss.masked_fill_(pad_mask, 0.0)
|
||||
smooth_loss.masked_fill_(pad_mask, 0.0)
|
||||
else:
|
||||
nll_loss = nll_loss.squeeze(-1)
|
||||
smooth_loss = smooth_loss.squeeze(-1)
|
||||
|
||||
nll_loss = nll_loss.sum() # mean()? Scared to break other math.
|
||||
smooth_loss = smooth_loss.sum()
|
||||
eps_i = epsilon / lprobs.size(-1)
|
||||
loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
|
||||
return loss, nll_loss
|
||||
|
||||
|
||||
def lmap(f: Callable, x: Iterable) -> List:
|
||||
"""list(map(f, x))"""
|
||||
return list(map(f, x))
|
||||
|
||||
|
||||
def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict:
|
||||
"""Uses sacrebleu's corpus_bleu implementation."""
|
||||
return {"bleu": round(corpus_bleu(output_lns, [refs_lns], **kwargs).score, 4)}
|
||||
|
||||
|
||||
def build_compute_metrics_fn(task_name: str, tokenizer: PreTrainedTokenizer) -> Callable[[EvalPrediction], Dict]:
|
||||
def non_pad_len(tokens: np.ndarray) -> int:
|
||||
return np.count_nonzero(tokens != tokenizer.pad_token_id)
|
||||
|
||||
def decode_pred(pred: EvalPrediction) -> Tuple[List[str], List[str]]:
|
||||
pred_str = tokenizer.batch_decode(pred.predictions, skip_special_tokens=True)
|
||||
label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True)
|
||||
pred_str = lmap(str.strip, pred_str)
|
||||
label_str = lmap(str.strip, label_str)
|
||||
return pred_str, label_str
|
||||
|
||||
def summarization_metrics(pred: EvalPrediction) -> Dict:
|
||||
pred_str, label_str = decode_pred(pred)
|
||||
rouge: Dict = calculate_rouge(pred_str, label_str)
|
||||
summ_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
|
||||
rouge.update({"gen_len": summ_len})
|
||||
return rouge
|
||||
|
||||
def translation_metrics(pred: EvalPrediction) -> Dict:
|
||||
pred_str, label_str = decode_pred(pred)
|
||||
bleu: Dict = calculate_bleu(pred_str, label_str)
|
||||
gen_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
|
||||
bleu.update({"gen_len": gen_len})
|
||||
return bleu
|
||||
|
||||
compute_metrics_fn = summarization_metrics if "summarization" in task_name else translation_metrics
|
||||
return compute_metrics_fn
|
||||
|
||||
|
||||
def trim_batch(
|
||||
input_ids,
|
||||
pad_token_id,
|
||||
attention_mask=None,
|
||||
):
|
||||
"""Remove columns that are populated exclusively by pad_token_id"""
|
||||
keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
|
||||
if attention_mask is None:
|
||||
return input_ids[:, keep_column_mask]
|
||||
else:
|
||||
return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
|
||||
|
||||
|
||||
class AbstractSeq2SeqDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
data_dir,
|
||||
max_source_length,
|
||||
max_target_length,
|
||||
type_path="train",
|
||||
n_obs=None,
|
||||
prefix="",
|
||||
**dataset_kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.src_file = Path(data_dir).joinpath(type_path + ".source")
|
||||
self.tgt_file = Path(data_dir).joinpath(type_path + ".target")
|
||||
self.len_file = Path(data_dir).joinpath(type_path + ".len")
|
||||
if os.path.exists(self.len_file):
|
||||
self.src_lens = pickle_load(self.len_file)
|
||||
self.used_char_len = False
|
||||
else:
|
||||
self.src_lens = self.get_char_lens(self.src_file)
|
||||
self.used_char_len = True
|
||||
self.max_source_length = max_source_length
|
||||
self.max_target_length = max_target_length
|
||||
assert min(self.src_lens) > 0, f"found empty line in {self.src_file}"
|
||||
self.tokenizer = tokenizer
|
||||
self.prefix = prefix if prefix is not None else ""
|
||||
|
||||
if n_obs is not None:
|
||||
self.src_lens = self.src_lens[:n_obs]
|
||||
self.pad_token_id = self.tokenizer.pad_token_id
|
||||
self.dataset_kwargs = dataset_kwargs
|
||||
dataset_kwargs.update({"add_prefix_space": True} if isinstance(self.tokenizer, BartTokenizer) else {})
|
||||
|
||||
def __len__(self):
|
||||
return len(self.src_lens)
|
||||
|
||||
@staticmethod
|
||||
def get_char_lens(data_file):
|
||||
return [len(x) for x in Path(data_file).open().readlines()]
|
||||
|
||||
@cached_property
|
||||
def tgt_lens(self):
|
||||
"""Length in characters of target documents"""
|
||||
return self.get_char_lens(self.tgt_file)
|
||||
|
||||
def make_sortish_sampler(self, batch_size, distributed=False, shuffle=True, **kwargs):
|
||||
if distributed:
|
||||
return DistributedSortishSampler(self, batch_size, shuffle=shuffle, **kwargs)
|
||||
else:
|
||||
return SortishSampler(self.src_lens, batch_size, shuffle=shuffle)
|
||||
|
||||
def make_dynamic_sampler(self, max_tokens_per_batch=1024, **kwargs):
|
||||
assert FAIRSEQ_AVAILABLE, "Dynamic batch size requires `pip install fairseq`"
|
||||
assert not self.used_char_len, "You must call python make_len_file.py before calling make_dynamic_sampler"
|
||||
sorted_indices = list(self.make_sortish_sampler(1024, shuffle=False))
|
||||
|
||||
def num_tokens_in_example(i):
|
||||
return min(self.src_lens[i], self.max_target_length)
|
||||
|
||||
# call fairseq cython function
|
||||
batch_sampler: List[List[int]] = batch_by_size(
|
||||
sorted_indices,
|
||||
num_tokens_fn=num_tokens_in_example,
|
||||
max_tokens=max_tokens_per_batch,
|
||||
required_batch_size_multiple=64,
|
||||
)
|
||||
shuffled_batches = [batch_sampler[i] for i in np.random.permutation(range(len(batch_sampler)))]
|
||||
# move the largest batch to the front to OOM quickly (uses an approximation for padding)
|
||||
approximate_toks_per_batch = [max(self.src_lens[i] for i in batch) * len(batch) for batch in shuffled_batches]
|
||||
largest_batch_idx = np.argmax(approximate_toks_per_batch)
|
||||
shuffled_batches[0], shuffled_batches[largest_batch_idx] = (
|
||||
shuffled_batches[largest_batch_idx],
|
||||
shuffled_batches[0],
|
||||
)
|
||||
return shuffled_batches
|
||||
|
||||
def __getitem__(self, item):
|
||||
raise NotImplementedError("You must implement this")
|
||||
|
||||
def collate_fn(self, batch):
|
||||
raise NotImplementedError("You must implement this")
|
||||
|
||||
|
||||
class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
|
||||
def __getitem__(self, index) -> Dict[str, torch.Tensor]:
|
||||
"""Call tokenizer on src and tgt_lines"""
|
||||
index = index + 1 # linecache starts at 1
|
||||
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
|
||||
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
|
||||
assert source_line, f"empty source line for index {index}"
|
||||
assert tgt_line, f"empty tgt line for index {index}"
|
||||
source_inputs = self.encode_line(self.tokenizer, source_line, self.max_source_length)
|
||||
target_inputs = self.encode_line(self.tokenizer, tgt_line, self.max_target_length)
|
||||
|
||||
source_ids = source_inputs["input_ids"].squeeze()
|
||||
target_ids = target_inputs["input_ids"].squeeze()
|
||||
src_mask = source_inputs["attention_mask"].squeeze()
|
||||
return {
|
||||
"input_ids": source_ids,
|
||||
"attention_mask": src_mask,
|
||||
"labels": target_ids,
|
||||
}
|
||||
|
||||
def encode_line(self, tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"):
|
||||
"""Only used by LegacyDataset"""
|
||||
return tokenizer(
|
||||
[line],
|
||||
max_length=max_length,
|
||||
padding="max_length" if pad_to_max_length else None,
|
||||
truncation=True,
|
||||
return_tensors=return_tensors,
|
||||
**self.dataset_kwargs,
|
||||
)
|
||||
|
||||
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
|
||||
input_ids = torch.stack([x["input_ids"] for x in batch])
|
||||
masks = torch.stack([x["attention_mask"] for x in batch])
|
||||
target_ids = torch.stack([x["labels"] for x in batch])
|
||||
pad_token_id = self.pad_token_id
|
||||
y = trim_batch(target_ids, pad_token_id)
|
||||
source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks)
|
||||
batch = {
|
||||
"input_ids": source_ids,
|
||||
"attention_mask": source_mask,
|
||||
"labels": y,
|
||||
}
|
||||
return batch
|
||||
|
||||
|
||||
class Seq2SeqDataset(AbstractSeq2SeqDataset):
|
||||
"""A dataset that calls prepare_seq2seq_batch."""
|
||||
|
||||
def __getitem__(self, index) -> Dict[str, str]:
|
||||
index = index + 1 # linecache starts at 1
|
||||
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
|
||||
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
|
||||
assert source_line, f"empty source line for index {index}"
|
||||
assert tgt_line, f"empty tgt line for index {index}"
|
||||
return {"tgt_texts": tgt_line, "src_texts": source_line, "id": index - 1}
|
||||
|
||||
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
|
||||
"""Call prepare_seq2seq_batch."""
|
||||
batch_encoding: Dict[str, torch.Tensor] = self.tokenizer.prepare_seq2seq_batch(
|
||||
[x["src_texts"] for x in batch],
|
||||
tgt_texts=[x["tgt_texts"] for x in batch],
|
||||
max_length=self.max_source_length,
|
||||
max_target_length=self.max_target_length,
|
||||
return_tensors="pt",
|
||||
**self.dataset_kwargs,
|
||||
).data
|
||||
batch_encoding["ids"] = torch.tensor([x["id"] for x in batch])
|
||||
return batch_encoding
|
||||
|
||||
|
||||
class Seq2SeqDataCollator:
|
||||
def __init__(self, tokenizer, data_args, tpu_num_cores=None):
|
||||
self.tokenizer = tokenizer
|
||||
self.pad_token_id = tokenizer.pad_token_id
|
||||
assert (
|
||||
self.pad_token_id is not None
|
||||
), f"pad_token_id is not defined for ({self.tokenizer.__class__.__name__}), it must be defined."
|
||||
self.data_args = data_args
|
||||
self.tpu_num_cores = tpu_num_cores
|
||||
self.dataset_kwargs = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
|
||||
if data_args.src_lang is not None:
|
||||
self.dataset_kwargs["src_lang"] = data_args.src_lang
|
||||
if data_args.tgt_lang is not None:
|
||||
self.dataset_kwargs["tgt_lang"] = data_args.tgt_lang
|
||||
|
||||
def __call__(self, batch) -> Dict[str, torch.Tensor]:
|
||||
if hasattr(self.tokenizer, "prepare_seq2seq_batch"):
|
||||
batch = self._encode(batch)
|
||||
input_ids, attention_mask, labels = (
|
||||
batch["input_ids"],
|
||||
batch["attention_mask"],
|
||||
batch["labels"],
|
||||
)
|
||||
else:
|
||||
input_ids = torch.stack([x["input_ids"] for x in batch])
|
||||
attention_mask = torch.stack([x["attention_mask"] for x in batch])
|
||||
labels = torch.stack([x["labels"] for x in batch])
|
||||
|
||||
labels = trim_batch(labels, self.pad_token_id)
|
||||
input_ids, attention_mask = trim_batch(input_ids, self.pad_token_id, attention_mask=attention_mask)
|
||||
|
||||
if isinstance(self.tokenizer, T5Tokenizer):
|
||||
decoder_input_ids = self._shift_right_t5(labels)
|
||||
else:
|
||||
decoder_input_ids = shift_tokens_right(labels, self.pad_token_id)
|
||||
|
||||
batch = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"labels": labels,
|
||||
}
|
||||
return batch
|
||||
|
||||
def _shift_right_t5(self, input_ids):
|
||||
# shift inputs to the right
|
||||
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
||||
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
||||
shifted_input_ids[..., 0] = self.pad_token_id
|
||||
return shifted_input_ids
|
||||
|
||||
def _encode(self, batch) -> Dict[str, torch.Tensor]:
|
||||
batch_encoding = self.tokenizer.prepare_seq2seq_batch(
|
||||
[x["src_texts"] for x in batch],
|
||||
tgt_texts=[x["tgt_texts"] for x in batch],
|
||||
max_length=self.data_args.max_source_length,
|
||||
max_target_length=self.data_args.max_target_length,
|
||||
padding="max_length" if self.tpu_num_cores is not None else "longest", # TPU hack
|
||||
return_tensors="pt",
|
||||
**self.dataset_kwargs,
|
||||
)
|
||||
return batch_encoding.data
|
||||
|
||||
|
||||
class SortishSampler(Sampler):
|
||||
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
|
||||
|
||||
def __init__(self, data, batch_size, shuffle=True):
|
||||
self.data, self.bs, self.shuffle = data, batch_size, shuffle
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(sortish_sampler_indices(self.data, self.bs, shuffle=self.shuffle))
|
||||
|
||||
|
||||
def sortish_sampler_indices(data: List, bs: int, shuffle=True) -> np.array:
|
||||
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
|
||||
if not shuffle:
|
||||
return np.argsort(np.array(data) * -1)
|
||||
|
||||
def key_fn(i):
|
||||
return data[i]
|
||||
|
||||
idxs = np.random.permutation(len(data))
|
||||
sz = bs * 50
|
||||
ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)]
|
||||
sort_idx = np.concatenate([sorted(s, key=key_fn, reverse=True) for s in ck_idx])
|
||||
sz = bs
|
||||
ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)]
|
||||
max_ck = np.argmax([key_fn(ck[0]) for ck in ck_idx]) # find the chunk with the largest key,
|
||||
ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0] # then make sure it goes first.
|
||||
sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=np.int)
|
||||
sort_idx = np.concatenate((ck_idx[0], sort_idx))
|
||||
return sort_idx
|
||||
|
||||
|
||||
class DistributedSortishSampler(Sampler):
|
||||
"""Copied from torch DistributedSampler"""
|
||||
|
||||
def __init__(self, dataset, batch_size, num_replicas=None, rank=None, add_extra_examples=True, shuffle=True):
|
||||
if num_replicas is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
num_replicas = dist.get_world_size()
|
||||
if rank is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
rank = dist.get_rank()
|
||||
self.dataset = dataset
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
self.epoch = 0
|
||||
if add_extra_examples:
|
||||
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
else:
|
||||
self.total_size = len(dataset)
|
||||
self.num_samples = len(self.available_indices)
|
||||
self.batch_size = batch_size
|
||||
self.add_extra_examples = add_extra_examples
|
||||
self.shuffle = shuffle
|
||||
|
||||
def __iter__(self) -> Iterable:
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.epoch)
|
||||
|
||||
sortish_data = [self.dataset.src_lens[i] for i in self.available_indices]
|
||||
sortish_indices = sortish_sampler_indices(sortish_data, self.batch_size, shuffle=self.shuffle)
|
||||
indices = [self.available_indices[i] for i in sortish_indices]
|
||||
assert len(indices) == self.num_samples
|
||||
return iter(indices)
|
||||
|
||||
@cached_property
|
||||
def available_indices(self) -> np.array:
|
||||
indices = list(range(len(self.dataset)))
|
||||
# add extra samples to make it evenly divisible
|
||||
indices += indices[: (self.total_size - len(indices))]
|
||||
assert len(indices) == self.total_size
|
||||
# subsample
|
||||
available_indices = indices[self.rank : self.total_size : self.num_replicas]
|
||||
return available_indices
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def use_task_specific_params(model, task):
|
||||
"""Update config with summarization specific params."""
|
||||
task_specific_params = model.config.task_specific_params
|
||||
|
||||
if task_specific_params is not None:
|
||||
pars = task_specific_params.get(task, {})
|
||||
logger.info(f"using task specific params for {task}: {pars}")
|
||||
model.config.update(pars)
|
||||
|
||||
|
||||
def pickle_load(path):
|
||||
"""pickle.load(path)"""
|
||||
with open(path, "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
|
||||
def pickle_save(obj, path):
|
||||
"""pickle.dump(obj, path)"""
|
||||
with open(path, "wb") as f:
|
||||
return pickle.dump(obj, f)
|
||||
|
||||
|
||||
def flatten_list(summary_ids: List[List]):
|
||||
return [x for x in itertools.chain.from_iterable(summary_ids)]
|
||||
|
||||
|
||||
def save_git_info(folder_path: str) -> None:
|
||||
"""Save git information to output_dir/git_log.json"""
|
||||
repo_infos = get_git_info()
|
||||
save_json(repo_infos, os.path.join(folder_path, "git_log.json"))
|
||||
|
||||
|
||||
def save_json(content, path, indent=4, **json_dump_kwargs):
|
||||
with open(path, "w") as f:
|
||||
json.dump(content, f, indent=indent, **json_dump_kwargs)
|
||||
|
||||
|
||||
def load_json(path):
|
||||
with open(path) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def get_git_info():
|
||||
try:
|
||||
repo = git.Repo(search_parent_directories=True)
|
||||
repo_infos = {
|
||||
"repo_id": str(repo),
|
||||
"repo_sha": str(repo.head.object.hexsha),
|
||||
"repo_branch": str(repo.active_branch),
|
||||
"hostname": str(socket.gethostname()),
|
||||
}
|
||||
return repo_infos
|
||||
except TypeError:
|
||||
return {
|
||||
"repo_id": None,
|
||||
"repo_sha": None,
|
||||
"repo_branch": None,
|
||||
"hostname": None,
|
||||
}
|
||||
|
||||
|
||||
ROUGE_KEYS = ["rouge1", "rouge2", "rougeL", "rougeLsum"]
|
||||
|
||||
|
||||
def extract_rouge_mid_statistics(dct):
|
||||
new_dict = {}
|
||||
for k1, v1 in dct.items():
|
||||
mid = v1.mid
|
||||
new_dict[k1] = {stat: round(getattr(mid, stat), 4) for stat in ["precision", "recall", "fmeasure"]}
|
||||
return new_dict
|
||||
|
||||
|
||||
def calculate_rouge(
|
||||
pred_lns: List[str],
|
||||
tgt_lns: List[str],
|
||||
use_stemmer=True,
|
||||
rouge_keys=ROUGE_KEYS,
|
||||
return_precision_and_recall=False,
|
||||
bootstrap_aggregation=True,
|
||||
newline_sep=True,
|
||||
) -> Dict:
|
||||
"""Calculate rouge using rouge_scorer package.
|
||||
|
||||
Args:
|
||||
pred_lns: list of summaries generated by model
|
||||
tgt_lns: list of groundtruth summaries (e.g. contents of val.target)
|
||||
use_stemmer: Bool indicating whether Porter stemmer should be used to
|
||||
strip word suffixes to improve matching.
|
||||
rouge_keys: which metrics to compute, defaults to rouge1, rouge2, rougeL, rougeLsum
|
||||
return_precision_and_recall: (False) whether to also return precision and recall.
|
||||
bootstrap_aggregation: whether to do the typical bootstrap resampling of scores. Defaults to True, if False
|
||||
this function returns a collections.defaultdict[metric: list of values for each observation for each subscore]``
|
||||
newline_sep:(default=True) whether to add newline between sentences. This is essential for calculation rougeL
|
||||
on multi sentence summaries (CNN/DM dataset).
|
||||
|
||||
Returns:
|
||||
Dict[score: value] if aggregate else defaultdict(list) keyed by rouge_keys
|
||||
|
||||
"""
|
||||
scorer = rouge_scorer.RougeScorer(rouge_keys, use_stemmer=use_stemmer)
|
||||
aggregator = scoring.BootstrapAggregator()
|
||||
for pred, tgt in zip(tgt_lns, pred_lns):
|
||||
# rougeLsum expects "\n" separated sentences within a summary
|
||||
if newline_sep:
|
||||
pred = add_newline_to_end_of_each_sentence(pred)
|
||||
tgt = add_newline_to_end_of_each_sentence(tgt)
|
||||
scores = scorer.score(pred, tgt)
|
||||
aggregator.add_scores(scores)
|
||||
|
||||
if bootstrap_aggregation:
|
||||
result = aggregator.aggregate()
|
||||
if return_precision_and_recall:
|
||||
return extract_rouge_mid_statistics(result) # here we return dict
|
||||
else:
|
||||
return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()}
|
||||
|
||||
else:
|
||||
return aggregator._scores # here we return defaultdict(list)
|
||||
|
||||
|
||||
# Utilities for freezing parameters and checking whether they are frozen
|
||||
|
||||
|
||||
def freeze_params(model: nn.Module):
|
||||
"""Set requires_grad=False for each of model.parameters()"""
|
||||
for par in model.parameters():
|
||||
par.requires_grad = False
|
||||
|
||||
|
||||
def freeze_embeds(model):
|
||||
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
||||
model_type = model.config.model_type
|
||||
|
||||
if model_type == "t5":
|
||||
freeze_params(model.shared)
|
||||
for d in [model.encoder, model.decoder]:
|
||||
freeze_params(d.embed_tokens)
|
||||
elif model_type == "fsmt":
|
||||
for d in [model.model.encoder, model.model.decoder]:
|
||||
freeze_params(d.embed_positions)
|
||||
freeze_params(d.embed_tokens)
|
||||
else:
|
||||
freeze_params(model.model.shared)
|
||||
for d in [model.model.encoder, model.model.decoder]:
|
||||
freeze_params(d.embed_positions)
|
||||
freeze_params(d.embed_tokens)
|
||||
|
||||
|
||||
def grad_status(model: nn.Module) -> Iterable:
|
||||
return (par.requires_grad for par in model.parameters())
|
||||
|
||||
|
||||
def any_requires_grad(model: nn.Module) -> bool:
|
||||
return any(grad_status(model))
|
||||
|
||||
|
||||
def assert_all_frozen(model):
|
||||
model_grads: List[bool] = list(grad_status(model))
|
||||
n_require_grad = sum(lmap(int, model_grads))
|
||||
npars = len(model_grads)
|
||||
assert not any(model_grads), f"{n_require_grad/npars:.1%} of {npars} weights require grad"
|
||||
|
||||
|
||||
def assert_not_all_frozen(model):
|
||||
model_grads: List[bool] = list(grad_status(model))
|
||||
npars = len(model_grads)
|
||||
assert any(model_grads), f"none of {npars} weights require grad"
|
||||
|
||||
|
||||
def parse_numeric_n_bool_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, float, bool]]:
|
||||
"""
|
||||
Parse an argv list of unspecified command line args to a dict.
|
||||
Assumes all values are either numeric or boolean in the form of true/false.
|
||||
"""
|
||||
result = {}
|
||||
assert len(unparsed_args) % 2 == 0, f"got odd number of unparsed args: {unparsed_args}"
|
||||
num_pairs = len(unparsed_args) // 2
|
||||
for pair_num in range(num_pairs):
|
||||
i = 2 * pair_num
|
||||
assert unparsed_args[i].startswith("--")
|
||||
if unparsed_args[i + 1].lower() == "true":
|
||||
value = True
|
||||
elif unparsed_args[i + 1].lower() == "false":
|
||||
value = False
|
||||
else:
|
||||
try:
|
||||
value = int(unparsed_args[i + 1])
|
||||
except ValueError:
|
||||
value = float(unparsed_args[i + 1]) # this can raise another informative ValueError
|
||||
|
||||
result[unparsed_args[i][2:]] = value
|
||||
return result
|
||||
|
||||
|
||||
def write_txt_file(ordered_tgt, path):
|
||||
f = Path(path).open("w")
|
||||
for ln in ordered_tgt:
|
||||
f.write(ln + "\n")
|
||||
f.flush()
|
||||
|
||||
|
||||
def chunks(lst, n):
|
||||
"""Yield successive n-sized chunks from lst."""
|
||||
for i in range(0, len(lst), n):
|
||||
yield lst[i : i + n]
|
||||
|
||||
|
||||
def check_output_dir(args, expected_items=0):
|
||||
"""
|
||||
Checks whether to bail out if output_dir already exists and has more than expected_items in it
|
||||
|
||||
`args`: needs to have the following attributes of `args`:
|
||||
- output_dir
|
||||
- do_train
|
||||
- overwrite_output_dir
|
||||
|
||||
`expected_items`: normally 0 (default) - i.e. empty dir, but in some cases a few files are expected (e.g. recovery from OOM)
|
||||
"""
|
||||
if (
|
||||
os.path.exists(args.output_dir)
|
||||
and len(os.listdir(args.output_dir)) > expected_items
|
||||
and args.do_train
|
||||
and not args.overwrite_output_dir
|
||||
):
|
||||
raise ValueError(
|
||||
f"Output directory ({args.output_dir}) already exists and "
|
||||
f"has {len(os.listdir(args.output_dir))} items in it (expected {expected_items} items). "
|
||||
"Use --overwrite_output_dir to overcome."
|
||||
)
|
||||
@@ -18,7 +18,7 @@ limitations under the License.
|
||||
|
||||
This directory contains examples for finetuning and evaluating transformers on summarization and translation tasks.
|
||||
Please tag @patil-suraj with any issues/unexpected behaviors, or send a PR!
|
||||
For deprecated `bertabs` instructions, see [`bertabs/README.md`](bertabs/README.md).
|
||||
For deprecated `bertabs` instructions, see [`bertabs/README.md`](https://github.com/huggingface/transformers/blob/master/examples/research_projects/bertabs/README.md).
|
||||
|
||||
### Supported Architectures
|
||||
|
||||
|
||||
47
examples/seq2seq/ds_config.json
Normal file
47
examples/seq2seq/ds_config.json
Normal file
@@ -0,0 +1,47 @@
|
||||
{
|
||||
"fp16": {
|
||||
"enabled": true,
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 1000,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"allgather_partitions": true,
|
||||
"allgather_bucket_size": 2e8,
|
||||
"overlap_comm": true,
|
||||
"reduce_scatter": true,
|
||||
"reduce_bucket_size": 2e8,
|
||||
"contiguous_gradients": true,
|
||||
"cpu_offload": true
|
||||
},
|
||||
|
||||
"zero_allow_untested_optimizer": true,
|
||||
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": 3e-5,
|
||||
"betas": [
|
||||
0.8,
|
||||
0.999
|
||||
],
|
||||
"eps": 1e-8,
|
||||
"weight_decay": 3e-7
|
||||
}
|
||||
},
|
||||
|
||||
"scheduler": {
|
||||
"type": "WarmupLR",
|
||||
"params": {
|
||||
"warmup_min_lr": 0,
|
||||
"warmup_max_lr": 3e-5,
|
||||
"warmup_num_steps": 500
|
||||
}
|
||||
},
|
||||
|
||||
"steps_per_print": 2000,
|
||||
"wall_clock_breakdown": false
|
||||
}
|
||||
@@ -16,14 +16,20 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import transformers
|
||||
from seq2seq_trainer import Seq2SeqTrainer
|
||||
from seq2seq_training_args import Seq2SeqTrainingArguments
|
||||
from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer, HfArgumentParser, MBartTokenizer, set_seed
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoTokenizer,
|
||||
HfArgumentParser,
|
||||
MBartTokenizer,
|
||||
Seq2SeqTrainer,
|
||||
Seq2SeqTrainingArguments,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.trainer_utils import EvaluationStrategy, is_main_process
|
||||
from transformers.training_args import ParallelMode
|
||||
from utils import (
|
||||
@@ -98,7 +104,9 @@ class DataTrainingArguments:
|
||||
default=142,
|
||||
metadata={
|
||||
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
"than this will be truncated, sequences shorter will be padded. "
|
||||
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
|
||||
"during ``evaluate`` and ``predict``."
|
||||
},
|
||||
)
|
||||
test_max_target_length: Optional[int] = field(
|
||||
@@ -120,30 +128,6 @@ class DataTrainingArguments:
|
||||
)
|
||||
|
||||
|
||||
def speed_metrics(split, start_time, num_samples):
|
||||
"""
|
||||
Measure and return speed performance metrics.
|
||||
|
||||
This function requires a time snapshot `start_time` before the operation to be measured starts and this
|
||||
function should be run immediately after the operation to be measured has completed.
|
||||
|
||||
Args:
|
||||
- split: one of train, val, test
|
||||
- start_time: operation start time
|
||||
- num_samples: number of samples processed
|
||||
|
||||
"""
|
||||
runtime = time.time() - start_time
|
||||
result = {}
|
||||
|
||||
samples_per_second = 1 / (runtime / num_samples)
|
||||
result[f"{split}_samples_per_second"] = round(samples_per_second, 3)
|
||||
result[f"{split}_runtime"] = round(runtime, 4)
|
||||
|
||||
result[f"{split}_n_ojbs"] = num_samples
|
||||
return result
|
||||
|
||||
|
||||
def handle_metrics(split, metrics, output_dir):
|
||||
"""
|
||||
Log and save metrics
|
||||
@@ -155,8 +139,8 @@ def handle_metrics(split, metrics, output_dir):
|
||||
"""
|
||||
|
||||
logger.info(f"***** {split} metrics *****")
|
||||
for key, value in metrics.items():
|
||||
logger.info(f" {key} = {value}")
|
||||
for key in sorted(metrics.keys()):
|
||||
logger.info(f" {key} = {metrics[key]}")
|
||||
save_json(metrics, os.path.join(output_dir, f"{split}_results.json"))
|
||||
|
||||
|
||||
@@ -297,13 +281,12 @@ def main():
|
||||
)
|
||||
trainer = Seq2SeqTrainer(
|
||||
model=model,
|
||||
config=config,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
data_collator=Seq2SeqDataCollator(tokenizer, data_args, training_args.tpu_num_cores),
|
||||
compute_metrics=compute_metrics_fn,
|
||||
data_args=data_args,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
all_metrics = {}
|
||||
@@ -311,11 +294,11 @@ def main():
|
||||
if training_args.do_train:
|
||||
logger.info("*** Train ***")
|
||||
|
||||
start_time = time.time()
|
||||
trainer.train(
|
||||
train_result = trainer.train(
|
||||
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
|
||||
)
|
||||
metrics = speed_metrics("train", start_time, data_args.n_train)
|
||||
metrics = train_result.metrics
|
||||
metrics["train_n_objs"] = data_args.n_train
|
||||
|
||||
trainer.save_model() # this also saves the tokenizer
|
||||
|
||||
@@ -334,9 +317,10 @@ def main():
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
|
||||
start_time = time.time()
|
||||
metrics = trainer.evaluate(metric_key_prefix="val")
|
||||
metrics.update(speed_metrics("val", start_time, data_args.n_val))
|
||||
metrics = trainer.evaluate(
|
||||
metric_key_prefix="val", max_length=data_args.val_max_target_length, num_beams=data_args.eval_beams
|
||||
)
|
||||
metrics["val_n_objs"] = data_args.n_val
|
||||
metrics["val_loss"] = round(metrics["val_loss"], 4)
|
||||
|
||||
if trainer.is_world_process_zero():
|
||||
@@ -347,10 +331,14 @@ def main():
|
||||
if training_args.do_predict:
|
||||
logger.info("*** Predict ***")
|
||||
|
||||
start_time = time.time()
|
||||
test_output = trainer.predict(test_dataset=test_dataset, metric_key_prefix="test")
|
||||
test_output = trainer.predict(
|
||||
test_dataset=test_dataset,
|
||||
metric_key_prefix="test",
|
||||
max_length=data_args.val_max_target_length,
|
||||
num_beams=data_args.eval_beams,
|
||||
)
|
||||
metrics = test_output.metrics
|
||||
metrics.update(speed_metrics("test", start_time, data_args.n_test))
|
||||
metrics["test_n_objs"] = data_args.n_test
|
||||
|
||||
if trainer.is_world_process_zero():
|
||||
metrics["test_loss"] = round(metrics["test_loss"], 4)
|
||||
|
||||
@@ -23,7 +23,7 @@ from pack_dataset import pack_data_dir
|
||||
from parameterized import parameterized
|
||||
from save_len_file import save_len_file
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.bart.modeling_bart import shift_tokens_right
|
||||
from transformers.models.mbart.modeling_mbart import shift_tokens_right
|
||||
from transformers.testing_utils import TestCasePlus, require_torch_non_multi_gpu_but_fix_me, slow
|
||||
from utils import FAIRSEQ_AVAILABLE, DistributedSortishSampler, LegacySeq2SeqDataset, Seq2SeqDataset
|
||||
|
||||
|
||||
@@ -14,10 +14,11 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from transformers import BertTokenizer, EncoderDecoderModel
|
||||
from transformers.file_utils import is_datasets_available
|
||||
from transformers.file_utils import is_apex_available
|
||||
from transformers.integrations import is_deepspeed_available, is_fairscale_available
|
||||
from transformers.testing_utils import (
|
||||
TestCasePlus,
|
||||
execute_subprocess_async,
|
||||
@@ -29,8 +30,7 @@ from transformers.testing_utils import (
|
||||
from transformers.trainer_callback import TrainerState
|
||||
from transformers.trainer_utils import set_seed
|
||||
|
||||
from .finetune_trainer import Seq2SeqTrainingArguments, main
|
||||
from .seq2seq_trainer import Seq2SeqTrainer
|
||||
from .finetune_trainer import main
|
||||
|
||||
|
||||
set_seed(42)
|
||||
@@ -38,9 +38,42 @@ MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
|
||||
MBART_TINY = "sshleifer/tiny-mbart"
|
||||
|
||||
|
||||
# a candidate for testing_utils
|
||||
def require_fairscale(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires fairscale
|
||||
"""
|
||||
if not is_fairscale_available():
|
||||
return unittest.skip("test requires fairscale")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
# a candidate for testing_utils
|
||||
def require_deepspeed(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires deepspeed
|
||||
"""
|
||||
if not is_deepspeed_available():
|
||||
return unittest.skip("test requires deepspeed")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
# a candidate for testing_utils
|
||||
def require_apex(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires apex
|
||||
"""
|
||||
if not is_apex_available():
|
||||
return unittest.skip("test requires apex")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
class TestFinetuneTrainer(TestCasePlus):
|
||||
def finetune_trainer_quick(self, distributed=None):
|
||||
output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed)
|
||||
def finetune_trainer_quick(self, distributed=None, deepspeed=False, extra_args_str=None):
|
||||
output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, deepspeed, extra_args_str)
|
||||
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
|
||||
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
|
||||
first_step_stats = eval_metrics[0]
|
||||
@@ -59,6 +92,26 @@ class TestFinetuneTrainer(TestCasePlus):
|
||||
def test_finetune_trainer_ddp(self):
|
||||
self.finetune_trainer_quick(distributed=True)
|
||||
|
||||
# it's crucial to test --sharded_ddp w/ and w/o --fp16
|
||||
@require_torch_multi_gpu
|
||||
@require_fairscale
|
||||
def test_finetune_trainer_ddp_sharded_ddp(self):
|
||||
self.finetune_trainer_quick(distributed=True, extra_args_str="--sharded_ddp")
|
||||
|
||||
@require_torch_multi_gpu
|
||||
@require_fairscale
|
||||
def test_finetune_trainer_ddp_sharded_ddp_fp16(self):
|
||||
self.finetune_trainer_quick(distributed=True, extra_args_str="--sharded_ddp --fp16")
|
||||
|
||||
@require_apex
|
||||
def test_finetune_trainer_apex(self):
|
||||
self.finetune_trainer_quick(extra_args_str="--fp16 --fp16_backend=apex")
|
||||
|
||||
@require_torch_multi_gpu
|
||||
@require_deepspeed
|
||||
def test_finetune_trainer_deepspeed(self):
|
||||
self.finetune_trainer_quick(deepspeed=True)
|
||||
|
||||
@slow
|
||||
def test_finetune_trainer_slow(self):
|
||||
# There is a missing call to __init__process_group somewhere
|
||||
@@ -81,121 +134,15 @@ class TestFinetuneTrainer(TestCasePlus):
|
||||
assert "test_generations.txt" in contents
|
||||
assert "test_results.json" in contents
|
||||
|
||||
@slow
|
||||
def test_finetune_bert2bert(self):
|
||||
if not is_datasets_available():
|
||||
return
|
||||
|
||||
import datasets
|
||||
|
||||
bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("prajjwal1/bert-tiny", "prajjwal1/bert-tiny")
|
||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
bert2bert.config.vocab_size = bert2bert.config.encoder.vocab_size
|
||||
bert2bert.config.eos_token_id = tokenizer.sep_token_id
|
||||
bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id
|
||||
bert2bert.config.max_length = 128
|
||||
|
||||
train_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="train[:1%]")
|
||||
val_dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="validation[:1%]")
|
||||
|
||||
train_dataset = train_dataset.select(range(32))
|
||||
val_dataset = val_dataset.select(range(16))
|
||||
|
||||
rouge = datasets.load_metric("rouge")
|
||||
|
||||
batch_size = 4
|
||||
|
||||
def _map_to_encoder_decoder_inputs(batch):
|
||||
# Tokenizer will automatically set [BOS] <text> [EOS]
|
||||
inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=512)
|
||||
outputs = tokenizer(batch["highlights"], padding="max_length", truncation=True, max_length=128)
|
||||
batch["input_ids"] = inputs.input_ids
|
||||
batch["attention_mask"] = inputs.attention_mask
|
||||
|
||||
batch["decoder_input_ids"] = outputs.input_ids
|
||||
batch["labels"] = outputs.input_ids.copy()
|
||||
batch["labels"] = [
|
||||
[-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]
|
||||
]
|
||||
batch["decoder_attention_mask"] = outputs.attention_mask
|
||||
|
||||
assert all([len(x) == 512 for x in inputs.input_ids])
|
||||
assert all([len(x) == 128 for x in outputs.input_ids])
|
||||
|
||||
return batch
|
||||
|
||||
def _compute_metrics(pred):
|
||||
labels_ids = pred.label_ids
|
||||
pred_ids = pred.predictions
|
||||
|
||||
# all unnecessary tokens are removed
|
||||
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
||||
label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
|
||||
|
||||
rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])[
|
||||
"rouge2"
|
||||
].mid
|
||||
|
||||
return {
|
||||
"rouge2_precision": round(rouge_output.precision, 4),
|
||||
"rouge2_recall": round(rouge_output.recall, 4),
|
||||
"rouge2_fmeasure": round(rouge_output.fmeasure, 4),
|
||||
}
|
||||
|
||||
# map train dataset
|
||||
train_dataset = train_dataset.map(
|
||||
_map_to_encoder_decoder_inputs,
|
||||
batched=True,
|
||||
batch_size=batch_size,
|
||||
remove_columns=["article", "highlights"],
|
||||
)
|
||||
train_dataset.set_format(
|
||||
type="torch",
|
||||
columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
|
||||
)
|
||||
|
||||
# same for validation dataset
|
||||
val_dataset = val_dataset.map(
|
||||
_map_to_encoder_decoder_inputs,
|
||||
batched=True,
|
||||
batch_size=batch_size,
|
||||
remove_columns=["article", "highlights"],
|
||||
)
|
||||
val_dataset.set_format(
|
||||
type="torch",
|
||||
columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
|
||||
)
|
||||
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
|
||||
training_args = Seq2SeqTrainingArguments(
|
||||
output_dir=output_dir,
|
||||
per_device_train_batch_size=batch_size,
|
||||
per_device_eval_batch_size=batch_size,
|
||||
predict_with_generate=True,
|
||||
evaluation_strategy="steps",
|
||||
do_train=True,
|
||||
do_eval=True,
|
||||
warmup_steps=0,
|
||||
eval_steps=2,
|
||||
logging_steps=2,
|
||||
)
|
||||
|
||||
# instantiate trainer
|
||||
trainer = Seq2SeqTrainer(
|
||||
model=bert2bert,
|
||||
args=training_args,
|
||||
compute_metrics=_compute_metrics,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=val_dataset,
|
||||
)
|
||||
|
||||
# start training
|
||||
trainer.train()
|
||||
|
||||
def run_trainer(
|
||||
self, eval_steps: int, max_len: str, model_name: str, num_train_epochs: int, distributed: bool = False
|
||||
self,
|
||||
eval_steps: int,
|
||||
max_len: str,
|
||||
model_name: str,
|
||||
num_train_epochs: int,
|
||||
distributed: bool = False,
|
||||
deepspeed: bool = False,
|
||||
extra_args_str: str = None,
|
||||
):
|
||||
data_dir = self.examples_dir / "seq2seq/test_data/wmt_en_ro"
|
||||
output_dir = self.get_auto_remove_tmp_dir()
|
||||
@@ -223,7 +170,7 @@ class TestFinetuneTrainer(TestCasePlus):
|
||||
--save_steps {str(eval_steps)}
|
||||
--eval_steps {str(eval_steps)}
|
||||
--sortish_sampler
|
||||
--label_smoothing 0.1
|
||||
--label_smoothing_factor 0.1
|
||||
--adafactor
|
||||
--task translation
|
||||
--tgt_lang ro_RO
|
||||
@@ -231,7 +178,18 @@ class TestFinetuneTrainer(TestCasePlus):
|
||||
""".split()
|
||||
# --eval_beams 2
|
||||
|
||||
if distributed:
|
||||
if extra_args_str is not None:
|
||||
args.extend(extra_args_str.split())
|
||||
|
||||
if deepspeed:
|
||||
ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config.json".split()
|
||||
distributed_args = f"""
|
||||
{self.test_file_dir}/finetune_trainer.py
|
||||
""".split()
|
||||
cmd = ["deepspeed"] + distributed_args + args + ds_args
|
||||
execute_subprocess_async(cmd, env=self.get_env())
|
||||
|
||||
elif distributed:
|
||||
n_gpu = get_gpu_count()
|
||||
distributed_args = f"""
|
||||
-m torch.distributed.launch
|
||||
@@ -240,6 +198,7 @@ class TestFinetuneTrainer(TestCasePlus):
|
||||
""".split()
|
||||
cmd = [sys.executable] + distributed_args + args
|
||||
execute_subprocess_async(cmd, env=self.get_env())
|
||||
|
||||
else:
|
||||
testargs = ["finetune_trainer.py"] + args
|
||||
with patch.object(sys, "argv", testargs):
|
||||
|
||||
@@ -29,9 +29,10 @@ python finetune_trainer.py \
|
||||
--freeze_encoder --freeze_embeds \
|
||||
--num_train_epochs=6 \
|
||||
--save_steps 3000 --eval_steps 3000 \
|
||||
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
|
||||
--max_source_length $MAX_LEN --max_target_length $MAX_LEN \
|
||||
--val_max_target_length $MAX_TGT_LEN --test_max_target_length $MAX_TGT_LEN \
|
||||
--do_train --do_eval --do_predict \
|
||||
--evaluation_strategy steps \
|
||||
--predict_with_generate --logging_first_step \
|
||||
--task translation --label_smoothing 0.1 \
|
||||
--task translation --label_smoothing_factor 0.1 \
|
||||
"$@"
|
||||
|
||||
@@ -30,9 +30,10 @@ python xla_spawn.py --num_cores $TPU_NUM_CORES \
|
||||
--num_train_epochs=6 \
|
||||
--save_steps 500 --eval_steps 500 \
|
||||
--logging_first_step --logging_steps 200 \
|
||||
--max_source_length $MAX_LEN --max_target_length $MAX_LEN --val_max_target_length $MAX_LEN --test_max_target_length $MAX_LEN \
|
||||
--max_source_length $MAX_LEN --max_target_length $MAX_LEN \
|
||||
--val_max_target_length $MAX_TGT_LEN --test_max_target_length $MAX_TGT_LEN \
|
||||
--do_train --do_eval \
|
||||
--evaluation_strategy steps \
|
||||
--prediction_loss_only \
|
||||
--task translation --label_smoothing 0.1 \
|
||||
--task translation --label_smoothing_factor 0.1 \
|
||||
"$@"
|
||||
|
||||
@@ -32,7 +32,7 @@ python finetune_trainer.py \
|
||||
--num_train_epochs=2 \
|
||||
--save_steps 3000 --eval_steps 3000 \
|
||||
--logging_first_step \
|
||||
--max_target_length 56 --val_max_target_length $MAX_TGT_LEN --test_max_target_length $MAX_TGT_LEN \
|
||||
--max_target_length 56 --val_max_target_length $MAX_TGT_LEN --test_max_target_length $MAX_TGT_LEN\
|
||||
--do_train --do_eval --do_predict \
|
||||
--evaluation_strategy steps \
|
||||
--predict_with_generate --sortish_sampler \
|
||||
|
||||
@@ -24,8 +24,7 @@ python finetune_trainer.py \
|
||||
--src_lang en_XX --tgt_lang ro_RO \
|
||||
--freeze_embeds \
|
||||
--per_device_train_batch_size=4 --per_device_eval_batch_size=4 \
|
||||
--max_source_length 128 --max_target_length 128 \
|
||||
--val_max_target_length 128 --test_max_target_length 128 \
|
||||
--max_source_length 128 --max_target_length 128 --val_max_target_length 128 --test_max_target_length 128\
|
||||
--sortish_sampler \
|
||||
--num_train_epochs 6 \
|
||||
--save_steps 25000 --eval_steps 25000 --logging_steps 1000 \
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user