zero-shot pipeline multi_class -> multi_label (#10727)
This commit is contained in:
@@ -49,7 +49,7 @@ class TeacherModelArguments:
|
||||
teacher_batch_size: Optional[int] = field(
|
||||
default=32, metadata={"help": "Batch size for generating teacher predictions."}
|
||||
)
|
||||
multi_class: Optional[bool] = field(
|
||||
multi_label: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": (
|
||||
@@ -163,7 +163,7 @@ def get_teacher_predictions(
|
||||
hypothesis_template: str,
|
||||
batch_size: int,
|
||||
temperature: float,
|
||||
multi_class: bool,
|
||||
multi_label: bool,
|
||||
use_fast_tokenizer: bool,
|
||||
no_cuda: bool,
|
||||
fp16: bool,
|
||||
@@ -203,7 +203,7 @@ def get_teacher_predictions(
|
||||
logits = torch.cat(logits, dim=0) # N*K x 3
|
||||
nli_logits = logits.reshape(len(examples), len(class_names), -1)[..., [contr_id, entail_id]] # N x K x 2
|
||||
|
||||
if multi_class:
|
||||
if multi_label:
|
||||
# softmax over (contr, entail) logits for each class independently
|
||||
nli_prob = (nli_logits / temperature).softmax(-1)
|
||||
else:
|
||||
@@ -285,7 +285,7 @@ def main():
|
||||
teacher_args.hypothesis_template,
|
||||
teacher_args.teacher_batch_size,
|
||||
teacher_args.temperature,
|
||||
teacher_args.multi_class,
|
||||
teacher_args.multi_label,
|
||||
data_args.use_fast_tokenizer,
|
||||
training_args.no_cuda,
|
||||
training_args.fp16,
|
||||
|
||||
Reference in New Issue
Block a user