From 4a53e8e9e405779cc9f01c11c4d866b3fb6738e2 Mon Sep 17 00:00:00 2001 From: Jonathan Chang <31893406+cccntu@users.noreply.github.com> Date: Sun, 8 Nov 2020 06:53:01 -0800 Subject: [PATCH] Fix DataCollatorForWholeWordMask again (#8397) --- src/transformers/data/data_collator.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index 04b55b7b6a..ba94baaa7d 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -206,6 +206,10 @@ def _collate_batch(examples, tokenizer): return result +def tolist(x: Union[List[Any], torch.Tensor]): + return x.tolist() if isinstance(x, torch.Tensor) else x + + @dataclass class DataCollatorForLanguageModeling: """ @@ -320,13 +324,13 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling): mask_labels = [] for e in examples: ref_tokens = [] - for id in e["input_ids"].tolist(): + for id in tolist(e["input_ids"]): token = self.tokenizer._convert_id_to_token(id) ref_tokens.append(token) # For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢] if "chinese_ref" in e: - ref_pos = e["chinese_ref"].tolist() + ref_pos = tolist(e["chinese_ref"]) len_seq = e["input_ids"].size(0) for i in range(len_seq): if i in ref_pos: