[style] consistent nn. and nn.functional: part 4 examples (#12156)
* consistent nn. and nn.functional: p4 examples * restore
This commit is contained in:
@@ -8,6 +8,7 @@ from pathlib import Path
|
||||
import pytest
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
import lightning_base
|
||||
from convert_pl_checkpoint_to_hf import convert_pl_to_hf
|
||||
@@ -183,7 +184,7 @@ class TestSummarizationDistiller(TestCasePlus):
|
||||
|
||||
logits = model(input_ids, attention_mask=mask, decoder_input_ids=decoder_input_ids, use_cache=False).logits
|
||||
|
||||
lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
|
||||
lprobs = nn.functional.log_softmax(logits, dim=-1)
|
||||
smoothed_loss, nll_loss = label_smoothed_nll_loss(
|
||||
lprobs, lm_labels, 0.1, ignore_index=model.config.pad_token_id
|
||||
)
|
||||
|
||||
@@ -10,7 +10,6 @@ from typing import List
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from finetune import SummarizationModule, TranslationModule
|
||||
from finetune import main as ft_main
|
||||
@@ -123,8 +122,8 @@ class SummarizationDistiller(SummarizationModule):
|
||||
assert t_logits_slct.size() == s_logits_slct.size()
|
||||
loss_ce = (
|
||||
self.ce_loss_fct(
|
||||
F.log_softmax(s_logits_slct / self.temperature, dim=-1),
|
||||
F.softmax(t_logits_slct / self.temperature, dim=-1),
|
||||
nn.functional.log_softmax(s_logits_slct / self.temperature, dim=-1),
|
||||
nn.functional.softmax(t_logits_slct / self.temperature, dim=-1),
|
||||
)
|
||||
* (self.temperature) ** 2
|
||||
)
|
||||
@@ -160,10 +159,10 @@ class SummarizationDistiller(SummarizationModule):
|
||||
assert lm_logits.shape[-1] == self.model.config.vocab_size
|
||||
if self.hparams.label_smoothing == 0:
|
||||
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
|
||||
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
|
||||
loss_fct = nn.CrossEntropyLoss(ignore_index=pad_token_id)
|
||||
student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1))
|
||||
else:
|
||||
lprobs = F.log_softmax(lm_logits, dim=-1)
|
||||
lprobs = nn.functional.log_softmax(lm_logits, dim=-1)
|
||||
student_lm_loss, _ = label_smoothed_nll_loss(
|
||||
lprobs, labels, self.hparams.label_smoothing, ignore_index=pad_token_id
|
||||
)
|
||||
@@ -230,9 +229,9 @@ class SummarizationDistiller(SummarizationModule):
|
||||
teacher_states = torch.stack([hidden_states_T[j] for j in matches])
|
||||
assert student_states.shape == teacher_states.shape, f"{student_states.shape} != {teacher_states.shape}"
|
||||
if normalize_hidden:
|
||||
student_states = F.layer_norm(student_states, student_states.shape[1:])
|
||||
teacher_states = F.layer_norm(teacher_states, teacher_states.shape[1:])
|
||||
mse = F.mse_loss(student_states, teacher_states, reduction="none")
|
||||
student_states = nn.functional.layer_norm(student_states, student_states.shape[1:])
|
||||
teacher_states = nn.functional.layer_norm(teacher_states, teacher_states.shape[1:])
|
||||
mse = nn.functional.mse_loss(student_states, teacher_states, reduction="none")
|
||||
masked_mse = (mse * mask.unsqueeze(0).unsqueeze(-1)).sum() / valid_count
|
||||
return masked_mse
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ from typing import Dict, List, Tuple
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
|
||||
@@ -151,12 +152,12 @@ class SummarizationModule(BaseTransformer):
|
||||
lm_logits = outputs["logits"]
|
||||
if self.hparams.label_smoothing == 0:
|
||||
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
|
||||
ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
|
||||
ce_loss_fct = nn.CrossEntropyLoss(ignore_index=pad_token_id)
|
||||
|
||||
assert lm_logits.shape[-1] == self.vocab_size
|
||||
loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))
|
||||
else:
|
||||
lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
|
||||
lprobs = nn.functional.log_softmax(lm_logits, dim=-1)
|
||||
loss, nll_loss = label_smoothed_nll_loss(
|
||||
lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user