Switch return_dict to True by default. (#8530)
* Use the CI to identify failing tests * Remove from all examples and tests * More default switch * Fixes * More test fixes * More fixes * Last fixes hopefully * Use the CI to identify failing tests * Remove from all examples and tests * More default switch * Fixes * More test fixes * More fixes * Last fixes hopefully * Run on the real suite * Fix slow tests
This commit is contained in:
@@ -153,7 +153,6 @@ class SummarizationDistiller(SummarizationModule):
|
||||
output_hidden_states=self.do_calc_hidden_loss,
|
||||
output_attentions=False,
|
||||
use_cache=False,
|
||||
return_dict=True,
|
||||
)
|
||||
lm_logits = student_outputs.logits
|
||||
|
||||
@@ -179,7 +178,6 @@ class SummarizationDistiller(SummarizationModule):
|
||||
input_ids,
|
||||
attention_mask=src_mask,
|
||||
output_hidden_states=self.do_calc_hidden_loss,
|
||||
return_dict=True,
|
||||
)
|
||||
if self.different_base_models:
|
||||
teacher_enc_outputs = all_teacher_encoder_outputs.last_hidden_state
|
||||
@@ -199,7 +197,6 @@ class SummarizationDistiller(SummarizationModule):
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
output_hidden_states=self.do_calc_hidden_loss,
|
||||
use_cache=False, # since we are not passing labels, never let this default to True
|
||||
return_dict=True,
|
||||
)
|
||||
dec_mask = decoder_input_ids.ne(pad_token_id)
|
||||
loss_ce = self.calc_ce_loss(dec_mask, lm_logits, teacher_outputs.logits)
|
||||
|
||||
@@ -185,7 +185,7 @@ class TestSummarizationDistiller(TestCasePlus):
|
||||
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
def test_loss_fn(self):
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(BART_TINY, return_dict=True)
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(BART_TINY)
|
||||
input_ids, mask = model.dummy_inputs["input_ids"], model.dummy_inputs["attention_mask"]
|
||||
target_ids = torch.tensor([[0, 4, 8, 2], [0, 8, 2, 1]], dtype=torch.long, device=model.device)
|
||||
decoder_input_ids = target_ids[:, :-1].contiguous() # Why this line?
|
||||
|
||||
Reference in New Issue
Block a user