Upgrade black to version ~=22.0 (#15565)
* Upgrade black to version ~=22.0 * Check copies * Fix code
This commit is contained in:
@@ -102,7 +102,7 @@ class BeamSearchTester:
|
||||
beam_hyp.add(input_ids[beam_idx], -10.0 + float(beam_idx))
|
||||
|
||||
# -10.0 is removed => -9.0 is worst score
|
||||
self.parent.assertAlmostEqual(beam_hyp.worst_score, -9.0 / (self.sequence_length ** beam_hyp.length_penalty))
|
||||
self.parent.assertAlmostEqual(beam_hyp.worst_score, -9.0 / (self.sequence_length**beam_hyp.length_penalty))
|
||||
|
||||
# -5.0 is better than worst score => should not be finished
|
||||
self.parent.assertFalse(beam_hyp.is_done(-5.0, self.sequence_length))
|
||||
|
||||
@@ -544,7 +544,7 @@ class IBertModelIntegrationTest(unittest.TestCase):
|
||||
self.assertTrue(torch.allclose(q_int, q_int.round(), atol=1e-4))
|
||||
|
||||
# Output of the quantize Softmax should not exceed the output_bit
|
||||
self.assertTrue(q.abs().max() < 2 ** output_bit)
|
||||
self.assertTrue(q.abs().max() < 2**output_bit)
|
||||
|
||||
array = [[i + j for j in range(10)] for i in range(-10, 10)]
|
||||
_test(array)
|
||||
|
||||
@@ -252,7 +252,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
# check that output_attentions also work using config
|
||||
del inputs_dict["output_attentions"]
|
||||
config.output_attentions = True
|
||||
window_size_squared = config.window_size ** 2
|
||||
window_size_squared = config.window_size**2
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
@@ -134,7 +134,7 @@ class ViTMAEModelTester:
|
||||
patch_size = to_2tuple(self.patch_size)
|
||||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
||||
expected_seq_len = num_patches
|
||||
expected_num_channels = self.patch_size ** 2 * self.num_channels
|
||||
expected_num_channels = self.patch_size**2 * self.num_channels
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, expected_seq_len, expected_num_channels))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
|
||||
@@ -68,7 +68,8 @@ class CopyCheckTester(unittest.TestCase):
|
||||
code = comment + f"\nclass {class_name}(nn.Module):\n" + class_code
|
||||
if overwrite_result is not None:
|
||||
expected = comment + f"\nclass {class_name}(nn.Module):\n" + overwrite_result
|
||||
code = black.format_str(code, mode=black.FileMode([black.TargetVersion.PY35], line_length=119))
|
||||
mode = black.Mode(target_versions={black.TargetVersion.PY35}, line_length=119)
|
||||
code = black.format_str(code, mode=mode)
|
||||
fname = os.path.join(self.transformer_dir, "new_code.py")
|
||||
with open(fname, "w", newline="\n") as f:
|
||||
f.write(code)
|
||||
|
||||
Reference in New Issue
Block a user