Fix TensorFlow dataset generator (#4881)
* fix TensorFlow generator * Better features handling * Apply style * Apply style * Fix squad as well * Apply style * Better factorization of TF Tensors creation
This commit is contained in:
@@ -389,57 +389,102 @@ def squad_convert_examples_to_features(
|
||||
|
||||
def gen():
|
||||
for i, ex in enumerate(features):
|
||||
yield (
|
||||
{
|
||||
"input_ids": ex.input_ids,
|
||||
"attention_mask": ex.attention_mask,
|
||||
"token_type_ids": ex.token_type_ids,
|
||||
"feature_index": i,
|
||||
"qas_id": ex.qas_id,
|
||||
},
|
||||
{
|
||||
"start_positions": ex.start_position,
|
||||
"end_positions": ex.end_position,
|
||||
"cls_index": ex.cls_index,
|
||||
"p_mask": ex.p_mask,
|
||||
"is_impossible": ex.is_impossible,
|
||||
},
|
||||
)
|
||||
if ex.token_type_ids is None:
|
||||
yield (
|
||||
{
|
||||
"input_ids": ex.input_ids,
|
||||
"attention_mask": ex.attention_mask,
|
||||
"feature_index": i,
|
||||
"qas_id": ex.qas_id,
|
||||
},
|
||||
{
|
||||
"start_positions": ex.start_position,
|
||||
"end_positions": ex.end_position,
|
||||
"cls_index": ex.cls_index,
|
||||
"p_mask": ex.p_mask,
|
||||
"is_impossible": ex.is_impossible,
|
||||
},
|
||||
)
|
||||
else:
|
||||
yield (
|
||||
{
|
||||
"input_ids": ex.input_ids,
|
||||
"attention_mask": ex.attention_mask,
|
||||
"token_type_ids": ex.token_type_ids,
|
||||
"feature_index": i,
|
||||
"qas_id": ex.qas_id,
|
||||
},
|
||||
{
|
||||
"start_positions": ex.start_position,
|
||||
"end_positions": ex.end_position,
|
||||
"cls_index": ex.cls_index,
|
||||
"p_mask": ex.p_mask,
|
||||
"is_impossible": ex.is_impossible,
|
||||
},
|
||||
)
|
||||
|
||||
# Why have we split the batch into a tuple? PyTorch just has a list of tensors.
|
||||
train_types = (
|
||||
{
|
||||
"input_ids": tf.int32,
|
||||
"attention_mask": tf.int32,
|
||||
"token_type_ids": tf.int32,
|
||||
"feature_index": tf.int64,
|
||||
"qas_id": tf.string,
|
||||
},
|
||||
{
|
||||
"start_positions": tf.int64,
|
||||
"end_positions": tf.int64,
|
||||
"cls_index": tf.int64,
|
||||
"p_mask": tf.int32,
|
||||
"is_impossible": tf.int32,
|
||||
},
|
||||
)
|
||||
if "token_type_ids" in tokenizer.model_input_names:
|
||||
train_types = (
|
||||
{
|
||||
"input_ids": tf.int32,
|
||||
"attention_mask": tf.int32,
|
||||
"token_type_ids": tf.int32,
|
||||
"feature_index": tf.int64,
|
||||
"qas_id": tf.string,
|
||||
},
|
||||
{
|
||||
"start_positions": tf.int64,
|
||||
"end_positions": tf.int64,
|
||||
"cls_index": tf.int64,
|
||||
"p_mask": tf.int32,
|
||||
"is_impossible": tf.int32,
|
||||
},
|
||||
)
|
||||
|
||||
train_shapes = (
|
||||
{
|
||||
"input_ids": tf.TensorShape([None]),
|
||||
"attention_mask": tf.TensorShape([None]),
|
||||
"token_type_ids": tf.TensorShape([None]),
|
||||
"feature_index": tf.TensorShape([]),
|
||||
"qas_id": tf.TensorShape([]),
|
||||
},
|
||||
{
|
||||
"start_positions": tf.TensorShape([]),
|
||||
"end_positions": tf.TensorShape([]),
|
||||
"cls_index": tf.TensorShape([]),
|
||||
"p_mask": tf.TensorShape([None]),
|
||||
"is_impossible": tf.TensorShape([]),
|
||||
},
|
||||
)
|
||||
train_shapes = (
|
||||
{
|
||||
"input_ids": tf.TensorShape([None]),
|
||||
"attention_mask": tf.TensorShape([None]),
|
||||
"token_type_ids": tf.TensorShape([None]),
|
||||
"feature_index": tf.TensorShape([]),
|
||||
"qas_id": tf.TensorShape([]),
|
||||
},
|
||||
{
|
||||
"start_positions": tf.TensorShape([]),
|
||||
"end_positions": tf.TensorShape([]),
|
||||
"cls_index": tf.TensorShape([]),
|
||||
"p_mask": tf.TensorShape([None]),
|
||||
"is_impossible": tf.TensorShape([]),
|
||||
},
|
||||
)
|
||||
else:
|
||||
train_types = (
|
||||
{"input_ids": tf.int32, "attention_mask": tf.int32, "feature_index": tf.int64, "qas_id": tf.string},
|
||||
{
|
||||
"start_positions": tf.int64,
|
||||
"end_positions": tf.int64,
|
||||
"cls_index": tf.int64,
|
||||
"p_mask": tf.int32,
|
||||
"is_impossible": tf.int32,
|
||||
},
|
||||
)
|
||||
|
||||
train_shapes = (
|
||||
{
|
||||
"input_ids": tf.TensorShape([None]),
|
||||
"attention_mask": tf.TensorShape([None]),
|
||||
"feature_index": tf.TensorShape([]),
|
||||
"qas_id": tf.TensorShape([]),
|
||||
},
|
||||
{
|
||||
"start_positions": tf.TensorShape([]),
|
||||
"end_positions": tf.TensorShape([]),
|
||||
"cls_index": tf.TensorShape([]),
|
||||
"p_mask": tf.TensorShape([None]),
|
||||
"is_impossible": tf.TensorShape([]),
|
||||
},
|
||||
)
|
||||
|
||||
return tf.data.Dataset.from_generator(gen, train_types, train_shapes)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user