Fixed training for TF XLM
This commit is contained in:
@@ -84,7 +84,8 @@ def get_masks(slen, lengths, causal, padding_mask=None, dtype=tf.float32):
|
|||||||
attn_mask = mask
|
attn_mask = mask
|
||||||
|
|
||||||
# sanity check
|
# sanity check
|
||||||
assert shape_list(mask) == [bs, slen]
|
# assert shape_list(mask) == [bs, slen]
|
||||||
|
tf.debugging.assert_equal(shape_list(mask), [bs, slen])
|
||||||
assert causal is False or shape_list(attn_mask) == [bs, slen, slen]
|
assert causal is False or shape_list(attn_mask) == [bs, slen, slen]
|
||||||
|
|
||||||
mask = tf.cast(mask, dtype=dtype)
|
mask = tf.cast(mask, dtype=dtype)
|
||||||
@@ -318,7 +319,8 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
# check inputs
|
# check inputs
|
||||||
bs, slen = shape_list(input_ids)
|
bs, slen = shape_list(input_ids)
|
||||||
assert shape_list(lengths)[0] == bs
|
# assert shape_list(lengths)[0] == bs
|
||||||
|
tf.debugging.assert_equal(shape_list(lengths)[0], bs)
|
||||||
# assert lengths.max().item() <= slen
|
# assert lengths.max().item() <= slen
|
||||||
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
|
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
|
||||||
# assert (src_enc is None) == (src_len is None)
|
# assert (src_enc is None) == (src_len is None)
|
||||||
@@ -335,12 +337,14 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
|||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = tf.expand_dims(tf.range(slen), axis=0)
|
position_ids = tf.expand_dims(tf.range(slen), axis=0)
|
||||||
else:
|
else:
|
||||||
assert shape_list(position_ids) == [bs, slen] # (slen, bs)
|
# assert shape_list(position_ids) == [bs, slen] # (slen, bs)
|
||||||
|
tf.debugging.assert_equal(shape_list(position_ids), [bs, slen])
|
||||||
# position_ids = position_ids.transpose(0, 1)
|
# position_ids = position_ids.transpose(0, 1)
|
||||||
|
|
||||||
# langs
|
# langs
|
||||||
if langs is not None:
|
if langs is not None:
|
||||||
assert shape_list(langs) == [bs, slen] # (slen, bs)
|
# assert shape_list(langs) == [bs, slen] # (slen, bs)
|
||||||
|
tf.debugging.assert_equal(shape_list(langs), [bs, slen])
|
||||||
# langs = langs.transpose(0, 1)
|
# langs = langs.transpose(0, 1)
|
||||||
|
|
||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
|
|||||||
Reference in New Issue
Block a user