Fix memory leak with CTC training script on Chinese languages (#30358)
* Fix memory leak with CTC training script on Chinese languages * Fix lint
This commit is contained in:
@@ -28,7 +28,6 @@ from typing import Dict, List, Optional, Union
|
|||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import evaluate
|
import evaluate
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import DatasetDict, load_dataset
|
from datasets import DatasetDict, load_dataset
|
||||||
|
|
||||||
@@ -712,10 +711,14 @@ def main():
|
|||||||
logger.info(f"Data preprocessing finished. Files cached at {vectorized_datasets.cache_files}")
|
logger.info(f"Data preprocessing finished. Files cached at {vectorized_datasets.cache_files}")
|
||||||
return
|
return
|
||||||
|
|
||||||
def compute_metrics(pred):
|
# For languages like Chinese with large vocabulary size, we need to discard logits
|
||||||
pred_logits = pred.predictions
|
# and only keep the argmax, otherwise we run out of memory during evaluation.
|
||||||
pred_ids = np.argmax(pred_logits, axis=-1)
|
def preprocess_logits_for_metrics(logits, labels):
|
||||||
|
pred_ids = torch.argmax(logits, dim=-1)
|
||||||
|
return pred_ids, labels
|
||||||
|
|
||||||
|
def compute_metrics(pred):
|
||||||
|
pred_ids = pred.predictions[0]
|
||||||
pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id
|
pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id
|
||||||
|
|
||||||
pred_str = tokenizer.batch_decode(pred_ids)
|
pred_str = tokenizer.batch_decode(pred_ids)
|
||||||
@@ -762,6 +765,7 @@ def main():
|
|||||||
train_dataset=vectorized_datasets["train"] if training_args.do_train else None,
|
train_dataset=vectorized_datasets["train"] if training_args.do_train else None,
|
||||||
eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None,
|
eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None,
|
||||||
tokenizer=processor,
|
tokenizer=processor,
|
||||||
|
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 8. Finally, we can start training
|
# 8. Finally, we can start training
|
||||||
|
|||||||
Reference in New Issue
Block a user