Fixed label datatype for STS-B (#6492)
* fixed label datatype for sts-b * naming update * make style * make style
This commit is contained in:
@@ -79,6 +79,7 @@ if is_tf_available():
|
||||
processor = glue_processors[task]()
|
||||
examples = [processor.tfds_map(processor.get_example_from_tensor_dict(example)) for example in examples]
|
||||
features = glue_convert_examples_to_features(examples, tokenizer, max_length=max_length, task=task)
|
||||
label_type = tf.float32 if task == "sts-b" else tf.int64
|
||||
|
||||
def gen():
|
||||
for ex in features:
|
||||
@@ -90,7 +91,7 @@ if is_tf_available():
|
||||
|
||||
return tf.data.Dataset.from_generator(
|
||||
gen,
|
||||
({k: tf.int32 for k in input_names}, tf.int64),
|
||||
({k: tf.int32 for k in input_names}, label_type),
|
||||
({k: tf.TensorShape([None]) for k in input_names}, tf.TensorShape([])),
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user