Fixed label datatype for STS-B (#6492)
* fixed label datatype for sts-b * naming update * make style * make style
This commit is contained in:
@@ -79,6 +79,7 @@ if is_tf_available():
|
|||||||
processor = glue_processors[task]()
|
processor = glue_processors[task]()
|
||||||
examples = [processor.tfds_map(processor.get_example_from_tensor_dict(example)) for example in examples]
|
examples = [processor.tfds_map(processor.get_example_from_tensor_dict(example)) for example in examples]
|
||||||
features = glue_convert_examples_to_features(examples, tokenizer, max_length=max_length, task=task)
|
features = glue_convert_examples_to_features(examples, tokenizer, max_length=max_length, task=task)
|
||||||
|
label_type = tf.float32 if task == "sts-b" else tf.int64
|
||||||
|
|
||||||
def gen():
|
def gen():
|
||||||
for ex in features:
|
for ex in features:
|
||||||
@@ -90,7 +91,7 @@ if is_tf_available():
|
|||||||
|
|
||||||
return tf.data.Dataset.from_generator(
|
return tf.data.Dataset.from_generator(
|
||||||
gen,
|
gen,
|
||||||
({k: tf.int32 for k in input_names}, tf.int64),
|
({k: tf.int32 for k in input_names}, label_type),
|
||||||
({k: tf.TensorShape([None]) for k in input_names}, tf.TensorShape([])),
|
({k: tf.TensorShape([None]) for k in input_names}, tf.TensorShape([])),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user