Fix TensorFlow dataset generator (#4881)

* fix TensorFlow generator

* Better features handling

* Apply style

* Apply style

* Fix squad as well

* Apply style

* Better factorization of TF Tensors creation
This commit is contained in:
Julien Plu
2020-07-01 01:49:11 +02:00
committed by GitHub
parent 501040fd30
commit fcf0652460
2 changed files with 101 additions and 65 deletions

View File

@@ -389,57 +389,102 @@ def squad_convert_examples_to_features(
def gen():
for i, ex in enumerate(features):
yield (
{
"input_ids": ex.input_ids,
"attention_mask": ex.attention_mask,
"token_type_ids": ex.token_type_ids,
"feature_index": i,
"qas_id": ex.qas_id,
},
{
"start_positions": ex.start_position,
"end_positions": ex.end_position,
"cls_index": ex.cls_index,
"p_mask": ex.p_mask,
"is_impossible": ex.is_impossible,
},
)
if ex.token_type_ids is None:
yield (
{
"input_ids": ex.input_ids,
"attention_mask": ex.attention_mask,
"feature_index": i,
"qas_id": ex.qas_id,
},
{
"start_positions": ex.start_position,
"end_positions": ex.end_position,
"cls_index": ex.cls_index,
"p_mask": ex.p_mask,
"is_impossible": ex.is_impossible,
},
)
else:
yield (
{
"input_ids": ex.input_ids,
"attention_mask": ex.attention_mask,
"token_type_ids": ex.token_type_ids,
"feature_index": i,
"qas_id": ex.qas_id,
},
{
"start_positions": ex.start_position,
"end_positions": ex.end_position,
"cls_index": ex.cls_index,
"p_mask": ex.p_mask,
"is_impossible": ex.is_impossible,
},
)
# Why have we split the batch into a tuple? PyTorch just has a list of tensors.
train_types = (
{
"input_ids": tf.int32,
"attention_mask": tf.int32,
"token_type_ids": tf.int32,
"feature_index": tf.int64,
"qas_id": tf.string,
},
{
"start_positions": tf.int64,
"end_positions": tf.int64,
"cls_index": tf.int64,
"p_mask": tf.int32,
"is_impossible": tf.int32,
},
)
if "token_type_ids" in tokenizer.model_input_names:
train_types = (
{
"input_ids": tf.int32,
"attention_mask": tf.int32,
"token_type_ids": tf.int32,
"feature_index": tf.int64,
"qas_id": tf.string,
},
{
"start_positions": tf.int64,
"end_positions": tf.int64,
"cls_index": tf.int64,
"p_mask": tf.int32,
"is_impossible": tf.int32,
},
)
train_shapes = (
{
"input_ids": tf.TensorShape([None]),
"attention_mask": tf.TensorShape([None]),
"token_type_ids": tf.TensorShape([None]),
"feature_index": tf.TensorShape([]),
"qas_id": tf.TensorShape([]),
},
{
"start_positions": tf.TensorShape([]),
"end_positions": tf.TensorShape([]),
"cls_index": tf.TensorShape([]),
"p_mask": tf.TensorShape([None]),
"is_impossible": tf.TensorShape([]),
},
)
train_shapes = (
{
"input_ids": tf.TensorShape([None]),
"attention_mask": tf.TensorShape([None]),
"token_type_ids": tf.TensorShape([None]),
"feature_index": tf.TensorShape([]),
"qas_id": tf.TensorShape([]),
},
{
"start_positions": tf.TensorShape([]),
"end_positions": tf.TensorShape([]),
"cls_index": tf.TensorShape([]),
"p_mask": tf.TensorShape([None]),
"is_impossible": tf.TensorShape([]),
},
)
else:
train_types = (
{"input_ids": tf.int32, "attention_mask": tf.int32, "feature_index": tf.int64, "qas_id": tf.string},
{
"start_positions": tf.int64,
"end_positions": tf.int64,
"cls_index": tf.int64,
"p_mask": tf.int32,
"is_impossible": tf.int32,
},
)
train_shapes = (
{
"input_ids": tf.TensorShape([None]),
"attention_mask": tf.TensorShape([None]),
"feature_index": tf.TensorShape([]),
"qas_id": tf.TensorShape([]),
},
{
"start_positions": tf.TensorShape([]),
"end_positions": tf.TensorShape([]),
"cls_index": tf.TensorShape([]),
"p_mask": tf.TensorShape([None]),
"is_impossible": tf.TensorShape([]),
},
)
return tf.data.Dataset.from_generator(gen, train_types, train_shapes)
else: