FlaxGPT2 (#11556)
* flax gpt2 * combine masks * handle shared embeds * add causal LM sample * style * add tests * style * fix imports, docs, quality * don't use cache * add cache * add cache 1st version * make use cache work * start adding test for generation * finish generation loop compilation * rewrite test * finish * update * update * apply sylvains suggestions * update * refactor * fix typo Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -247,12 +247,8 @@ class FlaxModelTesterMixin:
|
||||
model = model_class(config)
|
||||
|
||||
@jax.jit
|
||||
def model_jitted(input_ids, attention_mask=None, token_type_ids=None):
|
||||
return model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
).to_tuple()
|
||||
def model_jitted(input_ids, attention_mask=None, **kwargs):
|
||||
return model(input_ids=input_ids, attention_mask=attention_mask, **kwargs).to_tuple()
|
||||
|
||||
with self.subTest("JIT Enabled"):
|
||||
jitted_outputs = model_jitted(**prepared_inputs_dict)
|
||||
@@ -266,11 +262,11 @@ class FlaxModelTesterMixin:
|
||||
self.assertEqual(jitted_output.shape, output.shape)
|
||||
|
||||
@jax.jit
|
||||
def model_jitted_return_dict(input_ids, attention_mask=None, token_type_ids=None):
|
||||
def model_jitted_return_dict(input_ids, attention_mask=None, **kwargs):
|
||||
return model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# jitted function cannot return OrderedDict
|
||||
|
||||
Reference in New Issue
Block a user