Apply ruff flake8-comprehensions (#21694)

This commit is contained in:
Aaron Gokaslan
2023-02-22 03:14:54 -05:00
committed by GitHub
parent df06fb1f0b
commit 5e8c8eb5ba
230 changed files with 971 additions and 955 deletions

View File

@@ -310,12 +310,12 @@ def main():
if config.label2id != PretrainedConfig(num_labels=num_labels).label2id and not is_regression:
# Some have all caps in their config, some don't.
label_name_to_id = {k.lower(): v for k, v in config.label2id.items()}
if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
if sorted(label_name_to_id.keys()) == sorted(label_list):
label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)}
else:
logger.warning(
"Your model seems to have been trained with labels, but they don't match the dataset: ",
f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
f"model labels: {sorted(label_name_to_id.keys())}, dataset labels: {sorted(label_list)}."
"\nIgnoring the model labels as a result.",
)
label_to_id = {label: i for i, label in enumerate(label_list)}
@@ -383,7 +383,7 @@ def main():
dataset_options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF
num_replicas = training_args.strategy.num_replicas_in_sync
tf_data = dict()
tf_data = {}
max_samples = {
"train": data_args.max_train_samples,
"validation": data_args.max_eval_samples,