From 594202a9348d7c2f27f7deaf1a7308e3751b3fbc Mon Sep 17 00:00:00 2001 From: VictorSanh Date: Wed, 2 Oct 2019 11:00:57 -0400 Subject: [PATCH] lm_seqs_dataset --- .../{dataset.py => lm_seqs_dataset.py} | 124 ++++++------------ 1 file changed, 37 insertions(+), 87 deletions(-) rename examples/distillation/{dataset.py => lm_seqs_dataset.py} (54%) diff --git a/examples/distillation/dataset.py b/examples/distillation/lm_seqs_dataset.py similarity index 54% rename from examples/distillation/dataset.py rename to examples/distillation/lm_seqs_dataset.py index 4babf73ea4..54e9742ce8 100644 --- a/examples/distillation/dataset.py +++ b/examples/distillation/lm_seqs_dataset.py @@ -12,30 +12,33 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Dataloaders to train DistilBERT +""" Dataset to distilled models adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM) """ -from typing import List -import math -from itertools import chain -from collections import Counter -import numpy as np import torch +from torch.utils.data import Dataset +import numpy as np from utils import logger -class Dataset: +class LmSeqsDataset(Dataset): + """Custom Dataset wrapping language modeling sequences. + + Each sample will be retrieved by indexing the list of token_ids and their corresponding lengths. + + Input: + ------ + params: `NameSpace` parameters + data: `List[np.array[int]] + """ + def __init__(self, params, data): self.params = params - self.tokens_per_batch = params.tokens_per_batch - self.batch_size = params.batch_size - self.shuffle = params.shuffle - self.group_by_size = params.group_by_size self.token_ids = np.array(data) - self.lengths = np.uint16([len(t) for t in data]) + self.lengths = np.array([len(t) for t in data]) self.check() self.remove_long_sequences() @@ -43,6 +46,9 @@ class Dataset: self.check() self.print_statistics() + def __getitem__(self, index): + return (self.token_ids[index], self.lengths[index]) + def __len__(self): return len(self.lengths) @@ -51,12 +57,14 @@ class Dataset: Some sanity checks """ assert len(self.token_ids) == len(self.lengths) + assert all(self.lengths[i] == len(self.token_ids[i]) for i in range(len(self.lengths))) def remove_long_sequences(self): """ - Sequences that are too long are splitted by chunk of max_position_embeddings. + Sequences that are too long are splitted by chunk of max_model_input_size. """ - indices = self.lengths >= self.params.max_position_embeddings + max_len = self.params.max_model_input_size + indices = self.lengths > max_len logger.info(f'Splitting {sum(indices)} too long sequences.') def divide_chunks(l, n): @@ -64,10 +72,13 @@ class Dataset: new_tok_ids = [] new_lengths = [] - cls_id, sep_id = self.params.special_tok_ids['cls_token'], self.params.special_tok_ids['sep_token'] - max_len = self.params.max_position_embeddings + if self.params.mlm: + cls_id, sep_id = self.params.special_tok_ids['cls_token'], self.params.special_tok_ids['sep_token'] + else: + cls_id, sep_id = self.params.special_tok_ids['bos_token'], self.params.special_tok_ids['eos_token'] for seq_, len_ in zip(self.token_ids, self.lengths): + assert (seq_[0] == cls_id) and (seq_[-1] == sep_id), seq_ if len_ <= max_len: new_tok_ids.append(seq_) new_lengths.append(len_) @@ -79,6 +90,7 @@ class Dataset: if sub_s[-1] != sep_id: sub_s = np.insert(sub_s, len(sub_s), sep_id) assert len(sub_s) <= max_len + assert (sub_s[0] == cls_id) and (sub_s[-1] == sep_id), sub_s sub_seqs.append(sub_s) new_tok_ids.extend(sub_seqs) @@ -113,89 +125,27 @@ class Dataset: # nb_unkown = sum([(t==unk_idx).sum() for t in self.token_ids]) # logger.info(f'{nb_unkown} unknown tokens (covering {100*nb_unkown/data_len:.2f}% of the data)') - def select_data(self, a: int, b: int): - """ - Select a subportion of the data. - """ - n_sequences = len(self) - assert 0 <= a < b <= n_sequences, ValueError(f'`0 <= a < b <= n_sequences` is not met with a={a} and b={b}') - - logger.info(f'Selecting sequences from {a} to {b} (excluded).') - self.token_ids = self.token_ids[a:b] - self.lengths = self.lengths[a:b] - - self.check() - - def split(self): - """ - Distributed training: split the data accross the processes. - """ - assert self.params.n_gpu > 1 - logger.info('Splitting the data accross the processuses.') - n_seq = len(self) - n_seq_per_procesus = n_seq // self.params.world_size - a = n_seq_per_procesus * self.params.global_rank - b = a + n_seq_per_procesus - self.select_data(a=a, b=b) - def batch_sequences(self, - token_ids: List[List[int]], - lengths: List[int]): + batch): """ Do the padding and transform into torch.tensor. """ + token_ids = [t[0] for t in batch] + lengths = [t[1] for t in batch] assert len(token_ids) == len(lengths) # Max for paddings max_seq_len_ = max(lengths) # Pad token ids - pad_idx = self.params.special_tok_ids['pad_token'] + if self.params.mlm: + pad_idx = self.params.special_tok_ids['pad_token'] + else: + pad_idx = self.params.special_tok_ids['unk_token'] tk_ = [list(t.astype(int)) + [pad_idx]*(max_seq_len_-len(t)) for t in token_ids] assert len(tk_) == len(token_ids) assert all(len(t) == max_seq_len_ for t in tk_) - tk_t = torch.tensor(tk_) # (bs, max_seq_len_) - lg_t = torch.tensor(lengths.astype(int)) # (bs) + tk_t = torch.tensor(tk_) # (bs, max_seq_len_) + lg_t = torch.tensor(lengths) # (bs) return tk_t, lg_t - - def get_batches_iterator(self, - batches): - """ - Return an iterator over batches. - """ - for sequences_ids in batches: - token_ids, lengths = self.batch_sequences(self.token_ids[sequences_ids], - self.lengths[sequences_ids]) - yield (token_ids, lengths) - - def get_iterator(self, - seed: int = None): - """ - Return a data iterator. - """ - rng = np.random.RandomState(seed) - - n_sequences = len(self) - indices = np.arange(n_sequences) - - if self.group_by_size: - indices = indices[np.argsort(self.lengths[indices], kind='mergesort')] - - if self.tokens_per_batch == -1: - batches = np.array_split(indices, math.ceil(len(indices) * 1. / self.batch_size)) - else: - assert self.tokens_per_batch > 0 - batch_ids = np.cumsum(self.lengths[indices]) // self.tokens_per_batch - _, bounds = np.unique(batch_ids, return_index=True) - batches = [indices[bounds[i]:bounds[i + 1]] for i in range(len(bounds) - 1)] - if bounds[-1] < len(indices): - batches.append(indices[bounds[-1]:]) - - if self.shuffle: - rng.shuffle(batches) - - assert n_sequences == sum([len(x) for x in batches]) - assert self.lengths[indices].sum() == sum([self.lengths[x].sum() for x in batches]) - - return self.get_batches_iterator(batches=batches)