Last active
March 31, 2023 17:05
-
-
Save shijianjian/93b52ae96137fcd3d63edd92eb3f5046 to your computer and use it in GitHub Desktop.
Convert hdf5 file to TFRecords.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
This converter is used to convert hdf5 file to TFRecords. | |
Dataset used: | |
This is designed for the point cloud hdf5 data of pointnet, | |
which can be downloaded from https://github.com/charlesq34/pointnet/sem_seg | |
The sample data shape is | |
h5py { | |
'data': (1000, 4096, 9), # (number_of_data, points, channels) | |
'label': (1000, 4096) # (number_of_data, label_of_points) | |
} | |
If you want modify this for your own hdf5 data, | |
the only thing you need to modify is "get_feature(point_cloud, label)" function | |
''' | |
import h5py | |
import tensorflow as tf | |
# For array storage, TFRecords will only support list storage or 1-D array storage | |
# If you have multi-dimensional array, please start with: | |
# array = array.reshape(-1) | |
def _int64_feature(value): | |
return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) | |
def _floats_feature(value): | |
return tf.train.Feature(float_list=tf.train.FloatList(value=value)) | |
def get_all_keys_from_h5(h5_file): | |
res = [] | |
for key in h5_file.keys(): | |
res.append(key) | |
return res | |
# details of 9-D vector: https://github.com/charlesq34/pointnet/issues/7 | |
def get_feature(point_cloud, label): | |
return { | |
'x': _floats_feature(point_cloud[:, 0]), | |
'y': _floats_feature(point_cloud[:, 1]), | |
'z': _floats_feature(point_cloud[:, 2]), | |
'norm_x': _floats_feature(point_cloud[:, 3]), | |
'norm_y': _floats_feature(point_cloud[:, 4]), | |
'norm_z': _floats_feature(point_cloud[:, 5]), | |
'r': _floats_feature(point_cloud[:, 6]), | |
'g': _floats_feature(point_cloud[:, 7]), | |
'b': _floats_feature(point_cloud[:, 8]), | |
'label': _int64_feature(label) | |
} | |
def h5_to_tfrecord_converter(input_file_path, output_file_path): | |
h5_file = h5py.File(input_file_path) | |
keys = get_all_keys_from_h5(h5_file) | |
num_of_items = h5_file[keys[0]][:].shape[0] | |
# Check the number of values in each key | |
for key in keys: | |
if h5_file[key][:].shape[0] != num_of_items: | |
raise ValueError('Invalid values. The inequality of the number of values in each key.') | |
with tf.python_io.TFRecordWriter(output_file_path) as writer: | |
for index in range(num_of_items): | |
example = tf.train.Example( | |
features=tf.train.Features( | |
feature = get_feature(h5_file[keys[0]][index], h5_file[keys[1]][index]) | |
)) | |
writer.write(example.SerializeToString()) | |
print('\r{:.1%}'.format((index+1)/num_of_items), end='') | |
# With commandline enabled | |
if __name__ == "__main__": | |
import argparse | |
import os | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--input-file-path', required=True, help='Path to the input HDF5 file.') | |
parser.add_argument('--output-file-path', default='', help='Path to the output TFRecords.') | |
parser.add_argument('-r', action='store_true', help='Recursively find *.h5 files under pointed folder. This will not dive deeper into sub-folders.') | |
FLAGS = parser.parse_args() | |
INPUT_PATH = os.path.abspath(FLAGS.input_file_path) | |
OUTPUT_PATH = FLAGS.output_file_path | |
RECURSIVE = FLAGS.r | |
if not INPUT_PATH.endswith('.h5') and not RECURSIVE: | |
raise ValueError('Not a valid HDF5 file provided, you may want to add -r.') | |
elif INPUT_PATH.endswith('.h5'): | |
if OUTPUT_PATH == '': | |
OUTPUT_PATH = INPUT_PATH[:-3] | |
print('Start converting...\t') | |
h5_to_tfrecord_converter(INPUT_PATH, os.path.abspath(OUTPUT_PATH) + '.tfrecord') | |
elif RECURSIVE: | |
files = [] | |
if OUTPUT_PATH == '': | |
OUTPUT_PATH = INPUT_PATH | |
for _file in os.listdir(INPUT_PATH): | |
if _file.endswith('.h5'): | |
files.append(( | |
os.path.join(INPUT_PATH, _file[:-3]), | |
os.path.join(os.path.abspath(OUTPUT_PATH), _file[:-3]), | |
_file | |
)) | |
print(len(files), 'of HDF5 file detected.') | |
for idx, (_input, _output, _file_name) in enumerate(files): | |
print('\n\ton job %d/%d, %s' % (idx, len(files), _file_name), end='') | |
h5_to_tfrecord_converter(_input + '.h5', _output + '.tfrecord') | |
else: | |
pass | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
When actually using it, we will feed it as iterator that would not take a lot of memories of your machine. | |
''' | |
def tfrecords_to_dataset(handle): | |
''' | |
Note: We can't excplicitly select what data to use when using tf.data.Dataset | |
Hence we will separate it manually like 0.2 or so. | |
''' | |
files = [ PATH_TO_EACH_TFRECORD ] | |
TEST_PERCENT = 0.2 | |
train_dataset = tf.data.TFRecordDataset(files[int(TEST_PERCENT*len(files)): ]) | |
train_dataset = train_dataset.map(feature_retrieval) | |
train_dataset = train_dataset.apply(tf.contrib.data.batch_and_drop_remainder(BATCH_SIZE)) | |
test_dataset = tf.data.TFRecordDataset(files[ : int(TEST_PERCENT*len(files))]) | |
test_dataset = test_dataset.map(feature_retrieval) | |
test_dataset = test_dataset.apply(tf.contrib.data.batch_and_drop_remainder(BATCH_SIZE)) | |
iterator = tf.data.Iterator.from_string_handle( | |
handle, | |
train_dataset.output_types, | |
train_dataset.output_shapes) | |
next_elem = iterator.get_next() | |
train_init_iter = train_dataset.make_initializable_iterator() | |
test_init_iter = test_dataset.make_initializable_iterator() | |
return next_elem, train_init_iter, test_init_iter | |
if __name__ == '__main__': | |
handle_pl = tf.placeholder(tf.string, shape=[]) | |
next_value, train_init_iter, test_init_iter = tfrecords_to_dataset(handle_pl) | |
pointclouds_pl, labels_pl = next_value | |
loss = YOUR_MODEL(pointclouds_pl, labels_pl) | |
# Training | |
with tf.Session() as sess: | |
training_handle = sess.run(train_init_iter.string_handle()) | |
sess.run(train_init_iter.initializer) | |
sess.run(loss, feed_dict={'handle_pl' = training_handle}) | |
with tf.Session() as sess: | |
testing_handle = sess.run(test_init_iter.string_handle()) | |
sess.run(test_init_iter.initializer) | |
sess.run(loss, feed_dict={'handle_pl' = testing_handle}) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
If you are interested in what inside your tfrecords. | |
Your can print it out like below. | |
''' | |
def get_tfrecords_features(): | |
return { | |
'x': tf.FixedLenFeature([4096], tf.float32), | |
'y': tf.FixedLenFeature([4096], tf.float32), | |
'z': tf.FixedLenFeature([4096], tf.float32), | |
'norm_x': tf.FixedLenFeature([4096], tf.float32), | |
'norm_y': tf.FixedLenFeature([4096], tf.float32), | |
'norm_z': tf.FixedLenFeature([4096], tf.float32), | |
'r': tf.FixedLenFeature([4096], tf.float32), | |
'g': tf.FixedLenFeature([4096], tf.float32), | |
'b': tf.FixedLenFeature([4096], tf.float32), | |
'label': tf.FixedLenFeature([4096], tf.int64) | |
} | |
def load_tfrecords(tfrecords_filepath): | |
''' | |
Input: | |
Path to a tfrecord file | |
Returns: | |
Tensor | |
''' | |
items = [] | |
labels = [] | |
print("Loading %s" % tfrecords_filepath) | |
with tf.Session() as sess: | |
for serialized_example in tf.python_io.tf_record_iterator(tfrecords_filepath): | |
data, label = feature_retrieval(serialized_example) | |
items.append(data) | |
labels.append(label) | |
print("Finished Loading %s" % tfrecords_filepath) | |
return (tf.stack(items), tf.stack(labels)) | |
def feature_retrieval(serialized_example): | |
features = tf.parse_single_example(serialized_example, features=get_tfrecords_features()) | |
_x = tf.cast(features['x'], tf.float32) | |
_y = tf.cast(features['y'], tf.float32) | |
_z = tf.cast(features['z'], tf.float32) | |
_norm_x = tf.cast(features['norm_x'], tf.float32) | |
_norm_y = tf.cast(features['norm_y'], tf.float32) | |
_norm_z = tf.cast(features['norm_z'], tf.float32) | |
_r = tf.cast(features['r'], tf.float32) | |
_g = tf.cast(features['g'], tf.float32) | |
_b = tf.cast(features['b'], tf.float32) | |
_label = tf.cast(features['label'], tf.int64) | |
data = tf.transpose( | |
tf.stack( | |
[ | |
_x, | |
_y, | |
_z, | |
_norm_x, | |
_norm_y, | |
_norm_z, | |
_r, | |
_g, | |
_b | |
]) | |
) | |
label = tf.transpose(_label) | |
return data, label | |
if __name__ == '__main__': | |
data, label = load_tfrecords(tfrecords_filepath) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment