Fix eval ref miss in Chinese WWM. (#8115)

* ADD: add whole word mask proxy for both eng and chinese

* MOD: adjust format

* MOD: reformat code

* MOD: update import

* MOD: fix bug

* MOD: add import

* MOD: fix bug

* MOD: decouple code and update readme

* MOD: reformat code

* Update examples/language-modeling/README.md

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update examples/language-modeling/README.md

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update examples/language-modeling/run_language_modeling.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update examples/language-modeling/run_language_modeling.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update examples/language-modeling/run_language_modeling.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update examples/language-modeling/run_language_modeling.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* change wwm to whole_word_mask

* reformat code

* reformat

* format

* Code quality

* ADD: update chinese ref readme

* MOD: small changes

* MOD: small changes2

* update readme

* fix eval ref file miss bug

* format file

* MOD: move ref code to contrib

* MOD: add delimeter check

* reformat code

* refomat code

* Update examples/language-modeling/README.md

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <sylvain.gugger@gmail.com>
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
wlhgtc
2020-10-30 05:08:39 +08:00
committed by GitHub
parent fdf893c441
commit 9a21b50614
4 changed files with 26 additions and 16 deletions

View File

@@ -118,7 +118,7 @@ def main(args):
# If we want to fine-tune these model, we have to use same tokenizer : LTP (https://github.com/HIT-SCIR/ltp) # If we want to fine-tune these model, we have to use same tokenizer : LTP (https://github.com/HIT-SCIR/ltp)
with open(args.file_name, "r", encoding="utf-8") as f: with open(args.file_name, "r", encoding="utf-8") as f:
data = f.readlines() data = f.readlines()
data = [line.strip() for line in data if len(line) > 0 and not line.isspace()] # avoid delimiter like '\u2029'
ltp_tokenizer = LTP(args.ltp) # faster in GPU device ltp_tokenizer = LTP(args.ltp) # faster in GPU device
bert_tokenizer = BertTokenizer.from_pretrained(args.bert) bert_tokenizer = BertTokenizer.from_pretrained(args.bert)

View File

@@ -63,7 +63,7 @@ python run_language_modeling.py \
--whole_word_mask --whole_word_mask
``` ```
For Chinese models, it's same with English model with only --mlm`. If using whole-word masking, we need to generate a reference files, case it's char level. For Chinese models, it's same with English model with only `--mlm`. If using whole-word masking, we need to generate a reference files, because it's char level.
**Q :** Why ref file ? **Q :** Why ref file ?
@@ -76,15 +76,19 @@ So we need a ref file to tell model which pos of BERT original token should be a
**A :** Cause the best known Chinese WWM BERT is [Chinese-BERT-wwm](https://github.com/ymcui/Chinese-BERT-wwm) by HIT. It works well on so many Chines Task like CLUE (Chinese GLUE). **A :** Cause the best known Chinese WWM BERT is [Chinese-BERT-wwm](https://github.com/ymcui/Chinese-BERT-wwm) by HIT. It works well on so many Chines Task like CLUE (Chinese GLUE).
They use LTP, so if we want to fine-tune their model, we need LTP. They use LTP, so if we want to fine-tune their model, we need LTP.
Now LTP only only works well on `transformers==3.2.0`. So we don't add it to requirements.txt.
You need to check to `3.2.0` for `run_chinese_ref.py`. And the code could be found in `examples/contrib`.
```bash ```bash
export TRAIN_FILE=/path/to/dataset/wiki.train.raw export TRAIN_FILE=/path/to/dataset/wiki.train.raw
export LTP_RESOURCE=/path/to/ltp/tokenizer export LTP_RESOURCE=/path/to/ltp/tokenizer
export BERT_RESOURCE=/path/to/bert/tokenizer export BERT_RESOURCE=/path/to/bert/tokenizer
export SAVE_PATH=/path/to/data/ref.txt export SAVE_PATH=/path/to/data/ref.txt
python chinese_ref.py \ python examples/contrib/run_chinese_ref.py \
--file_name=$TRAIN_FILE \ --file_name=$TRAIN_FILE \
--ltp=$LTP_RESOURCE --ltp=$LTP_RESOURCE \
--bert=$BERT_RESOURCE \ --bert=$BERT_RESOURCE \
--save_path=$SAVE_PATH --save_path=$SAVE_PATH
``` ```

View File

@@ -103,9 +103,13 @@ class DataTrainingArguments:
default=None, default=None,
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
) )
chinese_ref_file: Optional[str] = field( train_ref_file: Optional[str] = field(
default=None, default=None,
metadata={"help": "An optional input ref data file for whole word mask in Chinees."}, metadata={"help": "An optional input train ref data file for whole word mask in Chinese."},
)
eval_ref_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input eval ref data file for whole word mask in Chinese."},
) )
line_by_line: bool = field( line_by_line: bool = field(
default=False, default=False,
@@ -148,16 +152,16 @@ def get_dataset(
evaluate: bool = False, evaluate: bool = False,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
): ):
def _dataset(file_path): def _dataset(file_path, ref_path=None):
if args.line_by_line: if args.line_by_line:
if args.chinese_ref_file is not None: if ref_path is not None:
if not args.whole_word_mask or not args.mlm: if not args.whole_word_mask or not args.mlm:
raise ValueError("You need to set world whole masking and mlm to True for Chinese Whole Word Mask") raise ValueError("You need to set world whole masking and mlm to True for Chinese Whole Word Mask")
return LineByLineWithRefDataset( return LineByLineWithRefDataset(
tokenizer=tokenizer, tokenizer=tokenizer,
file_path=file_path, file_path=file_path,
block_size=args.block_size, block_size=args.block_size,
ref_path=args.chinese_ref_file, ref_path=ref_path,
) )
return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size) return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size)
@@ -171,11 +175,11 @@ def get_dataset(
) )
if evaluate: if evaluate:
return _dataset(args.eval_data_file) return _dataset(args.eval_data_file, args.eval_ref_file)
elif args.train_data_files: elif args.train_data_files:
return ConcatDataset([_dataset(f) for f in glob(args.train_data_files)]) return ConcatDataset([_dataset(f) for f in glob(args.train_data_files)])
else: else:
return _dataset(args.train_data_file) return _dataset(args.train_data_file, args.train_ref_file)
def main(): def main():

