|
|
|
|
@@ -1,5 +1,4 @@
|
|
|
|
|
#!/usr/bin/env python
|
|
|
|
|
# coding=utf-8
|
|
|
|
|
# Copyright 2020 The HuggingFace Inc. team. All rights reserved.
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
@@ -21,7 +20,7 @@ import os
|
|
|
|
|
import random
|
|
|
|
|
import sys
|
|
|
|
|
from dataclasses import dataclass, field
|
|
|
|
|
from typing import List, Optional
|
|
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
|
|
import datasets
|
|
|
|
|
import evaluate
|
|
|
|
|
@@ -256,7 +255,7 @@ class ModelArguments:
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_label_list(raw_dataset, split="train") -> List[str]:
|
|
|
|
|
def get_label_list(raw_dataset, split="train") -> list[str]:
|
|
|
|
|
"""Get the list of labels from a multi-label dataset"""
|
|
|
|
|
|
|
|
|
|
if isinstance(raw_dataset[split]["label"][0], list):
|
|
|
|
|
@@ -537,7 +536,7 @@ def main():
|
|
|
|
|
model.config.id2label = {id: label for label, id in label_to_id.items()}
|
|
|
|
|
elif not is_regression: # classification, but not training
|
|
|
|
|
logger.info("using label infos in the model config")
|
|
|
|
|
logger.info("label2id: {}".format(model.config.label2id))
|
|
|
|
|
logger.info(f"label2id: {model.config.label2id}")
|
|
|
|
|
label_to_id = model.config.label2id
|
|
|
|
|
else: # regression
|
|
|
|
|
label_to_id = None
|
|
|
|
|
@@ -549,7 +548,7 @@ def main():
|
|
|
|
|
)
|
|
|
|
|
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
|
|
|
|
|
|
|
|
|
def multi_labels_to_ids(labels: List[str]) -> List[float]:
|
|
|
|
|
def multi_labels_to_ids(labels: list[str]) -> list[float]:
|
|
|
|
|
ids = [0.0] * len(label_to_id) # BCELoss requires float as target type
|
|
|
|
|
for label in labels:
|
|
|
|
|
ids[label_to_id[label]] = 1.0
|
|
|
|
|
@@ -735,7 +734,7 @@ def main():
|
|
|
|
|
else:
|
|
|
|
|
item = label_list[item]
|
|
|
|
|
writer.write(f"{index}\t{item}\n")
|
|
|
|
|
logger.info("Predict results saved at {}".format(output_predict_file))
|
|
|
|
|
logger.info(f"Predict results saved at {output_predict_file}")
|
|
|
|
|
kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-classification"}
|
|
|
|
|
|
|
|
|
|
if training_args.push_to_hub:
|
|
|
|
|
|