Last active
October 22, 2020 18:10
-
-
Save Smthri/ef45c324d57f00eb8f3314aeb0601ce0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
template <typename scalar_t, typename idx_t> | |
void softmaxLUT_kernel(scalar_t* input, scalar_t* output, idx_t n, idx_t c, idx_t sizeC, | |
idx_t input_sN, idx_t input_sC, idx_t output_sN, idx_t output_sC, | |
idx_t multiplier, idx_t zero_point, idx_t* exp_table) { | |
/* | |
* Perform quantized softmax for one element | |
* | |
* output_i = multiplier / sum (table[A_j - A_i + 255] for all j) | |
*/ | |
idx_t A_i = static_cast<idx_t>(input[n * input_sN + c * input_sC]); | |
scalar_t* input_ = input + n * input_sN; | |
idx_t exp_sum = 0; | |
for (idx_t i = 0; i < sizeC; i += input_sC) { | |
idx_t A_j = static_cast<idx_t>(input_[i]); | |
exp_sum += exp_table[A_j - A_i + 255]; | |
} | |
idx_t out = multiplier / exp_sum + zero_point; | |
idx_t out_max = std::numeric_limits<scalar_t>::max(); | |
idx_t out_min = std::numeric_limits<scalar_t>::min(); | |
out = out < out_min ? out_min : out > out_max ? out_max : out; | |
output[n * output_sN + c * output_sC] = static_cast<scalar_t>(out); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment