From 991172922f9711d7bef160d6aedb2ed1059a88ff Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 3 Jul 2020 19:25:25 +0200 Subject: [PATCH] better error message (#5497) --- src/transformers/generation_tf_utils.py | 4 ++++ src/transformers/generation_utils.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/src/transformers/generation_tf_utils.py b/src/transformers/generation_tf_utils.py index 675ce26782..f84d420284 100644 --- a/src/transformers/generation_tf_utils.py +++ b/src/transformers/generation_tf_utils.py @@ -347,6 +347,10 @@ class TFGenerationMixin: encoder_outputs = None cur_len = shape_list(input_ids)[-1] + assert ( + cur_len < max_length + ), f"The context has {cur_len} number of tokens, but `max_length` is only {max_length}. Please make sure that `max_length` is bigger than the number of tokens, by setting either `generate(max_length=...,...)` or `config.max_length = ...`" + if num_beams > 1: output = self._generate_beam_search( input_ids, diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index ec0bf803ee..3c49e15abf 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -428,6 +428,10 @@ class GenerationMixin: encoder_outputs = None cur_len = input_ids.shape[-1] + assert ( + cur_len < max_length + ), f"The context has {cur_len} number of tokens, but `max_length` is only {max_length}. Please make sure that `max_length` is bigger than the number of tokens, by setting either `generate(max_length=...,...)` or `config.max_length = ...`" + if num_beams > 1: output = self._generate_beam_search( input_ids,