Last active
December 15, 2025 16:16
-
-
Save lmmx/6fb7cca550be89f81f83c43efb1381cf to your computer and use it in GitHub Desktop.
Build ONNX runtime (all of it) wheel for Ampere (sm_86) only, reducing .so weight from 415 MB ⇒ 82 MB (wheel: 300M ⇒ 48M). Needs 1 edited ops file
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/bin/bash | |
| set -e | |
| cd $HOME/lab/ort/build | |
| # Clean previous attempt | |
| rm -rf onnxruntime | |
| git clone --recursive --depth 1 --branch v1.23.0 https://github.com/microsoft/onnxruntime | |
| cp edited_ops/gather_block_quantized.cu onnxruntime/onnxruntime/contrib_ops/cuda/quantization/gather_block_quantized.cu | |
| cd onnxruntime | |
| ./build.sh \ | |
| --config Release \ | |
| --build_shared_lib \ | |
| --parallel \ | |
| --use_tensorrt \ | |
| --tensorrt_home /usr \ | |
| --use_cuda \ | |
| --cuda_home /usr/lib/nvidia-cuda-toolkit \ | |
| --cudnn_home /usr \ | |
| --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES="86" \ | |
| --cmake_extra_defines onnxruntime_USE_FLASH_ATTENTION=OFF \ | |
| --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF \ | |
| --cmake_extra_defines onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION=OFF \ | |
| --cmake_extra_defines CMAKE_CXX_FLAGS="-Wno-error=attributes" \ | |
| --cmake_extra_defines CMAKE_CUDA_FLAGS="-Xcompiler=-Wno-error=attributes" \ | |
| --skip_tests \ | |
| --build_wheel | |
| ls -lh build/Linux/Release/dist/*.whl |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| wget https://developer.download.nvidia.com/compute/cudnn/9.17.0/local_installers/cudnn-local-repo-debian12-9.17.0_1.0-1_amd64.deb | |
| sudo dpkg -i cudnn-local-repo-debian12-9.17.0_1.0-1_amd64.deb | |
| sudo cp /var/cudnn-local-repo-debian12-9.17.0/cudnn-*-keyring.gpg /usr/share/keyrings/ | |
| sudo apt-get update | |
| sudo apt-get -y install cudnn9-cuda-12 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| // Copyright (c) Microsoft Corporation. All rights reserved. | |
| // Licensed under the MIT License. | |
| #include "core/providers/cuda/cuda_common.h" | |
| #include "core/providers/cuda/cu_inc/common.cuh" | |
| #include "gather_block_quantized.cuh" | |
| using namespace onnxruntime::cuda; | |
| namespace onnxruntime { | |
| namespace contrib { | |
| namespace cuda { | |
| template <typename T1> | |
| __device__ inline int64_t get_val(const T1* data, int64_t idx, int64_t bits, bool sign) { | |
| const uint32_t mask = (1U << bits) - 1; | |
| const int64_t elems_per_byte = 8 / bits; | |
| const int64_t byte_idx = idx / elems_per_byte; | |
| const int64_t bit_offset = (idx % elems_per_byte) * bits; | |
| const uint8_t byte = reinterpret_cast<const uint8_t*>(data)[byte_idx]; | |
| int64_t val = (byte >> bit_offset) & mask; | |
| // Sign-extend based on bit width | |
| if (sign) { | |
| if (val & (1 << (bits - 1))) { | |
| val |= -1LL << bits; | |
| } | |
| } | |
| return val; | |
| } | |
| template <typename T1, typename T2, typename Tind> | |
| __global__ void GatherBlockQuantizedKernel( | |
| const T1* data, // packed 4-bit codes, one code per element | |
| const Tind* indices, | |
| const T2* scales, // one float scale per block | |
| const T1* zero_points, // packed 4-bit zero-points, one per block | |
| T2* output, | |
| int64_t after_gather_dim, | |
| int64_t gather_axis_dim, | |
| int64_t ind_dim, | |
| int64_t bits, | |
| int64_t block_size, | |
| int64_t gather_axis, | |
| int64_t N, | |
| bool sign) { | |
| int64_t out_idx = blockDim.x * blockIdx.x + threadIdx.x; | |
| if (out_idx >= N) return; | |
| // compute which input element this thread corresponds to: | |
| int64_t idx_before = out_idx / (after_gather_dim * ind_dim); | |
| int64_t idx_after = out_idx % after_gather_dim; | |
| int64_t idx = (out_idx % (after_gather_dim * ind_dim)) / after_gather_dim; | |
| int64_t idx_at_g = indices[idx]; | |
| int64_t in_idx = idx_before * gather_axis_dim * after_gather_dim + idx_at_g * after_gather_dim + idx_after; | |
| int64_t block_id = in_idx / block_size; | |
| // unpack zero_point for this block: | |
| int64_t offset = 0; | |
| if (zero_points) { | |
| offset = get_val(zero_points, block_id, bits, sign); | |
| } | |
| // unpack the raw quantized code for this element: | |
| int64_t weight = get_val(data, in_idx, bits, sign); | |
| // apply dequantization: | |
| output[out_idx] = static_cast<T2>(static_cast<float>(weight - offset)) * scales[block_id]; | |
| } | |
| template <typename T1, typename T2, typename Tind> | |
| void LaunchGatherBlockQuantizedKernel(const T1* data, | |
| const Tind* indices, | |
| const T2* scales, | |
| const T1* zero_points, | |
| T2* output, | |
| GatherBlockQuantizedParam param) { | |
| // Require quant_axis is last dim | |
| int blocksPerGrid = (int)(ceil(static_cast<float>(param.N) / GridDim::maxThreadsPerBlock)); | |
| bool sign = std::is_same<T1, Int4x2>::value; | |
| GatherBlockQuantizedKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, param.stream>>>(data, indices, scales, zero_points, output, | |
| param.after_gather_dim, param.gather_axis_dim, param.ind_dim, param.bits, param.block_size, param.gather_axis, param.N, sign); | |
| } | |
| template void LaunchGatherBlockQuantizedKernel<uint8_t, float, int32_t>(const uint8_t*, const int32_t*, const float*, const uint8_t*, float*, GatherBlockQuantizedParam); | |
| template void LaunchGatherBlockQuantizedKernel<uint8_t, float, int64_t>(const uint8_t*, const int64_t*, const float*, const uint8_t*, float*, GatherBlockQuantizedParam); | |
| template void LaunchGatherBlockQuantizedKernel<UInt4x2, float, int32_t>(const UInt4x2*, const int32_t*, const float*, const UInt4x2*, float*, GatherBlockQuantizedParam); | |
| template void LaunchGatherBlockQuantizedKernel<UInt4x2, float, int64_t>(const UInt4x2*, const int64_t*, const float*, const UInt4x2*, float*, GatherBlockQuantizedParam); | |
| template void LaunchGatherBlockQuantizedKernel<Int4x2, float, int32_t>(const Int4x2*, const int32_t*, const float*, const Int4x2*, float*, GatherBlockQuantizedParam); | |
| template void LaunchGatherBlockQuantizedKernel<Int4x2, float, int64_t>(const Int4x2*, const int64_t*, const float*, const Int4x2*, float*, GatherBlockQuantizedParam); | |
| template void LaunchGatherBlockQuantizedKernel<uint8_t, half, int32_t>(const uint8_t*, const int32_t*, const half*, const uint8_t*, half*, GatherBlockQuantizedParam); | |
| template void LaunchGatherBlockQuantizedKernel<uint8_t, half, int64_t>(const uint8_t*, const int64_t*, const half*, const uint8_t*, half*, GatherBlockQuantizedParam); | |
| template void LaunchGatherBlockQuantizedKernel<UInt4x2, half, int32_t>(const UInt4x2*, const int32_t*, const half*, const UInt4x2*, half*, GatherBlockQuantizedParam); | |
| template void LaunchGatherBlockQuantizedKernel<UInt4x2, half, int64_t>(const UInt4x2*, const int64_t*, const half*, const UInt4x2*, half*, GatherBlockQuantizedParam); | |
| template void LaunchGatherBlockQuantizedKernel<Int4x2, half, int32_t>(const Int4x2*, const int32_t*, const half*, const Int4x2*, half*, GatherBlockQuantizedParam); | |
| template void LaunchGatherBlockQuantizedKernel<Int4x2, half, int64_t>(const Int4x2*, const int64_t*, const half*, const Int4x2*, half*, GatherBlockQuantizedParam); | |
| template void LaunchGatherBlockQuantizedKernel<uint8_t, BFloat16, int32_t>(const uint8_t*, const int32_t*, const BFloat16*, const uint8_t*, BFloat16*, GatherBlockQuantizedParam); | |
| template void LaunchGatherBlockQuantizedKernel<uint8_t, BFloat16, int64_t>(const uint8_t*, const int64_t*, const BFloat16*, const uint8_t*, BFloat16*, GatherBlockQuantizedParam); | |
| template void LaunchGatherBlockQuantizedKernel<UInt4x2, BFloat16, int32_t>(const UInt4x2*, const int32_t*, const BFloat16*, const UInt4x2*, BFloat16*, GatherBlockQuantizedParam); | |
| template void LaunchGatherBlockQuantizedKernel<UInt4x2, BFloat16, int64_t>(const UInt4x2*, const int64_t*, const BFloat16*, const UInt4x2*, BFloat16*, GatherBlockQuantizedParam); | |
| template void LaunchGatherBlockQuantizedKernel<Int4x2, BFloat16, int32_t>(const Int4x2*, const int32_t*, const BFloat16*, const Int4x2*, BFloat16*, GatherBlockQuantizedParam); | |
| template void LaunchGatherBlockQuantizedKernel<Int4x2, BFloat16, int64_t>(const Int4x2*, const int64_t*, const BFloat16*, const Int4x2*, BFloat16*, GatherBlockQuantizedParam); | |
| } // namespace cuda | |
| } // namespace contrib | |
| } // namespace onnxruntime |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment