Apply ruff flake8-comprehensions (#21694)
This commit is contained in:
@@ -343,13 +343,13 @@ def main():
|
||||
if "train" in datasets:
|
||||
if not is_regression and config.label2id != PretrainedConfig(num_labels=num_labels).label2id:
|
||||
label_name_to_id = config.label2id
|
||||
if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
|
||||
if sorted(label_name_to_id.keys()) == sorted(label_list):
|
||||
label_to_id = label_name_to_id # Use the model's labels
|
||||
else:
|
||||
logger.warning(
|
||||
"Your model seems to have been trained with labels, but they don't match the dataset: ",
|
||||
f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels:"
|
||||
f" {list(sorted(label_list))}.\nIgnoring the model labels as a result.",
|
||||
f"model labels: {sorted(label_name_to_id.keys())}, dataset labels:"
|
||||
f" {sorted(label_list)}.\nIgnoring the model labels as a result.",
|
||||
)
|
||||
label_to_id = {v: i for i, v in enumerate(label_list)}
|
||||
elif not is_regression:
|
||||
@@ -411,7 +411,7 @@ def main():
|
||||
dataset_options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF
|
||||
num_replicas = training_args.strategy.num_replicas_in_sync
|
||||
|
||||
tf_data = dict()
|
||||
tf_data = {}
|
||||
max_samples = {
|
||||
"train": data_args.max_train_samples,
|
||||
"validation": data_args.max_val_samples,
|
||||
|
||||
Reference in New Issue
Block a user