diff --git a/examples/distillation/lm_seqs_dataset.py b/examples/distillation/lm_seqs_dataset.py index a29e9efb28..8f444f4e0e 100644 --- a/examples/distillation/lm_seqs_dataset.py +++ b/examples/distillation/lm_seqs_dataset.py @@ -114,17 +114,17 @@ class LmSeqsDataset(Dataset): """ Remove sequences with a (too) high level of unknown tokens. """ - if 'unk_token' not in self.params.special_tok_ids: + if "unk_token" not in self.params.special_tok_ids: return else: - unk_token_id = self.params.special_tok_ids['unk_token'] + unk_token_id = self.params.special_tok_ids["unk_token"] init_size = len(self) unk_occs = np.array([np.count_nonzero(a == unk_token_id) for a in self.token_ids]) - indices = (unk_occs/self.lengths) < 0.5 + indices = (unk_occs / self.lengths) < 0.5 self.token_ids = self.token_ids[indices] self.lengths = self.lengths[indices] new_size = len(self) - logger.info(f'Remove {init_size - new_size} sequences with a high level of unknown tokens (50%).') + logger.info(f"Remove {init_size - new_size} sequences with a high level of unknown tokens (50%).") def print_statistics(self): """