Black preview (#17217)
* Black preview * Fixup too! * Fix check copies * Use the same version as the CI * Bump black
This commit is contained in:
@@ -592,7 +592,7 @@ class Matcher(object):
|
||||
|
||||
match_labels = matches.new_full(matches.size(), 1, dtype=torch.int8)
|
||||
|
||||
for (l, low, high) in zip(self.labels, self.thresholds[:-1], self.thresholds[1:]):
|
||||
for l, low, high in zip(self.labels, self.thresholds[:-1], self.thresholds[1:]):
|
||||
low_high = (matched_vals >= low) & (matched_vals < high)
|
||||
match_labels[low_high] = l
|
||||
|
||||
@@ -1037,9 +1037,9 @@ class ResNet(Backbone):
|
||||
curr_kwargs = {}
|
||||
for k, v in kwargs.items():
|
||||
if k.endswith("_per_block"):
|
||||
assert len(v) == num_blocks, (
|
||||
f"Argument '{k}' of make_stage should have the " f"same length as num_blocks={num_blocks}."
|
||||
)
|
||||
assert (
|
||||
len(v) == num_blocks
|
||||
), f"Argument '{k}' of make_stage should have the same length as num_blocks={num_blocks}."
|
||||
newk = k[: -len("_per_block")]
|
||||
assert newk not in kwargs, f"Cannot call make_stage with both {k} and {newk}!"
|
||||
curr_kwargs[newk] = v[i]
|
||||
@@ -1401,7 +1401,7 @@ class AnchorGenerator(nn.Module):
|
||||
|
||||
def grid_anchors(self, grid_sizes):
|
||||
anchors = []
|
||||
for (size, stride, base_anchors) in zip(grid_sizes, self.strides, self.cell_anchors):
|
||||
for size, stride, base_anchors in zip(grid_sizes, self.strides, self.cell_anchors):
|
||||
shift_x, shift_y = _create_grid_offsets(size, stride, self.offset, base_anchors.device)
|
||||
shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
|
||||
|
||||
@@ -1708,10 +1708,9 @@ class GeneralizedRCNN(nn.Module):
|
||||
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||
archive_file = pretrained_model_name_or_path
|
||||
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
|
||||
assert (
|
||||
from_tf
|
||||
), "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format(
|
||||
pretrained_model_name_or_path + ".index"
|
||||
assert from_tf, (
|
||||
"We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint"
|
||||
.format(pretrained_model_name_or_path + ".index")
|
||||
)
|
||||
archive_file = pretrained_model_name_or_path + ".index"
|
||||
else:
|
||||
@@ -1797,26 +1796,28 @@ class GeneralizedRCNN(nn.Module):
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
print(
|
||||
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
|
||||
f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
|
||||
f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
|
||||
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n"
|
||||
f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
|
||||
f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
|
||||
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
|
||||
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
|
||||
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
|
||||
" with another architecture (e.g. initializing a BertForSequenceClassification model from a"
|
||||
" BertForPreTraining model).\n- This IS NOT expected if you are initializing"
|
||||
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical"
|
||||
" (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
|
||||
)
|
||||
else:
|
||||
print(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
|
||||
if len(missing_keys) > 0:
|
||||
print(
|
||||
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
|
||||
f"and are newly initialized: {missing_keys}\n"
|
||||
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
||||
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
||||
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
|
||||
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
|
||||
f"If your task is similar to the task the model of the checkpoint was trained on, "
|
||||
f"you can already use {model.__class__.__name__} for predictions without further training."
|
||||
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
|
||||
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
|
||||
f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
|
||||
" training."
|
||||
)
|
||||
if len(error_msgs) > 0:
|
||||
raise RuntimeError(
|
||||
|
||||
@@ -231,9 +231,10 @@ def compare(in_tensor):
|
||||
n2 = out_tensor.numpy()[0]
|
||||
print(n1.shape, n1[0, 0, :5])
|
||||
print(n2.shape, n2[0, 0, :5])
|
||||
assert np.allclose(
|
||||
n1, n2, rtol=0.01, atol=0.1
|
||||
), f"{sum([1 for x in np.isclose(n1, n2, rtol=0.01, atol=0.1).flatten() if x == False])/len(n1.flatten())*100:.4f} % element-wise mismatch"
|
||||
assert np.allclose(n1, n2, rtol=0.01, atol=0.1), (
|
||||
f"{sum([1 for x in np.isclose(n1, n2, rtol=0.01, atol=0.1).flatten() if x == False])/len(n1.flatten())*100:.4f} %"
|
||||
" element-wise mismatch"
|
||||
)
|
||||
raise Exception("tensors are all good")
|
||||
|
||||
# Hugging face functions below
|
||||
|
||||
Reference in New Issue
Block a user