moved temperature wrapper before topP/topK (#8686)

This commit is contained in:
Roman Kalyakin
2020-11-20 19:33:54 +01:00
committed by GitHub
parent 8062fa63c5
commit 2594bd8b73

View File

@@ -244,12 +244,12 @@ class GenerationMixin:
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py` # all samplers can be found in `generation_utils_samplers.py`
if temperature is not None and temperature != 1.0:
warpers.append(TemperatureLogitsWarper(temperature))
if top_k is not None and top_k != 0: if top_k is not None and top_k != 0:
warpers.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1))) warpers.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
if top_p is not None and top_p < 1.0: if top_p is not None and top_p < 1.0:
warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1))) warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
if temperature is not None and temperature != 1.0:
warpers.append(TemperatureLogitsWarper(temperature))
return warpers return warpers
def _get_logits_processor( def _get_logits_processor(