Skip to content

Instantly share code, notes, and snippets.

@luoyetx
Last active March 16, 2025 19:36
Show Gist options
  • Save luoyetx/2907da4e9378781f2c0992eee1f61ea2 to your computer and use it in GitHub Desktop.
Save luoyetx/2907da4e9378781f2c0992eee1f61ea2 to your computer and use it in GitHub Desktop.
RecordDb is an esay interface to save (image, label) data to `mx.record`
"""RecordDb is an esay interface to save (image, label) data to `mx.record`
"""
import random
import mxnet as mx
class RecDb(object):
"""Interface to save (image, label) to `mx.record`
"""
def __init__(self, path, label_width, quality=100, max_buffer=1000, shuffle=True):
"""Constructor
Parameters
----------
path: path to record file
label_width: label width
quality: JPEG quality for encoding
max_buffer: maximum size of the internal buffer
shuffle: shuffle the buffer before commit to record
"""
self.label_width = label_width
self.quality = quality
self.max_buffer = max_buffer
self.shuffle = shuffle
self.cur_idx = 0
self.record = mx.recordio.MXRecordIO(path, 'w')
self.buffer = []
self.is_closed = False
def __del__(self):
if not self.is_closed:
self.close()
def put(self, image, label):
"""Put an (image, label) to internal buffer, if the buffer size reach the buffer_size,
commit the buffer and clear
Parameters
----------
image: image array or str
label: label array
"""
assert len(label) == self.label_width, "label width is wrong"
header = mx.recordio.IRHeader(0, label, self.cur_idx, 0)
if isinstance(image, str):
s = mx.recordio.pack(header, image)
else:
s = mx.recordio.pack_img(header, image, quality=self.quality, img_fmt='.jpg')
self.buffer.append(s)
if len(self.buffer) >= self.max_buffer:
self.commit()
def commit(self):
"""Write the buffer to rec file and clear the buffer
"""
if self.shuffle:
random.shuffle(self.buffer)
for s in self.buffer:
self.record.write(s)
del self.buffer[:]
def close(self):
"""Close the record
"""
if len(self.buffer) > 0:
self.commit()
self.record.close()
self.is_closed = True
def test():
"""Test RecDb
"""
import os
import cv2
import numpy as np
db = RecDb('tmp.rec', 3)
items = []
for i in range(100):
image = np.ones((50, 100))*i
label = np.ones(3)*i
item = (image.astype(np.uint8), label)
items.append(item)
image_mean = np.zeros((50, 100))
label_mean = np.zeros(3)
for image, label in items:
image_mean += image
label_mean += label
if np.random.rand() < 0.5:
s = cv2.imencode('.jpg', image, [cv2.IMWRITE_JPEG_QUALITY, 80])[1].tostring()
db.put(s, label)
else:
db.put(image, label)
image_mean /= len(items)
label_mean /= len(items)
db.close()
rec_iter = mx.io.ImageRecordIter(path_imgrec='tmp.rec',
label_width=3,
data_shape=(1, 50, 100),
batch_size=100,
shuffle=False)
image_mean_ = np.zeros((50, 100))
label_mean_ = np.zeros(3)
for idx, batch in enumerate(rec_iter):
image, label = batch.data[0].asnumpy(), batch.label[0].asnumpy()
image_mean_ += image.reshape(100, 50, 100).sum(axis=0)
label_mean_ += label.sum(axis=0)
image_mean_ /= len(items)
label_mean_ /= len(items)
image_error = np.square(image_mean - image_mean_).mean()
label_error = np.square(label_mean - label_mean_).mean()
print 'image error:', image_error
print 'label error:', label_error
del rec_iter
os.remove('tmp.rec')
if __name__ == '__main__':
test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment