* flax gpt2

* combine masks

* handle shared embeds

* add causal LM sample

* style

* add tests

* style

* fix imports, docs, quality

* don't use cache

* add cache

* add cache 1st version

* make use cache work

* start adding test for generation

* finish generation loop compilation

* rewrite test

* finish

* update

* update

* apply sylvains suggestions

* update

* refactor

* fix typo

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Suraj Patil
2021-05-19 03:20:51 +05:30
committed by GitHub
parent eb3e072a3b
commit ca33278fdb
13 changed files with 1106 additions and 12 deletions

View File

@@ -355,7 +355,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| OpenAI GPT | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| OpenAI GPT-2 | ✅ | ✅ | ✅ | ✅ | |
| OpenAI GPT-2 | ✅ | ✅ | ✅ | ✅ | |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Pegasus | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+