* flax gpt2

* combine masks

* handle shared embeds

* add causal LM sample

* style

* add tests

* style

* fix imports, docs, quality

* don't use cache

* add cache

* add cache 1st version

* make use cache work

* start adding test for generation

* finish generation loop compilation

* rewrite test

* finish

* update

* update

* apply sylvains suggestions

* update

* refactor

* fix typo

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Suraj Patil
2021-05-19 03:20:51 +05:30
committed by GitHub
parent eb3e072a3b
commit ca33278fdb
13 changed files with 1106 additions and 12 deletions

View File

@@ -247,12 +247,8 @@ class FlaxModelTesterMixin:
model = model_class(config)
@jax.jit
def model_jitted(input_ids, attention_mask=None, token_type_ids=None):
return model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
).to_tuple()
def model_jitted(input_ids, attention_mask=None, **kwargs):
return model(input_ids=input_ids, attention_mask=attention_mask, **kwargs).to_tuple()
with self.subTest("JIT Enabled"):
jitted_outputs = model_jitted(**prepared_inputs_dict)
@@ -266,11 +262,11 @@ class FlaxModelTesterMixin:
self.assertEqual(jitted_output.shape, output.shape)
@jax.jit
def model_jitted_return_dict(input_ids, attention_mask=None, token_type_ids=None):
def model_jitted_return_dict(input_ids, attention_mask=None, **kwargs):
return model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
**kwargs,
)
# jitted function cannot return OrderedDict