Fix ROUGE add example check and update README (#18398)
* Fix ROUGE add example check and update README * Stay consistent in values
This commit is contained in:
@@ -625,12 +625,9 @@ def main():
|
||||
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
|
||||
|
||||
result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
|
||||
# Extract a few results from ROUGE
|
||||
result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
|
||||
|
||||
result = {k: round(v * 100, 4) for k, v in result.items()}
|
||||
prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
|
||||
result["gen_len"] = np.mean(prediction_lens)
|
||||
result = {k: round(v, 4) for k, v in result.items()}
|
||||
return result
|
||||
|
||||
# Initialize our Trainer
|
||||
|
||||
@@ -51,10 +51,13 @@ from transformers import (
|
||||
SchedulerType,
|
||||
get_scheduler,
|
||||
)
|
||||
from transformers.utils import get_full_repo_name, is_offline_mode, send_example_telemetry
|
||||
from transformers.utils import check_min_version, get_full_repo_name, is_offline_mode, send_example_telemetry
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.22.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
@@ -707,10 +710,7 @@ def main():
|
||||
references=decoded_labels,
|
||||
)
|
||||
result = metric.compute(use_stemmer=True)
|
||||
# Extract a few results from ROUGE
|
||||
result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
|
||||
|
||||
result = {k: round(v, 4) for k, v in result.items()}
|
||||
result = {k: round(v * 100, 4) for k, v in result.items()}
|
||||
|
||||
logger.info(result)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user