View File

@@ -128,15 +128,17 @@ class LineByLineWithRefDataset(Dataset):
logger.info("Creating features from dataset file at %s", file_path) logger.info("Creating features from dataset file at %s", file_path)
logger.info("Use ref segment results at %s", ref_path) logger.info("Use ref segment results at %s", ref_path)
with open(file_path, encoding="utf-8") as f: with open(file_path, encoding="utf-8") as f:
data = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())] data = f.readlines() # use this method to avoid delimiter '\u2029' to split a line
batch_encoding = tokenizer(data, add_special_tokens=True, truncation=True, max_length=block_size) data = [line.strip() for line in data if len(line) > 0 and not line.isspace()]
self.examples = batch_encoding["input_ids"]
self.examples = [{"input_ids": torch.tensor(e, dtype=torch.long)} for e in self.examples]
# Get ref inf from file # Get ref inf from file
with open(ref_path, encoding="utf-8") as f: with open(ref_path, encoding="utf-8") as f:
ref = [json.loads(line) for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())] ref = [json.loads(line) for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
assert len(data) == len(ref) assert len(data) == len(ref)
batch_encoding = tokenizer(data, add_special_tokens=True, truncation=True, max_length=block_size)
self.examples = batch_encoding["input_ids"]
self.examples = [{"input_ids": torch.tensor(e, dtype=torch.long)} for e in self.examples]
n = len(self.examples) n = len(self.examples)
for i in range(n): for i in range(n):
self.examples[i]["chinese_ref"] = torch.tensor(ref[i], dtype=torch.long) self.examples[i]["chinese_ref"] = torch.tensor(ref[i], dtype=torch.long)