fixed typos (issue 27919) (#27920)
* fixed typos (issue 27919) * Update docs/source/en/tasks/knowledge_distillation_for_image_classification.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -61,8 +61,8 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
|
|
||||||
class ImageDistilTrainer(Trainer):
|
class ImageDistilTrainer(Trainer):
|
||||||
def __init__(self, *args, teacher_model=None, **kwargs):
|
def __init__(self, teacher_model=None, student_model=None, temperature=None, lambda_param=None, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(model=student_model, *args, **kwargs)
|
||||||
self.teacher = teacher_model
|
self.teacher = teacher_model
|
||||||
self.student = student_model
|
self.student = student_model
|
||||||
self.loss_function = nn.KLDivLoss(reduction="batchmean")
|
self.loss_function = nn.KLDivLoss(reduction="batchmean")
|
||||||
@@ -164,7 +164,7 @@ trainer = ImageDistilTrainer(
|
|||||||
train_dataset=processed_datasets["train"],
|
train_dataset=processed_datasets["train"],
|
||||||
eval_dataset=processed_datasets["validation"],
|
eval_dataset=processed_datasets["validation"],
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
tokenizer=teacher_extractor,
|
tokenizer=teacher_processor,
|
||||||
compute_metrics=compute_metrics,
|
compute_metrics=compute_metrics,
|
||||||
temperature=5,
|
temperature=5,
|
||||||
lambda_param=0.5
|
lambda_param=0.5
|
||||||
|
|||||||
Reference in New Issue
Block a user