@@ -418,13 +418,14 @@ class FlaxDataCollatorForT5MLM:
|
|||||||
orig_length = length
|
orig_length = length
|
||||||
|
|
||||||
num_noise_tokens = int(np.round(length * self.noise_density))
|
num_noise_tokens = int(np.round(length * self.noise_density))
|
||||||
|
num_nonnoise_tokens = length - num_noise_tokens
|
||||||
# avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens.
|
# avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens.
|
||||||
num_noise_tokens = min(max(num_noise_tokens, 1), length - 1)
|
num_noise_tokens = min(max(num_noise_tokens, 1), length - 1)
|
||||||
num_noise_spans = int(np.round(num_noise_tokens / self.mean_noise_span_length))
|
# num_noise_tokens should be less than num_noise_tokens and num_nonnoise_tokens
|
||||||
|
num_noise_spans = int(np.round(min(num_noise_tokens, num_nonnoise_tokens) / self.mean_noise_span_length))
|
||||||
|
|
||||||
# avoid degeneracy by ensuring positive number of noise spans
|
# avoid degeneracy by ensuring positive number of noise spans
|
||||||
num_noise_spans = max(num_noise_spans, 1)
|
num_noise_spans = max(num_noise_spans, 1)
|
||||||
num_nonnoise_tokens = length - num_noise_tokens
|
|
||||||
|
|
||||||
# pick the lengths of the noise spans and the non-noise spans
|
# pick the lengths of the noise spans and the non-noise spans
|
||||||
def _random_segmentation(num_items, num_segments):
|
def _random_segmentation(num_items, num_segments):
|
||||||
|
|||||||
Reference in New Issue
Block a user