From 1ae81e4aa1868eb24d975ebff4a7241ed10975fc Mon Sep 17 00:00:00 2001 From: VictorSanh Date: Wed, 28 Aug 2019 01:10:05 +0000 Subject: [PATCH] add dataset. distiller, utils --- examples/distillation/dataset.py | 184 ++++++++++++ examples/distillation/distiller.py | 431 +++++++++++++++++++++++++++++ examples/distillation/utils.py | 112 ++++++++ 3 files changed, 727 insertions(+) create mode 100644 examples/distillation/dataset.py create mode 100644 examples/distillation/distiller.py create mode 100644 examples/distillation/utils.py diff --git a/examples/distillation/dataset.py b/examples/distillation/dataset.py new file mode 100644 index 0000000000..6256ce1144 --- /dev/null +++ b/examples/distillation/dataset.py @@ -0,0 +1,184 @@ +from typing import List +import math +from itertools import chain +from collections import Counter +import numpy as np +import torch + +from utils import logger + +class Dataset: + def __init__(self, + params, + data): + self.params = params + self.tokens_per_batch = params.tokens_per_batch + self.batch_size = params.batch_size + self.shuffle = params.shuffle + self.group_by_size = params.group_by_size + + self.token_ids = np.array(data) + self.lengths = np.uint16([len(t) for t in data]) + + self.check() + self.remove_long_sequences() + self.remove_empty_sequences() + self.check() + self.print_statistics() + + def __len__(self): + return len(self.lengths) + + def check(self): + """ + Some sanity checks + """ + assert len(self.token_ids) == len(self.lengths) + + def remove_long_sequences(self): + """ + Sequences that are too long are splitted by chunk of max_position_embeddings. + """ + indices = self.lengths >= self.params.max_position_embeddings + logger.info(f'Splitting {sum(indices)} too long sequences.') + + def divide_chunks(l, n): + return [l[i:i + n] for i in range(0, len(l), n)] + + new_tok_ids = [] + new_lengths = [] + cls_id, sep_id = self.params.special_tok_ids['cls_token'], self.params.special_tok_ids['sep_token'] + max_len = self.params.max_position_embeddings + + for seq_, len_ in zip(self.token_ids, self.lengths): + if len_ <= max_len: + new_tok_ids.append(seq_) + new_lengths.append(len_) + else: + sub_seqs = [] + for sub_s in divide_chunks(seq_, max_len-2): + if sub_s[0] != cls_id: + sub_s = np.insert(sub_s, 0, cls_id) + if sub_s[-1] != sep_id: + sub_s = np.insert(sub_s, len(sub_s), cls_id) + assert len(sub_s) <= max_len + sub_seqs.append(sub_s) + + new_tok_ids.extend(sub_seqs) + new_lengths.extend([len(l) for l in sub_seqs]) + + self.token_ids = np.array(new_tok_ids) + self.lengths = np.array(new_lengths) + + def remove_empty_sequences(self): + """ + Too short sequences are simply removed. This could be tunedd. + """ + init_size = len(self) + indices = self.lengths > 5 + self.token_ids = self.token_ids[indices] + self.lengths = self.lengths[indices] + new_size = len(self) + logger.info(f'Remove {init_size - new_size} too short (<=5 tokens) sequences.') + + def print_statistics(self): + """ + Print some statistics on the corpus. Only the master process. + """ + if not self.params.is_master: + return + logger.info(f'{len(self)} sequences') + # data_len = sum(self.lengths) + # nb_unique_tokens = len(Counter(list(chain(*self.token_ids)))) + # logger.info(f'{data_len} tokens ({nb_unique_tokens} unique)') + + # unk_idx = self.params.special_tok_ids['unk_token'] + # nb_unkown = sum([(t==unk_idx).sum() for t in self.token_ids]) + # logger.info(f'{nb_unkown} unknown tokens (covering {100*nb_unkown/data_len:.2f}% of the data)') + + def select_data(self, a: int, b: int): + """ + Select a subportion of the data. + """ + n_sequences = len(self) + assert 0 <= a < b <= n_sequences, ValueError(f'`0 <= a < b <= n_sequences` is not met with a={a} and b={b}') + + logger.info(f'Selecting sequences from {a} to {b} (excluded).') + self.token_ids = self.token_ids[a:b] + self.lengths = self.lengths[a:b] + + self.check() + + def split(self): + """ + Distributed training: split the data accross the processes. + """ + assert self.params.n_gpu > 1 + logger.info('Splitting the data accross the processuses.') + n_seq = len(self) + n_seq_per_procesus = n_seq // self.params.world_size + a = n_seq_per_procesus * self.params.global_rank + b = a + n_seq_per_procesus + self.select_data(a=a, b=b) + + def batch_sequences(self, + token_ids: List[List[int]], + lengths: List[int]): + """ + Do the padding and transform into torch.tensor. + """ + assert len(token_ids) == len(lengths) + + # Max for paddings + max_seq_len_ = max(lengths) + + # Pad token ids + pad_idx = self.params.special_tok_ids['pad_token'] + tk_ = [list(t.astype(int)) + [pad_idx]*(max_seq_len_-len(t)) for t in token_ids] + assert len(tk_) == len(token_ids) + assert all(len(t) == max_seq_len_ for t in tk_) + + tk_t = torch.tensor(tk_) # (bs, max_seq_len_) + lg_t = torch.tensor(lengths.astype(int)) # (bs) + return tk_t, lg_t + + def get_batches_iterator(self, + batches): + """ + Return an iterator over batches. + """ + for sequences_ids in batches: + token_ids, lengths = self.batch_sequences(self.token_ids[sequences_ids], + self.lengths[sequences_ids]) + yield (token_ids, lengths) + + def get_iterator(self, + seed: int = None): + """ + Return a data iterator. + """ + rng = np.random.RandomState(seed) + + n_sequences = len(self) + indices = np.arange(n_sequences) + + if self.group_by_size: + indices = indices[np.argsort(self.lengths[indices], kind='mergesort')] + + if self.tokens_per_batch == -1: + batches = np.array_split(indices, math.ceil(len(indices) * 1. / self.batch_size)) + else: + assert self.tokens_per_batch > 0 + batch_ids = np.cumsum(self.lengths[indices]) // self.tokens_per_batch + _, bounds = np.unique(batch_ids, return_index=True) + batches = [indices[bounds[i]:bounds[i + 1]] for i in range(len(bounds) - 1)] + if bounds[-1] < len(indices): + batches.append(indices[bounds[-1]:]) + + if self.shuffle: + rng.shuffle(batches) + + assert n_sequences == sum([len(x) for x in batches]) + assert self.lengths[indices].sum() == sum([self.lengths[x].sum() for x in batches]) + + return self.get_batches_iterator(batches=batches) diff --git a/examples/distillation/distiller.py b/examples/distillation/distiller.py new file mode 100644 index 0000000000..c9c4458abc --- /dev/null +++ b/examples/distillation/distiller.py @@ -0,0 +1,431 @@ +import os +import math +from tensorboardX import SummaryWriter +from tqdm import trange, tqdm +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from pytorch_transformers import AdamW, WarmupLinearSchedule + +from utils import logger +from dataset import Dataset + +class Distiller: + def __init__(self, + params: dict, + dataloader: Dataset, + token_probs: torch.tensor, + student: nn.Module, + teacher: nn.Module): + logger.info('Initializing Distiller') + self.params = params + self.dump_path = params.dump_path + self.multi_gpu = params.multi_gpu + self.fp16 = params.fp16 + + self.student = student + self.teacher = teacher + + self.dataloader = dataloader + if self.params.n_gpu > 1: + self.dataloader.split() + self.get_iterator(seed=params.seed) + + self.temperature = params.temperature + assert self.temperature > 0. + + self.alpha_ce = params.alpha_ce + self.alpha_mlm = params.alpha_mlm + self.alpha_mse = params.alpha_mse + assert self.alpha_ce >= 0. + assert self.alpha_mlm >= 0. + assert self.alpha_mse >= 0. + assert self.alpha_ce + self.alpha_mlm + self.alpha_mse > 0. + + self.mlm_mask_prop = params.mlm_mask_prop + assert 0.0 <= self.mlm_mask_prop <= 1.0 + assert params.word_mask + params.word_keep + params.word_rand == 1.0 + self.pred_probs = torch.FloatTensor([params.word_mask, params.word_keep, params.word_rand]) + self.pred_probs = self.pred_probs.to(f'cuda:{params.local_rank}') if params.n_gpu > 0 else self.pred_probs + self.token_probs = token_probs.to(f'cuda:{params.local_rank}') if params.n_gpu > 0 else token_probs + if self.fp16: + self.pred_probs = self.pred_probs.half() + self.token_probs = self.token_probs.half() + + self.epoch = 0 + self.n_iter = 0 + self.n_total_iter = 0 + self.n_sequences_epoch = 0 + self.total_loss_epoch = 0 + self.last_loss = 0 + self.last_loss_ce = 0 + self.last_loss_mlm = 0 + self.last_loss_mse = 0 + + self.ce_loss_fct = nn.KLDivLoss(reduction='batchmean') + self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1) + self.mse_loss_fct = nn.MSELoss(reduction='sum') + + logger.info('--- Initializing model optimizer') + assert params.gradient_accumulation_steps >= 1 + self.num_steps_epoch = int(len(self.dataloader) / params.batch_size) + 1 + num_train_optimization_steps = int(self.num_steps_epoch / params.gradient_accumulation_steps * params.n_epoch) + 1 + warmup_steps = math.ceil(num_train_optimization_steps * params.warmup_prop) + + no_decay = ['bias', 'LayerNorm.weight'] + optimizer_grouped_parameters = [ + {'params': [p for n, p in student.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad], 'weight_decay': params.weight_decay}, + {'params': [p for n, p in student.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad], 'weight_decay': 0.0} + ] + logger.info("------ Number of trainable parameters (student): %i" % sum([p.numel() for p in self.student.parameters() if p.requires_grad])) + logger.info("------ Number of parameters (student): %i" % sum([p.numel() for p in self.student.parameters()])) + self.optimizer = AdamW(optimizer_grouped_parameters, + lr=params.learning_rate, + eps=params.adam_epsilon, + betas=(0.9, 0.98)) + self.scheduler = WarmupLinearSchedule(self.optimizer, + warmup_steps=warmup_steps, + t_total=num_train_optimization_steps) + + if self.fp16: + try: + from apex import amp + except ImportError: + raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") + logger.info(f"Using fp16 training: {self.params.fp16_opt_level} level") + self.student, self.optimizer = amp.initialize(self.student, + self.optimizer, + opt_level=self.params.fp16_opt_level) + self.teacher = self.teacher.half() + + if self.multi_gpu: + if self.fp16: + from apex.parallel import DistributedDataParallel + logger.info("Using apex.parallel.DistributedDataParallel for distributed training.") + self.student = DistributedDataParallel(self.student) + else: + from torch.nn.parallel import DistributedDataParallel + logger.info("Using nn.parallel.DistributedDataParallel for distributed training.") + self.student = DistributedDataParallel(self.student, + device_ids=[params.local_rank], + output_device=params.local_rank) + + self.is_master = params.is_master + if self.is_master: + logger.info('--- Initializing Tensorboard') + self.tensorboard = SummaryWriter(log_dir=os.path.join(self.dump_path, 'log', 'train')) + self.tensorboard.add_text(tag='config', text_string=str(self.params), global_step=0) + + def get_iterator(self, + seed: int = None): + """ + Initialize the data iterator. + Each process has its own data iterator (iterating on his own random portion of the dataset). + + Input: + ------ + seed: `int` - The random seed. + """ + logger.info('--- Initializing Data Iterator') + self.data_iterator = self.dataloader.get_iterator(seed=seed) + + def get_batch(self): + """ + Call the data iterator to output a new batch. + If the data iterator went through the whole dataset, create a new iterator. + """ + assert hasattr(self, 'data_iterator') + try: + x = next(self.data_iterator) + except StopIteration: + logger.warning('--- Went through the whole dataset. Creating new data iterator.') + self.data_iterator = self.dataloader.get_iterator() + x = next(self.data_iterator) + return x + + def prepare_batch(self, + batch): + """ + Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the masked label for MLM. + + Input: + ------ + batch: `Tuple` + token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded. + lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch. + + Output: + ------- + token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM. + attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention. + mlm_labels: `torch.tensor(bs, seq_length)` - The masked languge modeling labels. There is a -1 where there is nothing to predict. + """ + token_ids, lengths = batch + token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths) + assert token_ids.size(0) == lengths.size(0) + + attn_mask = (torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None]) + + bs, max_seq_len = token_ids.size() + mlm_labels = token_ids.new(token_ids.size()).copy_(token_ids) + + x_prob = self.token_probs[token_ids.flatten()] + n_tgt = math.ceil(self.mlm_mask_prop * lengths.sum().item()) + tgt_ids = torch.multinomial(x_prob / x_prob.sum(), n_tgt, replacement=False) + pred_mask = torch.zeros(bs * max_seq_len, dtype=torch.uint8, device=token_ids.device) + pred_mask[tgt_ids] = 1 + pred_mask = pred_mask.view(bs, max_seq_len) + + pred_mask[token_ids == self.params.special_tok_ids['pad_token']] = 0 + + # mask a number of words == 0 [8] (faster with fp16) + if self.fp16: + n1 = pred_mask.sum().item() + if n1 > 8: + pred_mask = pred_mask.view(-1) + n2 = max(n1 % 8, 8 * (n1 // 8)) + if n2 != n1: + pred_mask[torch.nonzero(pred_mask).view(-1)[:n1-n2]] = 0 + pred_mask = pred_mask.view(bs, max_seq_len) + assert pred_mask.sum().item() % 8 == 0, pred_mask.sum().item() + + _token_ids_real = token_ids[pred_mask] + _token_ids_rand = _token_ids_real.clone().random_(self.params.vocab_size) + _token_ids_mask = _token_ids_real.clone().fill_(self.params.special_tok_ids['mask_token']) + probs = torch.multinomial(self.pred_probs, len(_token_ids_real), replacement=True) + _token_ids = _token_ids_mask * (probs == 0).long() + _token_ids_real * (probs == 1).long() + _token_ids_rand * (probs == 2).long() + token_ids = token_ids.masked_scatter(pred_mask, _token_ids) + + mlm_labels[1-pred_mask] = -1 + + return token_ids, attn_mask, mlm_labels + + def round_batch(self, + x: torch.tensor, + lengths: torch.tensor): + """ + For float16 only. + Sub-sample sentences in a batch, and add padding, so that each dimension is a multiple of 8. + + Input: + ------ + x: `torch.tensor(bs, seq_length)` - The token ids. + lengths: `torch.tensor(bs, seq_length)` - The lengths of each of the sequence in the batch. + + Output: + ------- + x: `torch.tensor(new_bs, new_seq_length)` - The updated token ids. + lengths: `torch.tensor(new_bs, new_seq_length)` - The updated lengths. + """ + if not self.fp16 or len(lengths) < 8: + return x, lengths + + # number of sentences == 0 [8] + bs1 = len(lengths) + bs2 = 8 * (bs1 // 8) + assert bs2 > 0 and bs2 % 8 == 0 + if bs1 != bs2: + idx = torch.randperm(bs1)[:bs2] + lengths = lengths[idx] + slen = lengths.max().item() + x = x[idx, :slen] + else: + idx = None + + # sequence length == 0 [8] + ml1 = x.size(1) + if ml1 % 8 != 0: + pad = 8 - (ml1 % 8) + ml2 = ml1 + pad + pad_id = self.params.special_tok_ids['pad_token'] + padding_tensor = torch.zeros(bs2, pad, dtype=torch.long, device=x.device).fill_(pad_id) + x = torch.cat([x, padding_tensor], 1) + assert x.size() == (bs2, ml2) + + assert x.size(0) % 8 == 0 + assert x.size(1) % 8 == 0 + return x, lengths + + def train(self): + """ + The real training loop. + """ + if self.is_master: logger.info('Starting training') + self.student.train() + self.teacher.eval() + + for _ in range(self.params.n_epoch): + if self.is_master: logger.info(f'--- Starting epoch {self.epoch}/{self.params.n_epoch-1}') + + iter_bar = trange(self.num_steps_epoch, desc="-Iter", disable=self.params.local_rank not in [-1, 0]) + for __ in range(self.num_steps_epoch): + batch = self.get_batch() + if self.params.n_gpu > 0: + batch = tuple(t.to(f'cuda:{self.params.local_rank}') for t in batch) + token_ids, attn_mask, mlm_labels = self.prepare_batch(batch=batch) + + self.step(input_ids=token_ids, attention_mask=attn_mask, mlm_labels=mlm_labels) + + iter_bar.update() + iter_bar.set_postfix({'Last_loss': f'{self.last_loss:.2f}', + 'Avg_cum_loss': f'{self.total_loss_epoch/self.n_iter:.2f}'}) + iter_bar.close() + + if self.is_master: logger.info(f'--- Ending epoch {self.epoch}/{self.params.n_epoch-1}') + self.end_epoch() + + if self.is_master: logger.info('Training is finished') + + def step(self, + input_ids: torch.tensor, + attention_mask: torch.tensor, + mlm_labels: torch.tensor): + """ + One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation), + and possibly a parameter update (depending on the gradient accumulation). + + Input: + ------ + input_ids: `torch.tensor(bs, seq_length)` - The token ids. + attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention. + mlm_labels: `torch.tensor(bs, seq_length)` - The masked language modeling labels. + """ + s_logits = self.student(input_ids=input_ids, attention_mask=attention_mask)[0] # (bs, seq_length, voc_size) + with torch.no_grad(): + t_logits = self.teacher(input_ids=input_ids, attention_mask=attention_mask)[0] # (bs, seq_length, voc_size) + assert s_logits.size() == t_logits.size() + + #https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 + #https://github.com/peterliht/knowledge-distillation-pytorch/issues/2 + if self.params.restrict_ce_to_mask: + mask = (mlm_labels>-1).unsqueeze(-1).expand_as(s_logits) # (bs, seq_lenth, voc_size) + else: + mask = attention_mask.unsqueeze(-1).expand_as(s_logits) # (bs, seq_lenth, voc_size) + s_logits_slct = torch.masked_select(s_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask + s_logits_slct = s_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask + t_logits_slct = torch.masked_select(t_logits, mask) # (bs * seq_length * voc_size) modulo the 1s in mask + t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask + assert t_logits_slct.size() == s_logits_slct.size() + + loss_ce = self.ce_loss_fct(F.log_softmax(s_logits_slct/self.temperature, dim=-1), + F.softmax(t_logits_slct/self.temperature, dim=-1)) * (self.temperature)**2 + loss = self.alpha_ce*loss_ce + if self.alpha_mlm > 0.: + loss_mlm = self.mlm_loss_fct(s_logits.view(-1, s_logits.size(-1)), mlm_labels.view(-1)) + loss += self.alpha_mlm * loss_mlm + if self.alpha_mse > 0.: + loss_mse = self.mse_loss_fct(s_logits_slct, t_logits_slct)/s_logits_slct.size(0) # Reproducing batchmean reduction + loss += self.alpha_mse * loss_mse + + self.total_loss_epoch += loss.item() + self.last_loss = loss.item() + self.last_loss_ce = loss_ce.item() + if self.alpha_mlm > 0.: + self.last_loss_mlm = loss_mlm.item() + if self.alpha_mse > 0.: + self.last_loss_mse = loss_mse.item() + + self.optimize(loss) + + self.n_sequences_epoch += input_ids.size(0) + + def optimize(self, + loss): + """ + Normalization on the loss (gradient accumulation or distributed training), followed by + backward pass on the loss, possibly followed by a parameter update (depending on the gradient accumulation). + Also update the metrics for tensorboard. + """ + # Check for NaN + if (loss != loss).data.any(): + logger.error('NaN detected') + exit() + + if self.multi_gpu: + loss = loss.mean() + if self.params.gradient_accumulation_steps > 1: + loss = loss / self.params.gradient_accumulation_steps + + if self.fp16: + from apex import amp + with amp.scale_loss(loss, self.optimizer) as scaled_loss: + scaled_loss.backward() + else: + loss.backward() + + self.iter() + if self.n_iter % self.params.gradient_accumulation_steps == 0: + if self.fp16: + torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.params.max_grad_norm) + else: + torch.nn.utils.clip_grad_norm_(self.student.parameters(), self.params.max_grad_norm) + self.scheduler.step() + self.optimizer.step() + self.optimizer.zero_grad() + + def iter(self): + """ + Update global counts, write to tensorboard and save checkpoint. + """ + self.n_iter += 1 + self.n_total_iter += 1 + + if self.n_total_iter % self.params.log_interval == 0: + self.log_tensorboard() + if self.n_total_iter % self.params.checkpoint_interval == 0: + self.save_checkpoint() + + def log_tensorboard(self): + """ + Log into tensorboard. Only by the master process. + """ + if not self.is_master: + return + + for param_name, param in self.student.named_parameters(): + self.tensorboard.add_scalar(tag='parameter_mean/' + param_name, scalar_value=param.data.mean(), global_step=self.n_total_iter) + self.tensorboard.add_scalar(tag='parameter_std/' + param_name, scalar_value=param.data.std(), global_step=self.n_total_iter) + if param.grad is None: + continue + self.tensorboard.add_scalar(tag="grad_mean/" + param_name, scalar_value=param.grad.data.mean(),global_step=self.n_total_iter) + self.tensorboard.add_scalar(tag="grad_std/" + param_name, scalar_value=param.grad.data.std(), global_step=self.n_total_iter) + + self.tensorboard.add_scalar(tag="losses/cum_avg_loss_epoch", scalar_value=self.total_loss_epoch/self.n_iter, global_step=self.n_total_iter) + self.tensorboard.add_scalar(tag="losses/loss", scalar_value=self.last_loss, global_step=self.n_total_iter) + self.tensorboard.add_scalar(tag="losses/loss_ce", scalar_value=self.last_loss_ce, global_step=self.n_total_iter) + if self.alpha_mlm > 0.: + self.tensorboard.add_scalar(tag="losses/loss_mlm", scalar_value=self.last_loss_mlm, global_step=self.n_total_iter) + if self.alpha_mse > 0.: + self.tensorboard.add_scalar(tag="losses/loss_mse", scalar_value=self.last_loss_mse, global_step=self.n_total_iter) + self.tensorboard.add_scalar(tag="learning_rate/lr", scalar_value=self.scheduler.get_lr()[0], global_step=self.n_total_iter) + + def end_epoch(self): + """ + Finally arrived at the end of epoch (full pass on dataset). + Do some tensorboard logging and checkpoint saving. + """ + logger.info(f'{self.n_sequences_epoch} sequences have been trained during this epoch.') + + if self.is_master: + self.save_checkpoint(checkpoint_name=f'model_epoch_{self.epoch}.pth') + self.tensorboard.add_scalar(tag='epoch/loss', scalar_value=self.total_loss_epoch/self.n_iter, global_step=self.epoch) + + self.epoch += 1 + self.n_sequences_epoch = 0 + self.n_iter = 0 + self.total_loss_epoch = 0 + + def save_checkpoint(self, + checkpoint_name: str = 'checkpoint.pth'): + """ + Save the current state. Only by the master process. + """ + if not self.is_master: + return + mdl_to_save = self.student.module if hasattr(self.student, 'module') else self.student + mdl_to_save.config.save_pretrained(self.dump_path) + state_dict = mdl_to_save.state_dict() + torch.save(state_dict, os.path.join(self.dump_path, checkpoint_name)) diff --git a/examples/distillation/utils.py b/examples/distillation/utils.py new file mode 100644 index 0000000000..b3a9f15891 --- /dev/null +++ b/examples/distillation/utils.py @@ -0,0 +1,112 @@ +import git +import json +import os +import socket +import torch +import numpy as np + +import logging +logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s', + datefmt = '%m/%d/%Y %H:%M:%S', + level = logging.INFO) +logger = logging.getLogger(__name__) + + +def git_log(folder_path: str): + """ + Log commit info. + """ + repo = git.Repo(search_parent_directories=True) + repo_infos = { + 'repo_id': str(repo), + 'repo_sha': str(repo.head.object.hexsha), + 'repo_branch': str(repo.active_branch) + } + + with open(os.path.join(folder_path, 'git_log.json'), 'w') as f: + json.dump(repo_infos, f, indent=4) + + +def init_gpu_params(params): + """ + Handle single and multi-GPU / multi-node. + """ + if params.n_gpu <= 0: + params.local_rank = 0 + params.master_port = -1 + params.is_master = True + params.multi_gpu = False + return + + assert torch.cuda.is_available() + + logger.info('Initializing GPUs') + if params.n_gpu > 1: + assert params.local_rank != -1 + + params.world_size = int(os.environ['WORLD_SIZE']) + params.n_gpu_per_node = int(os.environ['N_GPU_NODE']) + params.global_rank = int(os.environ['RANK']) + + # number of nodes / node ID + params.n_nodes = params.world_size // params.n_gpu_per_node + params.node_id = params.global_rank // params.n_gpu_per_node + params.multi_gpu = True + + assert params.n_nodes == int(os.environ['N_NODES']) + assert params.node_id == int(os.environ['NODE_RANK']) + + # local job (single GPU) + else: + assert params.local_rank == -1 + + params.n_nodes = 1 + params.node_id = 0 + params.local_rank = 0 + params.global_rank = 0 + params.world_size = 1 + params.n_gpu_per_node = 1 + params.multi_gpu = False + + # sanity checks + assert params.n_nodes >= 1 + assert 0 <= params.node_id < params.n_nodes + assert 0 <= params.local_rank <= params.global_rank < params.world_size + assert params.world_size == params.n_nodes * params.n_gpu_per_node + + # define whether this is the master process / if we are in multi-node distributed mode + params.is_master = params.node_id == 0 and params.local_rank == 0 + params.multi_node = params.n_nodes > 1 + + # summary + PREFIX = f"--- Global rank: {params.global_rank} - " + logger.info(PREFIX + "Number of nodes: %i" % params.n_nodes) + logger.info(PREFIX + "Node ID : %i" % params.node_id) + logger.info(PREFIX + "Local rank : %i" % params.local_rank) + logger.info(PREFIX + "World size : %i" % params.world_size) + logger.info(PREFIX + "GPUs per node : %i" % params.n_gpu_per_node) + logger.info(PREFIX + "Master : %s" % str(params.is_master)) + logger.info(PREFIX + "Multi-node : %s" % str(params.multi_node)) + logger.info(PREFIX + "Multi-GPU : %s" % str(params.multi_gpu)) + logger.info(PREFIX + "Hostname : %s" % socket.gethostname()) + + # set GPU device + torch.cuda.set_device(params.local_rank) + + # initialize multi-GPU + if params.multi_gpu: + logger.info("Initializing PyTorch distributed") + torch.distributed.init_process_group( + init_method='env://', + backend='nccl', + ) + + +def set_seed(args): + """ + Set the random seed. + """ + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if args.n_gpu > 0: + torch.cuda.manual_seed_all(args.seed)