prepare for "__floordiv__ is deprecated and its behavior will change in a future version of pytorch" (#20211)

* rounding_mode = "floor"  instead of // to prevent behavioral change

* add other TODO

* use `torch_int_div` from pytrch_utils

* same for tests

* fix copies

* style

* use relative imports when needed

* Co-authored-by: sgugger <sylvain.gugger@gmail.com>
This commit is contained in:
Arthur
2023-03-01 10:49:21 +01:00
committed by GitHub
parent b29e2dcaff
commit 44e3e3fb49
15 changed files with 53 additions and 51 deletions

View File

@@ -32,6 +32,7 @@ if is_torch_available():
DisjunctiveConstraint,
PhrasalConstraint,
)
from transformers.pytorch_utils import torch_int_div
class BeamSearchTester:
@@ -160,10 +161,8 @@ class BeamSearchTester:
expected_output_scores = cut_expected_tensor(next_scores)
# add num_beams * batch_idx
expected_output_indices = (
cut_expected_tensor(next_indices)
+ (torch.arange(self.num_beams * self.batch_size, device=torch_device) // self.num_beams) * self.num_beams
)
offset = torch_int_div(torch.arange(self.num_beams * self.batch_size, device=torch_device), self.num_beams)
expected_output_indices = cut_expected_tensor(next_indices) + offset * self.num_beams
self.parent.assertListEqual(expected_output_tokens.tolist(), output_tokens.tolist())
self.parent.assertListEqual(expected_output_indices.tolist(), output_indices.tolist())
@@ -399,10 +398,8 @@ class ConstrainedBeamSearchTester:
expected_output_scores = cut_expected_tensor(next_scores)
# add num_beams * batch_idx
expected_output_indices = (
cut_expected_tensor(next_indices)
+ (torch.arange(self.num_beams * self.batch_size, device=torch_device) // self.num_beams) * self.num_beams
)
offset = torch_int_div(torch.arange(self.num_beams * self.batch_size, device=torch_device), self.num_beams)
expected_output_indices = cut_expected_tensor(next_indices) + offset * self.num_beams
self.parent.assertListEqual(expected_output_tokens.tolist(), output_tokens.tolist())
self.parent.assertListEqual(expected_output_indices.tolist(), output_indices.tolist())