zero-shot pipeline multi_class -> multi_label (#10727)

This commit is contained in:
Joe Davison
2021-03-15 16:02:46 -06:00
committed by GitHub
parent 58f672e65c
commit 966ba081c9
3 changed files with 17 additions and 9 deletions

View File

@@ -49,7 +49,7 @@ class TeacherModelArguments:
teacher_batch_size: Optional[int] = field(
default=32, metadata={"help": "Batch size for generating teacher predictions."}
)
multi_class: Optional[bool] = field(
multi_label: Optional[bool] = field(
default=False,
metadata={
"help": (
@@ -163,7 +163,7 @@ def get_teacher_predictions(
hypothesis_template: str,
batch_size: int,
temperature: float,
multi_class: bool,
multi_label: bool,
use_fast_tokenizer: bool,
no_cuda: bool,
fp16: bool,
@@ -203,7 +203,7 @@ def get_teacher_predictions(
logits = torch.cat(logits, dim=0) # N*K x 3
nli_logits = logits.reshape(len(examples), len(class_names), -1)[..., [contr_id, entail_id]] # N x K x 2
if multi_class:
if multi_label:
# softmax over (contr, entail) logits for each class independently
nli_prob = (nli_logits / temperature).softmax(-1)
else:
@@ -285,7 +285,7 @@ def main():
teacher_args.hypothesis_template,
teacher_args.teacher_batch_size,
teacher_args.temperature,
teacher_args.multi_class,
teacher_args.multi_label,
data_args.use_fast_tokenizer,
training_args.no_cuda,
training_args.fp16,