Update ruff to 0.11.2 (#36962)
* update * update * update --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -2509,9 +2509,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
total_decoder_name="",
|
||||
total_encoder_name="",
|
||||
):
|
||||
assert isinstance(decoder_pointer, nn.Module) and isinstance(
|
||||
encoder_pointer, nn.Module
|
||||
), f"{decoder_pointer} and {encoder_pointer} have to be of type nn.Module"
|
||||
assert isinstance(decoder_pointer, nn.Module) and isinstance(encoder_pointer, nn.Module), (
|
||||
f"{decoder_pointer} and {encoder_pointer} have to be of type nn.Module"
|
||||
)
|
||||
if hasattr(decoder_pointer, "weight"):
|
||||
assert hasattr(encoder_pointer, "weight")
|
||||
encoder_pointer.weight = decoder_pointer.weight
|
||||
@@ -2525,9 +2525,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
encoder_modules = encoder_pointer._modules
|
||||
decoder_modules = decoder_pointer._modules
|
||||
if len(decoder_modules) > 0:
|
||||
assert (
|
||||
len(encoder_modules) > 0
|
||||
), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
|
||||
assert len(encoder_modules) > 0, (
|
||||
f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
|
||||
)
|
||||
|
||||
all_encoder_weights = {module_name + "/" + sub_name for sub_name in encoder_modules.keys()}
|
||||
encoder_layer_pos = 0
|
||||
@@ -3571,7 +3571,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
f"Please upgrade accelerate with `pip install -U accelerate`"
|
||||
)
|
||||
# init state_dict for this shard
|
||||
shard_state_dict = {name: "" for name in shard}
|
||||
shard_state_dict = dict.fromkeys(shard, "")
|
||||
for module_name in shard:
|
||||
# skip to collect this weight again
|
||||
if shard_state_dict.get(module_name) != "":
|
||||
@@ -4814,7 +4814,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
param_device_map = expand_device_map(device_map, checkpoint_keys)
|
||||
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
|
||||
if sharded_metadata is None:
|
||||
weight_map = {p: checkpoint_files[0] for p in checkpoint_keys}
|
||||
weight_map = dict.fromkeys(checkpoint_keys, checkpoint_files[0])
|
||||
else:
|
||||
folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1])
|
||||
# Fix the weight map keys according to the key mapping
|
||||
@@ -5446,9 +5446,9 @@ class PoolerEndLogits(nn.Module):
|
||||
Returns:
|
||||
`torch.FloatTensor`: The end logits for SQuAD.
|
||||
"""
|
||||
assert (
|
||||
start_states is not None or start_positions is not None
|
||||
), "One of start_states, start_positions should be not None"
|
||||
assert start_states is not None or start_positions is not None, (
|
||||
"One of start_states, start_positions should be not None"
|
||||
)
|
||||
if start_positions is not None:
|
||||
slen, hsz = hidden_states.shape[-2:]
|
||||
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
|
||||
@@ -5514,9 +5514,9 @@ class PoolerAnswerClass(nn.Module):
|
||||
"""
|
||||
# No dependency on end_feature so that we can obtain one single `cls_logits` for each sample.
|
||||
hsz = hidden_states.shape[-1]
|
||||
assert (
|
||||
start_states is not None or start_positions is not None
|
||||
), "One of start_states, start_positions should be not None"
|
||||
assert start_states is not None or start_positions is not None, (
|
||||
"One of start_states, start_positions should be not None"
|
||||
)
|
||||
if start_positions is not None:
|
||||
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
|
||||
start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz)
|
||||
|
||||
Reference in New Issue
Block a user