cleaning - change ' to " (black requirements)
This commit is contained in:
@@ -114,17 +114,17 @@ class LmSeqsDataset(Dataset):
|
|||||||
"""
|
"""
|
||||||
Remove sequences with a (too) high level of unknown tokens.
|
Remove sequences with a (too) high level of unknown tokens.
|
||||||
"""
|
"""
|
||||||
if 'unk_token' not in self.params.special_tok_ids:
|
if "unk_token" not in self.params.special_tok_ids:
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
unk_token_id = self.params.special_tok_ids['unk_token']
|
unk_token_id = self.params.special_tok_ids["unk_token"]
|
||||||
init_size = len(self)
|
init_size = len(self)
|
||||||
unk_occs = np.array([np.count_nonzero(a == unk_token_id) for a in self.token_ids])
|
unk_occs = np.array([np.count_nonzero(a == unk_token_id) for a in self.token_ids])
|
||||||
indices = (unk_occs / self.lengths) < 0.5
|
indices = (unk_occs / self.lengths) < 0.5
|
||||||
self.token_ids = self.token_ids[indices]
|
self.token_ids = self.token_ids[indices]
|
||||||
self.lengths = self.lengths[indices]
|
self.lengths = self.lengths[indices]
|
||||||
new_size = len(self)
|
new_size = len(self)
|
||||||
logger.info(f'Remove {init_size - new_size} sequences with a high level of unknown tokens (50%).')
|
logger.info(f"Remove {init_size - new_size} sequences with a high level of unknown tokens (50%).")
|
||||||
|
|
||||||
def print_statistics(self):
|
def print_statistics(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user