[style] consistent nn. and nn.functional: part 3 tests (#12155)
* consistent nn. and nn.functional: p3 templates * restore
This commit is contained in:
@@ -24,7 +24,7 @@ from .test_modeling_common import ids_tensor
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from transformers.generation_logits_process import (
|
||||
EncoderNoRepeatNGramLogitsProcessor,
|
||||
@@ -80,13 +80,13 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
scores[1, 10] = (1 / length) - 0.4 # valley, 1st batch
|
||||
|
||||
# compute softmax
|
||||
probs = F.softmax(scores, dim=-1)
|
||||
probs = nn.functional.softmax(scores, dim=-1)
|
||||
|
||||
temp_dist_warper_sharper = TemperatureLogitsWarper(temperature=0.5)
|
||||
temp_dist_warper_smoother = TemperatureLogitsWarper(temperature=1.3)
|
||||
|
||||
warped_prob_sharp = F.softmax(temp_dist_warper_sharper(input_ids, scores.clone()), dim=-1)
|
||||
warped_prob_smooth = F.softmax(temp_dist_warper_smoother(input_ids, scores.clone()), dim=-1)
|
||||
warped_prob_sharp = nn.functional.softmax(temp_dist_warper_sharper(input_ids, scores.clone()), dim=-1)
|
||||
warped_prob_smooth = nn.functional.softmax(temp_dist_warper_smoother(input_ids, scores.clone()), dim=-1)
|
||||
|
||||
# uniform distribution stays uniform
|
||||
self.assertTrue(torch.allclose(probs[0, :], warped_prob_sharp[0, :], atol=1e-3))
|
||||
|
||||
Reference in New Issue
Block a user