From e83d9f1c1d29890dd470de74f41627630e52abdc Mon Sep 17 00:00:00 2001 From: VictorSanh Date: Fri, 10 Jan 2020 19:34:25 -0500 Subject: [PATCH] cleaning - change ' to " (black requirements) --- examples/distillation/lm_seqs_dataset.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/distillation/lm_seqs_dataset.py b/examples/distillation/lm_seqs_dataset.py index a29e9efb28..8f444f4e0e 100644 --- a/examples/distillation/lm_seqs_dataset.py +++ b/examples/distillation/lm_seqs_dataset.py @@ -114,17 +114,17 @@ class LmSeqsDataset(Dataset): """ Remove sequences with a (too) high level of unknown tokens. """ - if 'unk_token' not in self.params.special_tok_ids: + if "unk_token" not in self.params.special_tok_ids: return else: - unk_token_id = self.params.special_tok_ids['unk_token'] + unk_token_id = self.params.special_tok_ids["unk_token"] init_size = len(self) unk_occs = np.array([np.count_nonzero(a == unk_token_id) for a in self.token_ids]) - indices = (unk_occs/self.lengths) < 0.5 + indices = (unk_occs / self.lengths) < 0.5 self.token_ids = self.token_ids[indices] self.lengths = self.lengths[indices] new_size = len(self) - logger.info(f'Remove {init_size - new_size} sequences with a high level of unknown tokens (50%).') + logger.info(f"Remove {init_size - new_size} sequences with a high level of unknown tokens (50%).") def print_statistics(self): """