Last active
March 16, 2025 19:36
-
-
Save luoyetx/2907da4e9378781f2c0992eee1f61ea2 to your computer and use it in GitHub Desktop.
RecordDb is an esay interface to save (image, label) data to `mx.record`
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""RecordDb is an esay interface to save (image, label) data to `mx.record` | |
""" | |
import random | |
import mxnet as mx | |
class RecDb(object): | |
"""Interface to save (image, label) to `mx.record` | |
""" | |
def __init__(self, path, label_width, quality=100, max_buffer=1000, shuffle=True): | |
"""Constructor | |
Parameters | |
---------- | |
path: path to record file | |
label_width: label width | |
quality: JPEG quality for encoding | |
max_buffer: maximum size of the internal buffer | |
shuffle: shuffle the buffer before commit to record | |
""" | |
self.label_width = label_width | |
self.quality = quality | |
self.max_buffer = max_buffer | |
self.shuffle = shuffle | |
self.cur_idx = 0 | |
self.record = mx.recordio.MXRecordIO(path, 'w') | |
self.buffer = [] | |
self.is_closed = False | |
def __del__(self): | |
if not self.is_closed: | |
self.close() | |
def put(self, image, label): | |
"""Put an (image, label) to internal buffer, if the buffer size reach the buffer_size, | |
commit the buffer and clear | |
Parameters | |
---------- | |
image: image array or str | |
label: label array | |
""" | |
assert len(label) == self.label_width, "label width is wrong" | |
header = mx.recordio.IRHeader(0, label, self.cur_idx, 0) | |
if isinstance(image, str): | |
s = mx.recordio.pack(header, image) | |
else: | |
s = mx.recordio.pack_img(header, image, quality=self.quality, img_fmt='.jpg') | |
self.buffer.append(s) | |
if len(self.buffer) >= self.max_buffer: | |
self.commit() | |
def commit(self): | |
"""Write the buffer to rec file and clear the buffer | |
""" | |
if self.shuffle: | |
random.shuffle(self.buffer) | |
for s in self.buffer: | |
self.record.write(s) | |
del self.buffer[:] | |
def close(self): | |
"""Close the record | |
""" | |
if len(self.buffer) > 0: | |
self.commit() | |
self.record.close() | |
self.is_closed = True | |
def test(): | |
"""Test RecDb | |
""" | |
import os | |
import cv2 | |
import numpy as np | |
db = RecDb('tmp.rec', 3) | |
items = [] | |
for i in range(100): | |
image = np.ones((50, 100))*i | |
label = np.ones(3)*i | |
item = (image.astype(np.uint8), label) | |
items.append(item) | |
image_mean = np.zeros((50, 100)) | |
label_mean = np.zeros(3) | |
for image, label in items: | |
image_mean += image | |
label_mean += label | |
if np.random.rand() < 0.5: | |
s = cv2.imencode('.jpg', image, [cv2.IMWRITE_JPEG_QUALITY, 80])[1].tostring() | |
db.put(s, label) | |
else: | |
db.put(image, label) | |
image_mean /= len(items) | |
label_mean /= len(items) | |
db.close() | |
rec_iter = mx.io.ImageRecordIter(path_imgrec='tmp.rec', | |
label_width=3, | |
data_shape=(1, 50, 100), | |
batch_size=100, | |
shuffle=False) | |
image_mean_ = np.zeros((50, 100)) | |
label_mean_ = np.zeros(3) | |
for idx, batch in enumerate(rec_iter): | |
image, label = batch.data[0].asnumpy(), batch.label[0].asnumpy() | |
image_mean_ += image.reshape(100, 50, 100).sum(axis=0) | |
label_mean_ += label.sum(axis=0) | |
image_mean_ /= len(items) | |
label_mean_ /= len(items) | |
image_error = np.square(image_mean - image_mean_).mean() | |
label_error = np.square(label_mean - label_mean_).mean() | |
print 'image error:', image_error | |
print 'label error:', label_error | |
del rec_iter | |
os.remove('tmp.rec') | |
if __name__ == '__main__': | |
test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment