[Model card] Bert2Bert
Add Rouge2 results
This commit is contained in:
committed by
GitHub
parent
9d37c56bab
commit
12f14710ce
@@ -125,12 +125,10 @@ def compute_metrics(pred):
|
|||||||
labels_ids = pred.label_ids
|
labels_ids = pred.label_ids
|
||||||
pred_ids = pred.predictions
|
pred_ids = pred.predictions
|
||||||
|
|
||||||
pred_str = tokenizer.batch_decode(pred_ids, clean_special_tokens=True)
|
# all unnecessary tokens are removed
|
||||||
label_str = tokenizer.batch_decode(labels_ids, clean_special_tokens=True)
|
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
||||||
|
label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
|
||||||
pred_str = [pred.split("[CLS]")[-1].split("[SEP]")[0] for pred in pred_str]
|
|
||||||
label_str = [label.split("[CLS]")[-1].split("[SEP]")[0] for label in label_str]
|
|
||||||
|
|
||||||
rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid
|
rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -189,6 +187,43 @@ trainer = Trainer(
|
|||||||
trainer.train()
|
trainer.train()
|
||||||
```
|
```
|
||||||
|
|
||||||
## Results
|
## Evaluation
|
||||||
|
|
||||||
TODO
|
The following script evaluates the model on the test set of
|
||||||
|
CNN/Daily Mail.
|
||||||
|
|
||||||
|
```python
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
import nlp
|
||||||
|
from transformers import BertTokenizer, EncoderDecoderModel
|
||||||
|
tokenizer = BertTokenizer.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")
|
||||||
|
model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")
|
||||||
|
model.to("cuda")
|
||||||
|
test_dataset = nlp.load_dataset("cnn_dailymail", "3.0.0", split="test")
|
||||||
|
batch_size = 128
|
||||||
|
# map data correctly
|
||||||
|
def generate_summary(batch):
|
||||||
|
# Tokenizer will automatically set [BOS] <text> [EOS]
|
||||||
|
# cut off at BERT max length 512
|
||||||
|
inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=512, return_tensors="pt")
|
||||||
|
input_ids = inputs.input_ids.to("cuda")
|
||||||
|
attention_mask = inputs.attention_mask.to("cuda")
|
||||||
|
outputs = model.generate(input_ids, attention_mask=attention_mask)
|
||||||
|
# all special tokens including will be removed
|
||||||
|
output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||||
|
batch["pred"] = output_str
|
||||||
|
return batch
|
||||||
|
results = test_dataset.map(generate_summary, batched=True, batch_size=batch_size, remove_columns=["article"])
|
||||||
|
# load rouge for validation
|
||||||
|
rouge = nlp.load_metric("rouge")
|
||||||
|
pred_str = results["pred"]
|
||||||
|
label_str = results["highlights"]
|
||||||
|
rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid
|
||||||
|
print(rouge_output)
|
||||||
|
```
|
||||||
|
|
||||||
|
The obtained results should be:
|
||||||
|
|
||||||
|
| - | Rouge2 - mid -precision | Rouge2 - mid - recall | Rouge2 - mid - fmeasure |
|
||||||
|
|----------|:-------------:|:------:|:------:|
|
||||||
|
| **CNN/Daily Mail** | 14.12 | 14.37 | **13.8** |
|
||||||
|
|||||||
Reference in New Issue
Block a user