Fix eval ref miss in Chinese WWM. (#8115)
* ADD: add whole word mask proxy for both eng and chinese * MOD: adjust format * MOD: reformat code * MOD: update import * MOD: fix bug * MOD: add import * MOD: fix bug * MOD: decouple code and update readme * MOD: reformat code * Update examples/language-modeling/README.md Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update examples/language-modeling/README.md Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update examples/language-modeling/run_language_modeling.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update examples/language-modeling/run_language_modeling.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update examples/language-modeling/run_language_modeling.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update examples/language-modeling/run_language_modeling.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * change wwm to whole_word_mask * reformat code * reformat * format * Code quality * ADD: update chinese ref readme * MOD: small changes * MOD: small changes2 * update readme * fix eval ref file miss bug * format file * MOD: move ref code to contrib * MOD: add delimeter check * reformat code * refomat code * Update examples/language-modeling/README.md Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <sylvain.gugger@gmail.com> Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -118,7 +118,7 @@ def main(args):
|
|||||||
# If we want to fine-tune these model, we have to use same tokenizer : LTP (https://github.com/HIT-SCIR/ltp)
|
# If we want to fine-tune these model, we have to use same tokenizer : LTP (https://github.com/HIT-SCIR/ltp)
|
||||||
with open(args.file_name, "r", encoding="utf-8") as f:
|
with open(args.file_name, "r", encoding="utf-8") as f:
|
||||||
data = f.readlines()
|
data = f.readlines()
|
||||||
|
data = [line.strip() for line in data if len(line) > 0 and not line.isspace()] # avoid delimiter like '\u2029'
|
||||||
ltp_tokenizer = LTP(args.ltp) # faster in GPU device
|
ltp_tokenizer = LTP(args.ltp) # faster in GPU device
|
||||||
bert_tokenizer = BertTokenizer.from_pretrained(args.bert)
|
bert_tokenizer = BertTokenizer.from_pretrained(args.bert)
|
||||||
|
|
||||||
@@ -63,7 +63,7 @@ python run_language_modeling.py \
|
|||||||
--whole_word_mask
|
--whole_word_mask
|
||||||
```
|
```
|
||||||
|
|
||||||
For Chinese models, it's same with English model with only --mlm`. If using whole-word masking, we need to generate a reference files, case it's char level.
|
For Chinese models, it's same with English model with only `--mlm`. If using whole-word masking, we need to generate a reference files, because it's char level.
|
||||||
|
|
||||||
**Q :** Why ref file ?
|
**Q :** Why ref file ?
|
||||||
|
|
||||||
@@ -76,15 +76,19 @@ So we need a ref file to tell model which pos of BERT original token should be a
|
|||||||
**A :** Cause the best known Chinese WWM BERT is [Chinese-BERT-wwm](https://github.com/ymcui/Chinese-BERT-wwm) by HIT. It works well on so many Chines Task like CLUE (Chinese GLUE).
|
**A :** Cause the best known Chinese WWM BERT is [Chinese-BERT-wwm](https://github.com/ymcui/Chinese-BERT-wwm) by HIT. It works well on so many Chines Task like CLUE (Chinese GLUE).
|
||||||
They use LTP, so if we want to fine-tune their model, we need LTP.
|
They use LTP, so if we want to fine-tune their model, we need LTP.
|
||||||
|
|
||||||
|
Now LTP only only works well on `transformers==3.2.0`. So we don't add it to requirements.txt.
|
||||||
|
You need to check to `3.2.0` for `run_chinese_ref.py`. And the code could be found in `examples/contrib`.
|
||||||
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
export TRAIN_FILE=/path/to/dataset/wiki.train.raw
|
export TRAIN_FILE=/path/to/dataset/wiki.train.raw
|
||||||
export LTP_RESOURCE=/path/to/ltp/tokenizer
|
export LTP_RESOURCE=/path/to/ltp/tokenizer
|
||||||
export BERT_RESOURCE=/path/to/bert/tokenizer
|
export BERT_RESOURCE=/path/to/bert/tokenizer
|
||||||
export SAVE_PATH=/path/to/data/ref.txt
|
export SAVE_PATH=/path/to/data/ref.txt
|
||||||
|
|
||||||
python chinese_ref.py \
|
python examples/contrib/run_chinese_ref.py \
|
||||||
--file_name=$TRAIN_FILE \
|
--file_name=$TRAIN_FILE \
|
||||||
--ltp=$LTP_RESOURCE
|
--ltp=$LTP_RESOURCE \
|
||||||
--bert=$BERT_RESOURCE \
|
--bert=$BERT_RESOURCE \
|
||||||
--save_path=$SAVE_PATH
|
--save_path=$SAVE_PATH
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -103,9 +103,13 @@ class DataTrainingArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
|
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
|
||||||
)
|
)
|
||||||
chinese_ref_file: Optional[str] = field(
|
train_ref_file: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "An optional input ref data file for whole word mask in Chinees."},
|
metadata={"help": "An optional input train ref data file for whole word mask in Chinese."},
|
||||||
|
)
|
||||||
|
eval_ref_file: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "An optional input eval ref data file for whole word mask in Chinese."},
|
||||||
)
|
)
|
||||||
line_by_line: bool = field(
|
line_by_line: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
@@ -148,16 +152,16 @@ def get_dataset(
|
|||||||
evaluate: bool = False,
|
evaluate: bool = False,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
):
|
):
|
||||||
def _dataset(file_path):
|
def _dataset(file_path, ref_path=None):
|
||||||
if args.line_by_line:
|
if args.line_by_line:
|
||||||
if args.chinese_ref_file is not None:
|
if ref_path is not None:
|
||||||
if not args.whole_word_mask or not args.mlm:
|
if not args.whole_word_mask or not args.mlm:
|
||||||
raise ValueError("You need to set world whole masking and mlm to True for Chinese Whole Word Mask")
|
raise ValueError("You need to set world whole masking and mlm to True for Chinese Whole Word Mask")
|
||||||
return LineByLineWithRefDataset(
|
return LineByLineWithRefDataset(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
file_path=file_path,
|
file_path=file_path,
|
||||||
block_size=args.block_size,
|
block_size=args.block_size,
|
||||||
ref_path=args.chinese_ref_file,
|
ref_path=ref_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size)
|
return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size)
|
||||||
@@ -171,11 +175,11 @@ def get_dataset(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if evaluate:
|
if evaluate:
|
||||||
return _dataset(args.eval_data_file)
|
return _dataset(args.eval_data_file, args.eval_ref_file)
|
||||||
elif args.train_data_files:
|
elif args.train_data_files:
|
||||||
return ConcatDataset([_dataset(f) for f in glob(args.train_data_files)])
|
return ConcatDataset([_dataset(f) for f in glob(args.train_data_files)])
|
||||||
else:
|
else:
|
||||||
return _dataset(args.train_data_file)
|
return _dataset(args.train_data_file, args.train_ref_file)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|||||||
@@ -128,15 +128,17 @@ class LineByLineWithRefDataset(Dataset):
|
|||||||
logger.info("Creating features from dataset file at %s", file_path)
|
logger.info("Creating features from dataset file at %s", file_path)
|
||||||
logger.info("Use ref segment results at %s", ref_path)
|
logger.info("Use ref segment results at %s", ref_path)
|
||||||
with open(file_path, encoding="utf-8") as f:
|
with open(file_path, encoding="utf-8") as f:
|
||||||
data = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
|
data = f.readlines() # use this method to avoid delimiter '\u2029' to split a line
|
||||||
batch_encoding = tokenizer(data, add_special_tokens=True, truncation=True, max_length=block_size)
|
data = [line.strip() for line in data if len(line) > 0 and not line.isspace()]
|
||||||
self.examples = batch_encoding["input_ids"]
|
|
||||||
self.examples = [{"input_ids": torch.tensor(e, dtype=torch.long)} for e in self.examples]
|
|
||||||
|
|
||||||
# Get ref inf from file
|
# Get ref inf from file
|
||||||
with open(ref_path, encoding="utf-8") as f:
|
with open(ref_path, encoding="utf-8") as f:
|
||||||
ref = [json.loads(line) for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
|
ref = [json.loads(line) for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
|
||||||
assert len(data) == len(ref)
|
assert len(data) == len(ref)
|
||||||
|
|
||||||
|
batch_encoding = tokenizer(data, add_special_tokens=True, truncation=True, max_length=block_size)
|
||||||
|
self.examples = batch_encoding["input_ids"]
|
||||||
|
self.examples = [{"input_ids": torch.tensor(e, dtype=torch.long)} for e in self.examples]
|
||||||
|
|
||||||
n = len(self.examples)
|
n = len(self.examples)
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
self.examples[i]["chinese_ref"] = torch.tensor(ref[i], dtype=torch.long)
|
self.examples[i]["chinese_ref"] = torch.tensor(ref[i], dtype=torch.long)
|
||||||
|
|||||||
Reference in New Issue
Block a user