Skip to content

Instantly share code, notes, and snippets.

@Smthri
Last active October 22, 2020 18:10
Show Gist options
  • Save Smthri/ef45c324d57f00eb8f3314aeb0601ce0 to your computer and use it in GitHub Desktop.
Save Smthri/ef45c324d57f00eb8f3314aeb0601ce0 to your computer and use it in GitHub Desktop.
template <typename scalar_t, typename idx_t>
void softmaxLUT_kernel(scalar_t* input, scalar_t* output, idx_t n, idx_t c, idx_t sizeC,
idx_t input_sN, idx_t input_sC, idx_t output_sN, idx_t output_sC,
idx_t multiplier, idx_t zero_point, idx_t* exp_table) {
/*
* Perform quantized softmax for one element
*
* output_i = multiplier / sum (table[A_j - A_i + 255] for all j)
*/
idx_t A_i = static_cast<idx_t>(input[n * input_sN + c * input_sC]);
scalar_t* input_ = input + n * input_sN;
idx_t exp_sum = 0;
for (idx_t i = 0; i < sizeC; i += input_sC) {
idx_t A_j = static_cast<idx_t>(input_[i]);
exp_sum += exp_table[A_j - A_i + 255];
}
idx_t out = multiplier / exp_sum + zero_point;
idx_t out_max = std::numeric_limits<scalar_t>::max();
idx_t out_min = std::numeric_limits<scalar_t>::min();
out = out < out_min ? out_min : out > out_max ? out_max : out;
output[n * output_sN + c * output_sC] = static_cast<scalar_t>(out);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment