[style] consistent nn. and nn.functional: part 3 tests (#12155)

* consistent nn. and nn.functional: p3 templates

* restore
This commit is contained in:
Stas Bekman
2021-06-14 12:18:22 -07:00
committed by GitHub
parent d9c0d08f9a
commit 372ab9cd6d
14 changed files with 93 additions and 81 deletions

View File

@@ -24,7 +24,7 @@ from .test_modeling_common import ids_tensor
if is_torch_available():
import torch
import torch.nn.functional as F
from torch import nn
from transformers.generation_logits_process import (
EncoderNoRepeatNGramLogitsProcessor,
@@ -80,13 +80,13 @@ class LogitsProcessorTest(unittest.TestCase):
scores[1, 10] = (1 / length) - 0.4 # valley, 1st batch
# compute softmax
probs = F.softmax(scores, dim=-1)
probs = nn.functional.softmax(scores, dim=-1)
temp_dist_warper_sharper = TemperatureLogitsWarper(temperature=0.5)
temp_dist_warper_smoother = TemperatureLogitsWarper(temperature=1.3)
warped_prob_sharp = F.softmax(temp_dist_warper_sharper(input_ids, scores.clone()), dim=-1)
warped_prob_smooth = F.softmax(temp_dist_warper_smoother(input_ids, scores.clone()), dim=-1)
warped_prob_sharp = nn.functional.softmax(temp_dist_warper_sharper(input_ids, scores.clone()), dim=-1)
warped_prob_smooth = nn.functional.softmax(temp_dist_warper_smoother(input_ids, scores.clone()), dim=-1)
# uniform distribution stays uniform
self.assertTrue(torch.allclose(probs[0, :], warped_prob_sharp[0, :], atol=1e-3))