diff --git a/examples/run_generation.py b/examples/run_generation.py index 13685c946c..ef58cfd844 100644 --- a/examples/run_generation.py +++ b/examples/run_generation.py @@ -139,7 +139,7 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k= next_token_logits = outputs[0][0, -1, :] / (temperature if temperature > 0 else 1.) # reptition penalty from CTRL (https://arxiv.org/abs/1909.05858) - for _ in set(generated): + for _ in set(generated.view(-1).tolist()): next_token_logits[_] /= repetition_penalty filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)