fix bert2bert test (#10063)
This commit is contained in:
committed by
GitHub
parent
31563e056d
commit
9e795eac88
@@ -24,15 +24,9 @@ if is_datasets_available():
|
|||||||
|
|
||||||
class Seq2seqTrainerTester(TestCasePlus):
|
class Seq2seqTrainerTester(TestCasePlus):
|
||||||
@slow
|
@slow
|
||||||
@require_datasets
|
|
||||||
@require_torch
|
@require_torch
|
||||||
|
@require_datasets
|
||||||
def test_finetune_bert2bert(self):
|
def test_finetune_bert2bert(self):
|
||||||
"""
|
|
||||||
Currently fails with:
|
|
||||||
|
|
||||||
ImportError: To be able to use this metric, you need to install the following dependencies['absl', 'nltk', 'rouge_score']
|
|
||||||
"""
|
|
||||||
|
|
||||||
bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("prajjwal1/bert-tiny", "prajjwal1/bert-tiny")
|
bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("prajjwal1/bert-tiny", "prajjwal1/bert-tiny")
|
||||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
||||||
|
|
||||||
@@ -47,8 +41,6 @@ class Seq2seqTrainerTester(TestCasePlus):
|
|||||||
train_dataset = train_dataset.select(range(32))
|
train_dataset = train_dataset.select(range(32))
|
||||||
val_dataset = val_dataset.select(range(16))
|
val_dataset = val_dataset.select(range(16))
|
||||||
|
|
||||||
rouge = datasets.load_metric("rouge")
|
|
||||||
|
|
||||||
batch_size = 4
|
batch_size = 4
|
||||||
|
|
||||||
def _map_to_encoder_decoder_inputs(batch):
|
def _map_to_encoder_decoder_inputs(batch):
|
||||||
@@ -78,15 +70,9 @@ class Seq2seqTrainerTester(TestCasePlus):
|
|||||||
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
||||||
label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
|
label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])[
|
accuracy = sum([int(pred_str[i] == label_str[i]) for i in range(len(pred_str))]) / len(pred_str)
|
||||||
"rouge2"
|
|
||||||
].mid
|
|
||||||
|
|
||||||
return {
|
return {"accuracy": accuracy}
|
||||||
"rouge2_precision": round(rouge_output.precision, 4),
|
|
||||||
"rouge2_recall": round(rouge_output.recall, 4),
|
|
||||||
"rouge2_fmeasure": round(rouge_output.fmeasure, 4),
|
|
||||||
}
|
|
||||||
|
|
||||||
# map train dataset
|
# map train dataset
|
||||||
train_dataset = train_dataset.map(
|
train_dataset = train_dataset.map(
|
||||||
|
|||||||
Reference in New Issue
Block a user