Created
February 10, 2021 21:21
-
-
Save nem0/2009e821026de1e5f102bb0dc15203bc to your computer and use it in GitHub Desktop.
radix sort
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
void radixSort2(u64* _keys, u64* _values, int size) { | |
enum { WORKERS = 10 }; | |
PROFILE_FUNCTION(); | |
profiler::pushInt("count", size); | |
if (size == 0) return; | |
Array<u64>& tmp_mem = allocRadixTmp(); | |
u64* keys = _keys; | |
u64* values = _values; | |
u64* tmp_keys = nullptr; | |
u64* tmp_values = nullptr; | |
Histogram histogram; | |
u16 shift = 0; | |
constexpr u64 BIT_MASK = (1 << 11) - 1; | |
tmp_mem.resize(size * 2); | |
tmp_keys = tmp_mem.begin(); | |
tmp_values = &tmp_mem[size]; | |
for (int pass = 0; pass < 6; ++pass) { | |
u32 histograms[WORKERS][1 << 11]; | |
const u32 step = (size + WORKERS - 1) / WORKERS; | |
jobs::forEach(WORKERS, 1, [&](i32 idx, i32) { | |
PROFILE_BLOCK("histo2"); | |
const u32 from = idx * step; | |
const u32 to = minimum(from + step, size); | |
u32* local_histogram = histograms[idx]; | |
memset(local_histogram, 0, sizeof(histograms[idx])); | |
for (u32 i = from; i < to; ++i) { | |
const u64 key = keys[i]; | |
const u64 index = (key >> shift) & BIT_MASK; | |
++local_histogram[index]; | |
} | |
}); | |
u32 offset = 0; | |
for (int i = 0; i < 1 << 11; ++i) { | |
for (u32 j = 0; j < lengthOf(histograms); ++j) { | |
const u32 count = histograms[j][i]; | |
histograms[j][i] = offset; | |
offset += count; | |
} | |
} | |
jobs::forEach(WORKERS, 1, [&](i32 idx, i32) { | |
PROFILE_BLOCK("scatter multi"); | |
const u32 from = idx * step; | |
const u32 to = minimum(from + step, size); | |
u32* local_histogram = histograms[idx]; | |
profiler::pushInt("count", to - from); | |
for (u32 i = from; i < to; ++i) { | |
const u64 key = keys[i]; | |
const u64 index = (key >> shift) & BIT_MASK; | |
const u32 dest = local_histogram[index]++; | |
ASSERT(dest < (u32)size); | |
tmp_keys[dest] = key; | |
tmp_values[dest] = values[i]; | |
} | |
}); | |
swap(tmp_keys, keys); | |
swap(tmp_values, values); | |
shift += Histogram::BITS; | |
} | |
releaseRadixTmp(tmp_mem); | |
} | |
void radixSort3(u64* _keys, u64* _values, int size) { | |
PROFILE_FUNCTION(); | |
profiler::pushInt("count", size); | |
if (size == 0) return; | |
Array<u64>& tmp_mem = allocRadixTmp(); | |
u64* keys = _keys; | |
u64* values = _values; | |
u64* tmp_keys = nullptr; | |
u64* tmp_values = nullptr; | |
Histogram histogram; | |
u16 shift = 0; | |
for (int pass = 0; pass < 6; ++pass) { | |
histogram.compute(keys, values, size, shift); | |
// if (histogram.m_sorted) { | |
// if (pass & 1) { | |
// memcpy(_keys, tmp_mem.begin(), tmp_mem.byte_size() / 2); | |
// memcpy(_values, &tmp_mem[size], tmp_mem.byte_size() / 2); | |
// } | |
// return; | |
//} | |
if (!tmp_keys) { | |
tmp_mem.resize(size * 2); | |
tmp_keys = tmp_mem.begin(); | |
tmp_values = &tmp_mem[size]; | |
} | |
u32 offset = 0; | |
for (int i = 0; i < Histogram::SIZE; ++i) { | |
const u32 count = histogram.m_histogram[i]; | |
histogram.m_histogram[i] = offset; | |
offset += count; | |
} | |
u32 foo[Histogram::SIZE]; | |
memcpy(foo, histogram.m_histogram + 1, sizeof(histogram.m_histogram) - sizeof(histogram.m_histogram[0])); | |
foo[Histogram::SIZE - 1] = size; | |
auto back_pass = [&]() { | |
PROFILE_BLOCK("back_pass"); | |
profiler::pushInt("pass", pass); | |
profiler::pushInt("count", size - size / 2); | |
u64* LUMIX_RESTRICT k = keys; | |
u64* LUMIX_RESTRICT v = values; | |
u64* LUMIX_RESTRICT tk = tmp_keys; | |
u64* LUMIX_RESTRICT tv = tmp_values; | |
u32* LUMIX_RESTRICT h = foo; | |
for (int i = size - 1; i >= size / 2; --i) { | |
const u64 key = k[i]; | |
const u16 index = (key >> shift) & Histogram::BIT_MASK; | |
const u32 dest = --h[index]; | |
tk[dest] = key; | |
tv[dest] = v[i]; | |
} | |
}; | |
jobs::SignalHandle signal = jobs::INVALID_HANDLE; | |
jobs::run( | |
&back_pass, | |
[](void* data) { | |
auto f = (decltype(back_pass)*)data; | |
(*f)(); | |
}, | |
&signal); | |
for (int i = 0; i < size / 2; ++i) { | |
const u64 key = keys[i]; | |
const u16 index = (key >> shift) & Histogram::BIT_MASK; | |
const u32 dest = histogram.m_histogram[index]++; | |
tmp_keys[dest] = key; | |
tmp_values[dest] = values[i]; | |
} | |
jobs::wait(signal); | |
swap(tmp_keys, keys); | |
swap(tmp_values, values); | |
shift += Histogram::BITS; | |
} | |
releaseRadixTmp(tmp_mem); | |
} | |
void radixSort(u64* _keys, u64* _values, int size) { | |
PROFILE_FUNCTION(); | |
profiler::pushInt("count", size); | |
if (size == 0) return; | |
Array<u64>& tmp_mem = allocRadixTmp(); | |
u64* keys = _keys; | |
u64* values = _values; | |
u64* tmp_keys = nullptr; | |
u64* tmp_values = nullptr; | |
Histogram histogram; | |
u16 shift = 0; | |
for (int pass = 0; pass < 6; ++pass) { | |
histogram.compute(keys, values, size, shift); | |
PROFILE_BLOCK("radix sort pass"); | |
profiler::pushInt("count", size); | |
// if (histogram.m_sorted) { | |
// if (pass & 1) { | |
// memcpy(_keys, tmp_mem.begin(), tmp_mem.byte_size() / 2); | |
// memcpy(_values, &tmp_mem[size], tmp_mem.byte_size() / 2); | |
// } | |
// return; | |
//} | |
if (!tmp_keys) { | |
tmp_mem.resize(size * 2); | |
tmp_keys = tmp_mem.begin(); | |
tmp_values = &tmp_mem[size]; | |
} | |
u32 offset = 0; | |
for (int i = 0; i < Histogram::SIZE; ++i) { | |
const u32 count = histogram.m_histogram[i]; | |
histogram.m_histogram[i] = offset; | |
offset += count; | |
} | |
{ | |
PROFILE_BLOCK("1sthalf"); | |
profiler::pushInt("count", size / 2); | |
for (int i = 0; i < size / 2; ++i) { | |
const u64 key = keys[i]; | |
const u16 index = (key >> shift) & Histogram::BIT_MASK; | |
const u32 dest = histogram.m_histogram[index]++; | |
tmp_keys[dest] = key; | |
tmp_values[dest] = values[i]; | |
} | |
} | |
for (int i = size / 2; i < size; ++i) { | |
const u64 key = keys[i]; | |
const u16 index = (key >> shift) & Histogram::BIT_MASK; | |
const u32 dest = histogram.m_histogram[index]++; | |
tmp_keys[dest] = key; | |
tmp_values[dest] = values[i]; | |
} | |
swap(tmp_keys, keys); | |
swap(tmp_values, values); | |
shift += Histogram::BITS; | |
} | |
releaseRadixTmp(tmp_mem); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment