prepare for "__floordiv__ is deprecated and its behavior will change in a future version of pytorch" (#20211)
* rounding_mode = "floor" instead of // to prevent behavioral change * add other TODO * use `torch_int_div` from pytrch_utils * same for tests * fix copies * style * use relative imports when needed * Co-authored-by: sgugger <sylvain.gugger@gmail.com>
This commit is contained in:
@@ -32,6 +32,7 @@ if is_torch_available():
|
||||
DisjunctiveConstraint,
|
||||
PhrasalConstraint,
|
||||
)
|
||||
from transformers.pytorch_utils import torch_int_div
|
||||
|
||||
|
||||
class BeamSearchTester:
|
||||
@@ -160,10 +161,8 @@ class BeamSearchTester:
|
||||
expected_output_scores = cut_expected_tensor(next_scores)
|
||||
|
||||
# add num_beams * batch_idx
|
||||
expected_output_indices = (
|
||||
cut_expected_tensor(next_indices)
|
||||
+ (torch.arange(self.num_beams * self.batch_size, device=torch_device) // self.num_beams) * self.num_beams
|
||||
)
|
||||
offset = torch_int_div(torch.arange(self.num_beams * self.batch_size, device=torch_device), self.num_beams)
|
||||
expected_output_indices = cut_expected_tensor(next_indices) + offset * self.num_beams
|
||||
|
||||
self.parent.assertListEqual(expected_output_tokens.tolist(), output_tokens.tolist())
|
||||
self.parent.assertListEqual(expected_output_indices.tolist(), output_indices.tolist())
|
||||
@@ -399,10 +398,8 @@ class ConstrainedBeamSearchTester:
|
||||
expected_output_scores = cut_expected_tensor(next_scores)
|
||||
|
||||
# add num_beams * batch_idx
|
||||
expected_output_indices = (
|
||||
cut_expected_tensor(next_indices)
|
||||
+ (torch.arange(self.num_beams * self.batch_size, device=torch_device) // self.num_beams) * self.num_beams
|
||||
)
|
||||
offset = torch_int_div(torch.arange(self.num_beams * self.batch_size, device=torch_device), self.num_beams)
|
||||
expected_output_indices = cut_expected_tensor(next_indices) + offset * self.num_beams
|
||||
|
||||
self.parent.assertListEqual(expected_output_tokens.tolist(), output_tokens.tolist())
|
||||
self.parent.assertListEqual(expected_output_indices.tolist(), output_indices.tolist())
|
||||
|
||||
@@ -71,7 +71,7 @@ if is_torch_available():
|
||||
_compute_mask_indices,
|
||||
_sample_negative_indices,
|
||||
)
|
||||
from transformers.pytorch_utils import is_torch_less_than_1_9
|
||||
from transformers.pytorch_utils import is_torch_less_than_1_9, torch_int_div
|
||||
else:
|
||||
is_torch_less_than_1_9 = True
|
||||
|
||||
@@ -1217,10 +1217,8 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
|
||||
sequence_length = 10
|
||||
hidden_size = 4
|
||||
num_negatives = 3
|
||||
|
||||
features = (torch.arange(sequence_length * hidden_size, device=torch_device) // hidden_size).view(
|
||||
sequence_length, hidden_size
|
||||
) # each value in vector consits of same value
|
||||
sequence = torch_int_div(torch.arange(sequence_length * hidden_size, device=torch_device), hidden_size)
|
||||
features = sequence.view(sequence_length, hidden_size) # each value in vector consits of same value
|
||||
features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous()
|
||||
|
||||
# sample negative indices
|
||||
@@ -1247,9 +1245,8 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
|
||||
mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
|
||||
mask[-1, sequence_length // 2 :] = 0
|
||||
|
||||
features = (torch.arange(sequence_length * hidden_size, device=torch_device) // hidden_size).view(
|
||||
sequence_length, hidden_size
|
||||
) # each value in vector consits of same value
|
||||
sequence = torch_int_div(torch.arange(sequence_length * hidden_size, device=torch_device), hidden_size)
|
||||
features = sequence.view(sequence_length, hidden_size) # each value in vector consits of same value
|
||||
features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous()
|
||||
|
||||
# replace masked feature vectors with -100 to test that those are not sampled
|
||||
|
||||
Reference in New Issue
Block a user