Generate: fix TF XLA tests on models with max_position_embeddings or max_target_positions (#21389)
This commit is contained in:
@@ -1865,6 +1865,17 @@ class TFModelTesterMixin:
|
|||||||
config.eos_token_id = None # Generate until max length
|
config.eos_token_id = None # Generate until max length
|
||||||
config.do_sample = False
|
config.do_sample = False
|
||||||
|
|
||||||
|
# fix config for models with additional sequence-length limiting settings
|
||||||
|
for var_name in ["max_position_embeddings", "max_target_positions"]:
|
||||||
|
attr = getattr(config, var_name, None)
|
||||||
|
if attr is not None and attr < generate_kwargs["max_new_tokens"]:
|
||||||
|
try:
|
||||||
|
setattr(config, var_name, generate_kwargs["max_new_tokens"])
|
||||||
|
except NotImplementedError:
|
||||||
|
# xlnet will raise an exception when trying to set
|
||||||
|
# max_position_embeddings.
|
||||||
|
pass
|
||||||
|
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
|
|
||||||
if model.supports_xla_generation:
|
if model.supports_xla_generation:
|
||||||
|
|||||||
Reference in New Issue
Block a user