Fix TensorFlow dataset generator (#4881)
* fix TensorFlow generator * Better features handling * Apply style * Apply style * Fix squad as well * Apply style * Better factorization of TF Tensors creation
This commit is contained in:
@@ -17,6 +17,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from dataclasses import asdict
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
@@ -81,26 +82,16 @@ if is_tf_available():
|
|||||||
|
|
||||||
def gen():
|
def gen():
|
||||||
for ex in features:
|
for ex in features:
|
||||||
yield (
|
d = {k: v for k, v in asdict(ex).items() if v is not None}
|
||||||
{
|
label = d.pop("label")
|
||||||
"input_ids": ex.input_ids,
|
yield (d, label)
|
||||||
"attention_mask": ex.attention_mask,
|
|
||||||
"token_type_ids": ex.token_type_ids,
|
input_names = ["input_ids"] + tokenizer.model_input_names
|
||||||
},
|
|
||||||
ex.label,
|
|
||||||
)
|
|
||||||
|
|
||||||
return tf.data.Dataset.from_generator(
|
return tf.data.Dataset.from_generator(
|
||||||
gen,
|
gen,
|
||||||
({"input_ids": tf.int32, "attention_mask": tf.int32, "token_type_ids": tf.int32}, tf.int64),
|
({k: tf.int32 for k in input_names}, tf.int64),
|
||||||
(
|
({k: tf.TensorShape([None]) for k in input_names}, tf.TensorShape([])),
|
||||||
{
|
|
||||||
"input_ids": tf.TensorShape([None]),
|
|
||||||
"attention_mask": tf.TensorShape([None]),
|
|
||||||
"token_type_ids": tf.TensorShape([None]),
|
|
||||||
},
|
|
||||||
tf.TensorShape([]),
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -389,6 +389,23 @@ def squad_convert_examples_to_features(
|
|||||||
|
|
||||||
def gen():
|
def gen():
|
||||||
for i, ex in enumerate(features):
|
for i, ex in enumerate(features):
|
||||||
|
if ex.token_type_ids is None:
|
||||||
|
yield (
|
||||||
|
{
|
||||||
|
"input_ids": ex.input_ids,
|
||||||
|
"attention_mask": ex.attention_mask,
|
||||||
|
"feature_index": i,
|
||||||
|
"qas_id": ex.qas_id,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"start_positions": ex.start_position,
|
||||||
|
"end_positions": ex.end_position,
|
||||||
|
"cls_index": ex.cls_index,
|
||||||
|
"p_mask": ex.p_mask,
|
||||||
|
"is_impossible": ex.is_impossible,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
yield (
|
yield (
|
||||||
{
|
{
|
||||||
"input_ids": ex.input_ids,
|
"input_ids": ex.input_ids,
|
||||||
@@ -407,6 +424,7 @@ def squad_convert_examples_to_features(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Why have we split the batch into a tuple? PyTorch just has a list of tensors.
|
# Why have we split the batch into a tuple? PyTorch just has a list of tensors.
|
||||||
|
if "token_type_ids" in tokenizer.model_input_names:
|
||||||
train_types = (
|
train_types = (
|
||||||
{
|
{
|
||||||
"input_ids": tf.int32,
|
"input_ids": tf.int32,
|
||||||
@@ -440,6 +458,33 @@ def squad_convert_examples_to_features(
|
|||||||
"is_impossible": tf.TensorShape([]),
|
"is_impossible": tf.TensorShape([]),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
train_types = (
|
||||||
|
{"input_ids": tf.int32, "attention_mask": tf.int32, "feature_index": tf.int64, "qas_id": tf.string},
|
||||||
|
{
|
||||||
|
"start_positions": tf.int64,
|
||||||
|
"end_positions": tf.int64,
|
||||||
|
"cls_index": tf.int64,
|
||||||
|
"p_mask": tf.int32,
|
||||||
|
"is_impossible": tf.int32,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
train_shapes = (
|
||||||
|
{
|
||||||
|
"input_ids": tf.TensorShape([None]),
|
||||||
|
"attention_mask": tf.TensorShape([None]),
|
||||||
|
"feature_index": tf.TensorShape([]),
|
||||||
|
"qas_id": tf.TensorShape([]),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"start_positions": tf.TensorShape([]),
|
||||||
|
"end_positions": tf.TensorShape([]),
|
||||||
|
"cls_index": tf.TensorShape([]),
|
||||||
|
"p_mask": tf.TensorShape([None]),
|
||||||
|
"is_impossible": tf.TensorShape([]),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
return tf.data.Dataset.from_generator(gen, train_types, train_shapes)
|
return tf.data.Dataset.from_generator(gen, train_types, train_shapes)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user