committed by
GitHub
parent
808d6c50f8
commit
63ca6d9771
@@ -101,6 +101,10 @@ class FlaxGenerationTesterMixin:
|
||||
pt_model = pt_model_class(config).eval()
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, flax_model.params)
|
||||
|
||||
# Generate max 5 tokens only otherwise seems to be numerical error accumulation
|
||||
pt_model.generation_config.max_length = 5
|
||||
flax_model.generation_config.max_length = 5
|
||||
|
||||
flax_generation_outputs = flax_model.generate(input_ids).sequences
|
||||
pt_generation_outputs = pt_model.generate(torch.tensor(input_ids, dtype=torch.long))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user