Revisions
-
ngxson revised this gist
Jan 28, 2025 . 1 changed file with 5 additions and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,6 +1,10 @@ ### Why did you do this? Relax, I only have one Sunday to work on idea, literally my weekend project. So I tried Deepseek to see if it can help. Surprisingly, it works and it saves me another weekend... ### What is your setup? Just chat.deepseek.com (cost = free) with prompts adapted from this gist. ### Does it work in one-shot or I have to prompt it multiple times? -
ngxson revised this gist
Jan 28, 2025 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,4 +1,4 @@ Here is a function in C, you have to convert it to WASM SIMD 128. This function is prone to produce inaccuracy result, it is very sensitive. Only optimize part that you are absolutely sure. -
ngxson revised this gist
Jan 27, 2025 . 2 changed files with 653 additions and 0 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,63 @@ Here is a functionin C, you have to convert it to WASM SIMD 128. This function is prone to produce inaccuracy result, it is very sensitive. Only optimize part that you are absolutely sure. Think carefully about it. Hint, focus more on the loop with aux16, aux32 ``` void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); const block_q6_K * restrict x = vx; const block_q8_K * restrict y = vy; const int nb = n / QK_K; int8_t aux8[QK_K]; int16_t aux16[8]; float sums [8]; int32_t aux32[8]; memset(sums, 0, 8*sizeof(float)); float sumf = 0; for (int i = 0; i < nb; ++i) { const uint8_t * restrict q4 = x[i].ql; const uint8_t * restrict qh = x[i].qh; const int8_t * restrict q8 = y[i].qs; memset(aux32, 0, 8*sizeof(int32_t)); int8_t * restrict a = aux8; for (int j = 0; j < QK_K; j += 128) { for (int l = 0; l < 32; ++l) { a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; } a += 128; q4 += 64; qh += 32; } a = aux8; int is = 0; for (int j = 0; j < QK_K/16; ++j) { int scale = x[i].scales[is++]; for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; q8 += 8; a += 8; for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; q8 += 8; a += 8; } const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; } for (int l = 0; l < 8; ++l) sumf += sums[l]; *s = sumf; } ``` This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,590 @@ ## NOTE: This task (q6_K_q8_K) is HARD. The prompt below always produces failed result Your task is to convert a given C code SIMD to WASM SIMD. Here is an example of another function: ```c void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); const block_q4_K * restrict x = vx; const block_q8_K * restrict y = vy; const int nb = n / QK_K; static const uint32_t kmask1 = 0x3f3f3f3f; static const uint32_t kmask2 = 0x0f0f0f0f; static const uint32_t kmask3 = 0x03030303; uint32_t utmp[4]; #ifdef __ARM_NEON const uint8x16_t m4b = vdupq_n_u8(0xf); const int32x4_t mzero = vdupq_n_s32(0); ggml_int8x16x2_t q4bytes; ggml_int8x16x2_t q8bytes; float sumf = 0; for (int i = 0; i < nb; ++i) { const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); memcpy(utmp, x[i].scales, 12); uint32x2_t mins8 = { 0 }; mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0); mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1); utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); utmp[0] &= kmask1; const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8))); const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); sumf -= dmin * vaddvq_s32(prod); const uint8_t * scales = (const uint8_t *)utmp; const uint8_t * restrict q4 = x[i].qs; const int8_t * restrict q8 = y[i].qs; int32_t sumi1 = 0; int32_t sumi2 = 0; for (int j = 0; j < QK_K/64; ++j) { const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4); q4 += 32; q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32; q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); sumi1 += vaddvq_s32(p1) * scales[2*j+0]; q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32; q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); sumi2 += vaddvq_s32(p2) * scales[2*j+1]; } sumf += d * (sumi1 + sumi2); } *s = sumf; #elif defined(__wasm_simd128__) // WASM SIMD128 implementation const uint8_t * scales = (const uint8_t*)&utmp[0]; float sumf = 0; for (int i = 0; i < nb; ++i) { const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); // Corrected sign const uint8_t * restrict q4 = x[i].qs; const int8_t * restrict q8 = y[i].qs; // Process scales and mins memcpy(utmp, x[i].scales, 12); utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); const uint32_t uaux = utmp[1] & kmask1; utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); utmp[2] = uaux; utmp[0] &= kmask1; // Sum mins * q8sums int32_t sumi = 0; const int16_t * restrict q8sums = y[i].bsums; const uint8_t * m = (const uint8_t *)&utmp[2]; for (int j = 0; j < 16; j += 2) { sumi += (q8sums[j] + q8sums[j+1]) * m[j/2]; } sumf -= dmin * sumi; int32_t sumi1 = 0; int32_t sumi2 = 0; for (int j = 0; j < QK_K/64; ++j) { // Load 64 4-bit weights (32 bytes) const v128_t q4x0 = wasm_v128_load(q4); const v128_t q4x1 = wasm_v128_load(q4 + 16); q4 += 32; // Split into low/high nibbles const v128_t q4l0 = wasm_v128_and(q4x0, wasm_i8x16_splat(0x0F)); const v128_t q4h0 = wasm_u8x16_shr(q4x0, 4); const v128_t q4l1 = wasm_v128_and(q4x1, wasm_i8x16_splat(0x0F)); const v128_t q4h1 = wasm_u8x16_shr(q4x1, 4); // Load 64 8-bit values (64 bytes) const v128_t q8x0 = wasm_v128_load(q8); const v128_t q8x1 = wasm_v128_load(q8 + 16); const v128_t q8x2 = wasm_v128_load(q8 + 32); const v128_t q8x3 = wasm_v128_load(q8 + 48); q8 += 64; // Low nibble products v128_t vacc1 = wasm_i32x4_dot_i16x8( wasm_i16x8_extend_low_i8x16(q4l0), wasm_i16x8_extend_low_i8x16(q8x0) ); vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8( wasm_i16x8_extend_high_i8x16(q4l0), wasm_i16x8_extend_high_i8x16(q8x0) )); vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8( wasm_i16x8_extend_low_i8x16(q4l1), wasm_i16x8_extend_low_i8x16(q8x1) )); vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8( wasm_i16x8_extend_high_i8x16(q4l1), wasm_i16x8_extend_high_i8x16(q8x1) )); // High nibble products v128_t vacc2 = wasm_i32x4_dot_i16x8( wasm_i16x8_extend_low_i8x16(q4h0), wasm_i16x8_extend_low_i8x16(q8x2) ); vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8( wasm_i16x8_extend_high_i8x16(q4h0), wasm_i16x8_extend_high_i8x16(q8x2) )); vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8( wasm_i16x8_extend_low_i8x16(q4h1), wasm_i16x8_extend_low_i8x16(q8x3) )); vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8( wasm_i16x8_extend_high_i8x16(q4h1), wasm_i16x8_extend_high_i8x16(q8x3) )); // Accumulate scaled results int32_t vacc1_sum = wasm_i32x4_extract_lane(vacc1, 0) + wasm_i32x4_extract_lane(vacc1, 1) + wasm_i32x4_extract_lane(vacc1, 2) + wasm_i32x4_extract_lane(vacc1, 3); sumi1 += vacc1_sum * scales[2*j]; int32_t vacc2_sum = wasm_i32x4_extract_lane(vacc2, 0) + wasm_i32x4_extract_lane(vacc2, 1) + wasm_i32x4_extract_lane(vacc2, 2) + wasm_i32x4_extract_lane(vacc2, 3); sumi2 += vacc2_sum * scales[2*j+1]; } sumf += d * (sumi1 + sumi2); } *s = sumf; #elif defined __AVX__ const __m128i m4 = _mm_set1_epi8(0xF); const __m128i m2 = _mm_set1_epi8(0x2); __m256 acc = _mm256_setzero_ps(); __m128 acc_m = _mm_setzero_ps(); for (int i = 0; i < nb; ++i) { const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); const uint8_t * restrict q4 = x[i].qs; const int8_t * restrict q8 = y[i].qs; memcpy(utmp, x[i].scales, 12); utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); const uint32_t uaux = utmp[1] & kmask1; utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); utmp[2] = uaux; utmp[0] &= kmask1; const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]); const __m128i scales = _mm_cvtepu8_epi16(utmps); const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps)); const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]); const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]); const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1); const __m128i prod = _mm_madd_epi16(mins, q8s); acc_m = _mm_add_ps(_mm_mul_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod)), acc_m); __m128i sumi_0 = _mm_setzero_si128(); __m128i sumi_1 = _mm_setzero_si128(); __m128i shuffle = _mm_set1_epi16(0x0100); for (int j = 0; j < QK_K/64; ++j) { const __m128i scale_l = _mm_shuffle_epi8(scales, shuffle); shuffle = _mm_add_epi16(shuffle, m2); const __m128i scale_h = _mm_shuffle_epi8(scales, shuffle); shuffle = _mm_add_epi16(shuffle, m2); __m128i q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16; const __m128i q4l_0 = _mm_and_si128(q4bits, m4); const __m128i q4h_0 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4); q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16; const __m128i q4l_1 = _mm_and_si128(q4bits, m4); const __m128i q4h_1 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4); const __m128i q8l_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; __m128i p16l = _mm_maddubs_epi16(q4l_0, q8l_0); p16l = _mm_madd_epi16(scale_l, p16l); sumi_0 = _mm_add_epi32(sumi_0, p16l); const __m128i q8l_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; p16l = _mm_maddubs_epi16(q4l_1, q8l_1); p16l = _mm_madd_epi16(scale_l, p16l); sumi_1 = _mm_add_epi32(sumi_1, p16l); const __m128i q8h_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; __m128i p16h = _mm_maddubs_epi16(q4h_0, q8h_0); p16h = _mm_madd_epi16(scale_h, p16h); sumi_0 = _mm_add_epi32(sumi_0, p16h); const __m128i q8h_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; p16h = _mm_maddubs_epi16(q4h_1, q8h_1); p16h = _mm_madd_epi16(scale_h, p16h); sumi_1 = _mm_add_epi32(sumi_1, p16h); } __m256 vd = _mm256_set1_ps(d); __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc); } acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m)); acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m)); *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m); #else const uint8_t * scales = (const uint8_t*)&utmp[0]; const uint8_t * mins = (const uint8_t*)&utmp[2]; int8_t aux8[QK_K]; int16_t aux16[8]; float sums [8]; int32_t aux32[8]; memset(sums, 0, 8*sizeof(float)); float sumf = 0; for (int i = 0; i < nb; ++i) { const uint8_t * restrict q4 = x[i].qs; const int8_t * restrict q8 = y[i].qs; memset(aux32, 0, 8*sizeof(int32_t)); int8_t * restrict a = aux8; for (int j = 0; j < QK_K/64; ++j) { for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); a += 32; for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); a += 32; q4 += 32; } memcpy(utmp, x[i].scales, 12); utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); const uint32_t uaux = utmp[1] & kmask1; utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); utmp[2] = uaux; utmp[0] &= kmask1; int sumi = 0; for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; a = aux8; int is = 0; for (int j = 0; j < QK_K/32; ++j) { int32_t scale = scales[is++]; for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; q8 += 8; a += 8; for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; q8 += 8; a += 8; for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; q8 += 8; a += 8; for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; q8 += 8; a += 8; } const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; sumf -= dmin * sumi; } for (int l = 0; l < 8; ++l) sumf += sums[l]; *s = sumf; #endif } ```` Here is a function. You need to convert it to WASM SIMD. ```c void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); const block_q6_K * restrict x = vx; const block_q8_K * restrict y = vy; const int nb = n / QK_K; #ifdef __ARM_NEON float sum = 0; const uint8x16_t m4b = vdupq_n_u8(0xF); const int32x4_t vzero = vdupq_n_s32(0); //const int8x16_t m32s = vdupq_n_s8(32); const uint8x16_t mone = vdupq_n_u8(3); ggml_int8x16x4_t q6bytes; ggml_uint8x16x4_t q6h; for (int i = 0; i < nb; ++i) { const float d_all = GGML_FP16_TO_FP32(x[i].d); const uint8_t * restrict q6 = x[i].ql; const uint8_t * restrict qh = x[i].qh; const int8_t * restrict q8 = y[i].qs; const int8_t * restrict scale = x[i].scales; const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums); const int8x16_t scales = vld1q_s8(scale); const ggml_int16x8x2_t q6scales = {{vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}}; const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])), vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))), vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[1]), vget_low_s16 (q6scales.val[1])), vmull_s16(vget_high_s16(q8sums.val[1]), vget_high_s16(q6scales.val[1])))); int32_t isum_mins = vaddvq_s32(prod); int32_t isum = 0; for (int j = 0; j < QK_K/128; ++j) { ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); qh += 32; ggml_uint8x16x4_t q6bits = ggml_vld1q_u8_x4(q6); q6 += 64; ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64; q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); uint8x16_t shifted = vshrq_n_u8(qhbits.val[0], 2); q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); shifted = vshrq_n_u8(qhbits.val[1], 2); q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])); q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])); q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])); q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])); isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; scale += 4; q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64; shifted = vshrq_n_u8(qhbits.val[0], 4); q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4); shifted = vshrq_n_u8(qhbits.val[1], 4); q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4); shifted = vshrq_n_u8(qhbits.val[0], 6); q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); shifted = vshrq_n_u8(qhbits.val[1], 6); q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])); q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])); q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])); q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])); isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; scale += 4; } //sum += isum * d_all * y[i].d; sum += d_all * y[i].d * (isum - 32 * isum_mins); } *s = sum; #elif defined __AVX__ const __m128i m3 = _mm_set1_epi8(3); const __m128i m15 = _mm_set1_epi8(15); __m256 acc = _mm256_setzero_ps(); for (int i = 0; i < nb; ++i) { const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); const uint8_t * restrict q4 = x[i].ql; const uint8_t * restrict qh = x[i].qh; const int8_t * restrict q8 = y[i].qs; // handle the q6_k -32 offset separately using bsums const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)y[i].bsums); const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)y[i].bsums + 1); const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales); const __m128i scales_16_0 = _mm_cvtepi8_epi16(scales); const __m128i scales_16_1 = _mm_cvtepi8_epi16(_mm_bsrli_si128(scales, 8)); const __m128i q8sclsub_0 = _mm_slli_epi32(_mm_madd_epi16(q8sums_0, scales_16_0), 5); const __m128i q8sclsub_1 = _mm_slli_epi32(_mm_madd_epi16(q8sums_1, scales_16_1), 5); __m128i sumi_0 = _mm_setzero_si128(); __m128i sumi_1 = _mm_setzero_si128(); int is = 0; for (int j = 0; j < QK_K/128; ++j) { const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16; const __m128i q4bitsH_1 = _mm_loadu_si128((const __m128i*)qh); qh += 16; const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4); const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4); const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(12)), 2); const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(12)), 2); const __m128i q4h_4 = _mm_and_si128(q4bitsH_0, _mm_set1_epi8(48)); const __m128i q4h_5 = _mm_and_si128(q4bitsH_1, _mm_set1_epi8(48)); const __m128i q4h_6 = _mm_srli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(-64)), 2); const __m128i q4h_7 = _mm_srli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(-64)), 2); const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m15), q4h_0); const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m15), q4h_1); const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m15), q4h_2); const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m15), q4h_3); const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m15), q4h_4); const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m15), q4h_5); const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m15), q4h_6); const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m15), q4h_7); const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; __m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0); __m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1); __m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2); __m128i p16_3 = _mm_maddubs_epi16(q4_3, q8_3); __m128i p16_4 = _mm_maddubs_epi16(q4_4, q8_4); __m128i p16_5 = _mm_maddubs_epi16(q4_5, q8_5); __m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6); __m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7); const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0)); const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1)); const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2)); const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3)); is += 4; p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0); p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_0, 8)), p16_1); p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2); p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_1, 8)), p16_3); p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4); p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_2, 8)), p16_5); p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6); p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_3, 8)), p16_7); sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3)); sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_4, p16_6)); sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_5, p16_7)); } sumi_0 = _mm_sub_epi32(sumi_0, q8sclsub_0); sumi_1 = _mm_sub_epi32(sumi_1, q8sclsub_1); const __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi)), acc); } *s = hsum_float_8(acc); #else int8_t aux8[QK_K]; int16_t aux16[8]; float sums [8]; int32_t aux32[8]; memset(sums, 0, 8*sizeof(float)); float sumf = 0; for (int i = 0; i < nb; ++i) { const uint8_t * restrict q4 = x[i].ql; const uint8_t * restrict qh = x[i].qh; const int8_t * restrict q8 = y[i].qs; memset(aux32, 0, 8*sizeof(int32_t)); int8_t * restrict a = aux8; for (int j = 0; j < QK_K; j += 128) { for (int l = 0; l < 32; ++l) { a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; } a += 128; q4 += 64; qh += 32; } a = aux8; int is = 0; for (int j = 0; j < QK_K/16; ++j) { int scale = x[i].scales[is++]; for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; q8 += 8; a += 8; for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; q8 += 8; a += 8; } const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; } for (int l = 0; l < 8; ++l) sumf += sums[l]; *s = sumf; #endif } ``` You must start your code with `#elif defined(__wasm_simd128__)` To think about it, you need to take into account both the refenrence code from ARM NEON and AVX implementation. -
ngxson revised this gist
Jan 27, 2025 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -4,7 +4,7 @@ Just chat.deepseek.com with prompts adapted from this gist. ### Does it work in one-shot or I have to prompt it multiple times? - For the `qX_0` variants, they are actually quite straight-forward so deepseek can come up with a correct result in 1 shot. - For the `qX_K` it's more complicated, I would say most of the time I need to re-prompt it 4 to 8 more times. - The most difficult was `q6_K`, the code never works until I ask it to only optimize one specific part, while leaving the rest intact (so it does not mess up everything) -
ngxson revised this gist
Jan 27, 2025 . 1 changed file with 4 additions and 0 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,3 +1,7 @@ ### What is your setup? Just chat.deepseek.com with prompts adapted from this gist. ### Does it work in one-shot or I have to prompt it multiple times? - For the `qX_0` variants, they are actually quite straight-forward so deepseek can come up with a correct result in 1 shot. It is already crazy enough, given that ChatGPT and Claude have never produced a working result for me. -
ngxson revised this gist
Jan 27, 2025 . 1 changed file with 4 additions and 4 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -4,7 +4,7 @@ - For the `qX_K` it's more complicated, I would say most of the time I need to re-prompt it 4 to 8 more times. - The most difficult was `q6_K`, the code never works until I ask it to only optimize one specific part, while leaving the rest intact (so it does not mess up everything) ### It only does conversion ARM NEON --> WASM SIMD, or it can invent new WASM SIMD code from scratch? It can do both. For `qX_0` I asked it to convert, and for `qX_K` I asked it to invent new code. @@ -14,6 +14,6 @@ Around 3-5 minutes per response. ### Prompt is very long, what to do? You can condense the prompt to this format, then create a new conversation (see example at the bottom): 1. Problem description 2. "Here is your last failed attempt, improve from this: [paste the last generated code here]" -
ngxson revised this gist
Jan 27, 2025 . 1 changed file with 4 additions and 0 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -4,6 +4,10 @@ - For the `qX_K` it's more complicated, I would say most of the time I need to re-prompt it 4 to 8 more times. - The most difficult was `q6_K`, the code never works until I ask it to only optimize one specific part, while leaving the rest intact (so it does not mess up everything) ### It only does conversion ARM NEON --> WASM SIMD, or it can invent new WASM SIMD code from scratch?** It can do both. For `qX_0` I asked it to convert, and for `qX_K` I asked it to invent new code. ### How much time does it spent to think? Around 3-5 minutes per response. -
ngxson revised this gist
Jan 27, 2025 . 1 changed file with 3 additions and 3 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,14 +1,14 @@ ### Does it work in one-shot or I have to prompt it multiple times? - For the `qX_0` variants, they are actually quite straight-forward so deepseek can come up with a correct result in 1 shot. It is already crazy enough, given that ChatGPT and Claude have never produced a working result for me. - For the `qX_K` it's more complicated, I would say most of the time I need to re-prompt it 4 to 8 more times. - The most difficult was `q6_K`, the code never works until I ask it to only optimize one specific part, while leaving the rest intact (so it does not mess up everything) ### How much time does it spent to think? Around 3-5 minutes per response. ### Prompt is very long, what to do? You can condense the prompt to this format, then create a new conversation: - Problem description -
ngxson revised this gist
Jan 27, 2025 . 2 changed files with 16 additions and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,15 @@ **Does it work in one-shot or I have to prompt it multiple times?** - For the `qX_0` variants, they are actually quite straight-forward so deepseek can come up with a correct result in 1 shot. It is already crazy enough, given that ChatGPT and Claude have never produced a working result for me. - For the `qX_K` it's more complicated, I would say most of the time I need to re-prompt it 4 to 8 more times. - The most difficult was `q6_K`, the code never works until I ask it to only optimize one specific part, while leaving the rest intact (so it does not mess up everything) **How much time does it spent to think?** Around 3-5 minutes per response. **Prompt is very long, what to do?** You can condense the prompt to this format, then create a new conversation: - Problem description - "Here is your last failed attempt, improve from this: [paste the last generated code here]" This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,4 +1,4 @@ Here is a function in C, you have to convert it to WASM SIMD 128: ```c -
ngxson revised this gist
Jan 26, 2025 . 1 changed file with 21 additions and 0 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,21 @@ Here is a functionin C, you have to convert it to WASM SIMD 128: ```c ``` You must start your code with `#elif defined(__wasm_simd128__)` To think about it, you need to take into account both the refenrence code from ARM NEON and AVX implementation. Please note that this is YOUR last attempt. It compiles, but give inaccurate result. Improve from that. ```c ``` You must start your code with `#elif defined(__wasm_simd128__)` To think about it, you need to take into account both the refenrence code from ARM NEON, riscv_v_intrinsic and other implementation that you can see on the code. Make sure the output value is as accurate as possible. -
ngxson revised this gist
Jan 26, 2025 . 1 changed file with 588 additions and 0 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,588 @@ Your task is to convert a given C code SIMD to WASM SIMD. Here is an example of another function: ```c void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); const block_q4_K * restrict x = vx; const block_q8_K * restrict y = vy; const int nb = n / QK_K; static const uint32_t kmask1 = 0x3f3f3f3f; static const uint32_t kmask2 = 0x0f0f0f0f; static const uint32_t kmask3 = 0x03030303; uint32_t utmp[4]; #ifdef __ARM_NEON const uint8x16_t m4b = vdupq_n_u8(0xf); const int32x4_t mzero = vdupq_n_s32(0); ggml_int8x16x2_t q4bytes; ggml_int8x16x2_t q8bytes; float sumf = 0; for (int i = 0; i < nb; ++i) { const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); memcpy(utmp, x[i].scales, 12); uint32x2_t mins8 = { 0 }; mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0); mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1); utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); utmp[0] &= kmask1; const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8))); const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); sumf -= dmin * vaddvq_s32(prod); const uint8_t * scales = (const uint8_t *)utmp; const uint8_t * restrict q4 = x[i].qs; const int8_t * restrict q8 = y[i].qs; int32_t sumi1 = 0; int32_t sumi2 = 0; for (int j = 0; j < QK_K/64; ++j) { const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4); q4 += 32; q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32; q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); sumi1 += vaddvq_s32(p1) * scales[2*j+0]; q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32; q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); sumi2 += vaddvq_s32(p2) * scales[2*j+1]; } sumf += d * (sumi1 + sumi2); } *s = sumf; #elif defined(__wasm_simd128__) // WASM SIMD128 implementation const uint8_t * scales = (const uint8_t*)&utmp[0]; float sumf = 0; for (int i = 0; i < nb; ++i) { const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); // Corrected sign const uint8_t * restrict q4 = x[i].qs; const int8_t * restrict q8 = y[i].qs; // Process scales and mins memcpy(utmp, x[i].scales, 12); utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); const uint32_t uaux = utmp[1] & kmask1; utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); utmp[2] = uaux; utmp[0] &= kmask1; // Sum mins * q8sums int32_t sumi = 0; const int16_t * restrict q8sums = y[i].bsums; const uint8_t * m = (const uint8_t *)&utmp[2]; for (int j = 0; j < 16; j += 2) { sumi += (q8sums[j] + q8sums[j+1]) * m[j/2]; } sumf -= dmin * sumi; int32_t sumi1 = 0; int32_t sumi2 = 0; for (int j = 0; j < QK_K/64; ++j) { // Load 64 4-bit weights (32 bytes) const v128_t q4x0 = wasm_v128_load(q4); const v128_t q4x1 = wasm_v128_load(q4 + 16); q4 += 32; // Split into low/high nibbles const v128_t q4l0 = wasm_v128_and(q4x0, wasm_i8x16_splat(0x0F)); const v128_t q4h0 = wasm_u8x16_shr(q4x0, 4); const v128_t q4l1 = wasm_v128_and(q4x1, wasm_i8x16_splat(0x0F)); const v128_t q4h1 = wasm_u8x16_shr(q4x1, 4); // Load 64 8-bit values (64 bytes) const v128_t q8x0 = wasm_v128_load(q8); const v128_t q8x1 = wasm_v128_load(q8 + 16); const v128_t q8x2 = wasm_v128_load(q8 + 32); const v128_t q8x3 = wasm_v128_load(q8 + 48); q8 += 64; // Low nibble products v128_t vacc1 = wasm_i32x4_dot_i16x8( wasm_i16x8_extend_low_i8x16(q4l0), wasm_i16x8_extend_low_i8x16(q8x0) ); vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8( wasm_i16x8_extend_high_i8x16(q4l0), wasm_i16x8_extend_high_i8x16(q8x0) )); vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8( wasm_i16x8_extend_low_i8x16(q4l1), wasm_i16x8_extend_low_i8x16(q8x1) )); vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8( wasm_i16x8_extend_high_i8x16(q4l1), wasm_i16x8_extend_high_i8x16(q8x1) )); // High nibble products v128_t vacc2 = wasm_i32x4_dot_i16x8( wasm_i16x8_extend_low_i8x16(q4h0), wasm_i16x8_extend_low_i8x16(q8x2) ); vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8( wasm_i16x8_extend_high_i8x16(q4h0), wasm_i16x8_extend_high_i8x16(q8x2) )); vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8( wasm_i16x8_extend_low_i8x16(q4h1), wasm_i16x8_extend_low_i8x16(q8x3) )); vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8( wasm_i16x8_extend_high_i8x16(q4h1), wasm_i16x8_extend_high_i8x16(q8x3) )); // Accumulate scaled results int32_t vacc1_sum = wasm_i32x4_extract_lane(vacc1, 0) + wasm_i32x4_extract_lane(vacc1, 1) + wasm_i32x4_extract_lane(vacc1, 2) + wasm_i32x4_extract_lane(vacc1, 3); sumi1 += vacc1_sum * scales[2*j]; int32_t vacc2_sum = wasm_i32x4_extract_lane(vacc2, 0) + wasm_i32x4_extract_lane(vacc2, 1) + wasm_i32x4_extract_lane(vacc2, 2) + wasm_i32x4_extract_lane(vacc2, 3); sumi2 += vacc2_sum * scales[2*j+1]; } sumf += d * (sumi1 + sumi2); } *s = sumf; #elif defined __AVX__ const __m128i m4 = _mm_set1_epi8(0xF); const __m128i m2 = _mm_set1_epi8(0x2); __m256 acc = _mm256_setzero_ps(); __m128 acc_m = _mm_setzero_ps(); for (int i = 0; i < nb; ++i) { const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); const uint8_t * restrict q4 = x[i].qs; const int8_t * restrict q8 = y[i].qs; memcpy(utmp, x[i].scales, 12); utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); const uint32_t uaux = utmp[1] & kmask1; utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); utmp[2] = uaux; utmp[0] &= kmask1; const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]); const __m128i scales = _mm_cvtepu8_epi16(utmps); const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps)); const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]); const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]); const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1); const __m128i prod = _mm_madd_epi16(mins, q8s); acc_m = _mm_add_ps(_mm_mul_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod)), acc_m); __m128i sumi_0 = _mm_setzero_si128(); __m128i sumi_1 = _mm_setzero_si128(); __m128i shuffle = _mm_set1_epi16(0x0100); for (int j = 0; j < QK_K/64; ++j) { const __m128i scale_l = _mm_shuffle_epi8(scales, shuffle); shuffle = _mm_add_epi16(shuffle, m2); const __m128i scale_h = _mm_shuffle_epi8(scales, shuffle); shuffle = _mm_add_epi16(shuffle, m2); __m128i q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16; const __m128i q4l_0 = _mm_and_si128(q4bits, m4); const __m128i q4h_0 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4); q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16; const __m128i q4l_1 = _mm_and_si128(q4bits, m4); const __m128i q4h_1 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4); const __m128i q8l_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; __m128i p16l = _mm_maddubs_epi16(q4l_0, q8l_0); p16l = _mm_madd_epi16(scale_l, p16l); sumi_0 = _mm_add_epi32(sumi_0, p16l); const __m128i q8l_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; p16l = _mm_maddubs_epi16(q4l_1, q8l_1); p16l = _mm_madd_epi16(scale_l, p16l); sumi_1 = _mm_add_epi32(sumi_1, p16l); const __m128i q8h_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; __m128i p16h = _mm_maddubs_epi16(q4h_0, q8h_0); p16h = _mm_madd_epi16(scale_h, p16h); sumi_0 = _mm_add_epi32(sumi_0, p16h); const __m128i q8h_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; p16h = _mm_maddubs_epi16(q4h_1, q8h_1); p16h = _mm_madd_epi16(scale_h, p16h); sumi_1 = _mm_add_epi32(sumi_1, p16h); } __m256 vd = _mm256_set1_ps(d); __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc); } acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m)); acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m)); *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m); #else const uint8_t * scales = (const uint8_t*)&utmp[0]; const uint8_t * mins = (const uint8_t*)&utmp[2]; int8_t aux8[QK_K]; int16_t aux16[8]; float sums [8]; int32_t aux32[8]; memset(sums, 0, 8*sizeof(float)); float sumf = 0; for (int i = 0; i < nb; ++i) { const uint8_t * restrict q4 = x[i].qs; const int8_t * restrict q8 = y[i].qs; memset(aux32, 0, 8*sizeof(int32_t)); int8_t * restrict a = aux8; for (int j = 0; j < QK_K/64; ++j) { for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); a += 32; for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); a += 32; q4 += 32; } memcpy(utmp, x[i].scales, 12); utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); const uint32_t uaux = utmp[1] & kmask1; utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); utmp[2] = uaux; utmp[0] &= kmask1; int sumi = 0; for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; a = aux8; int is = 0; for (int j = 0; j < QK_K/32; ++j) { int32_t scale = scales[is++]; for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; q8 += 8; a += 8; for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; q8 += 8; a += 8; for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; q8 += 8; a += 8; for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; q8 += 8; a += 8; } const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; sumf -= dmin * sumi; } for (int l = 0; l < 8; ++l) sumf += sums[l]; *s = sumf; #endif } ```` Here is a function. You need to convert it to WASM SIMD. ```c void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); const block_q6_K * restrict x = vx; const block_q8_K * restrict y = vy; const int nb = n / QK_K; #ifdef __ARM_NEON float sum = 0; const uint8x16_t m4b = vdupq_n_u8(0xF); const int32x4_t vzero = vdupq_n_s32(0); //const int8x16_t m32s = vdupq_n_s8(32); const uint8x16_t mone = vdupq_n_u8(3); ggml_int8x16x4_t q6bytes; ggml_uint8x16x4_t q6h; for (int i = 0; i < nb; ++i) { const float d_all = GGML_FP16_TO_FP32(x[i].d); const uint8_t * restrict q6 = x[i].ql; const uint8_t * restrict qh = x[i].qh; const int8_t * restrict q8 = y[i].qs; const int8_t * restrict scale = x[i].scales; const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums); const int8x16_t scales = vld1q_s8(scale); const ggml_int16x8x2_t q6scales = {{vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}}; const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])), vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))), vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[1]), vget_low_s16 (q6scales.val[1])), vmull_s16(vget_high_s16(q8sums.val[1]), vget_high_s16(q6scales.val[1])))); int32_t isum_mins = vaddvq_s32(prod); int32_t isum = 0; for (int j = 0; j < QK_K/128; ++j) { ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); qh += 32; ggml_uint8x16x4_t q6bits = ggml_vld1q_u8_x4(q6); q6 += 64; ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64; q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); uint8x16_t shifted = vshrq_n_u8(qhbits.val[0], 2); q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); shifted = vshrq_n_u8(qhbits.val[1], 2); q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])); q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])); q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])); q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])); isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; scale += 4; q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64; shifted = vshrq_n_u8(qhbits.val[0], 4); q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4); shifted = vshrq_n_u8(qhbits.val[1], 4); q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4); shifted = vshrq_n_u8(qhbits.val[0], 6); q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); shifted = vshrq_n_u8(qhbits.val[1], 6); q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])); q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])); q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])); q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])); isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; scale += 4; } //sum += isum * d_all * y[i].d; sum += d_all * y[i].d * (isum - 32 * isum_mins); } *s = sum; #elif defined __AVX__ const __m128i m3 = _mm_set1_epi8(3); const __m128i m15 = _mm_set1_epi8(15); __m256 acc = _mm256_setzero_ps(); for (int i = 0; i < nb; ++i) { const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); const uint8_t * restrict q4 = x[i].ql; const uint8_t * restrict qh = x[i].qh; const int8_t * restrict q8 = y[i].qs; // handle the q6_k -32 offset separately using bsums const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)y[i].bsums); const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)y[i].bsums + 1); const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales); const __m128i scales_16_0 = _mm_cvtepi8_epi16(scales); const __m128i scales_16_1 = _mm_cvtepi8_epi16(_mm_bsrli_si128(scales, 8)); const __m128i q8sclsub_0 = _mm_slli_epi32(_mm_madd_epi16(q8sums_0, scales_16_0), 5); const __m128i q8sclsub_1 = _mm_slli_epi32(_mm_madd_epi16(q8sums_1, scales_16_1), 5); __m128i sumi_0 = _mm_setzero_si128(); __m128i sumi_1 = _mm_setzero_si128(); int is = 0; for (int j = 0; j < QK_K/128; ++j) { const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16; const __m128i q4bitsH_1 = _mm_loadu_si128((const __m128i*)qh); qh += 16; const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4); const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4); const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(12)), 2); const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(12)), 2); const __m128i q4h_4 = _mm_and_si128(q4bitsH_0, _mm_set1_epi8(48)); const __m128i q4h_5 = _mm_and_si128(q4bitsH_1, _mm_set1_epi8(48)); const __m128i q4h_6 = _mm_srli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(-64)), 2); const __m128i q4h_7 = _mm_srli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(-64)), 2); const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m15), q4h_0); const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m15), q4h_1); const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m15), q4h_2); const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m15), q4h_3); const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m15), q4h_4); const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m15), q4h_5); const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m15), q4h_6); const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m15), q4h_7); const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; __m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0); __m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1); __m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2); __m128i p16_3 = _mm_maddubs_epi16(q4_3, q8_3); __m128i p16_4 = _mm_maddubs_epi16(q4_4, q8_4); __m128i p16_5 = _mm_maddubs_epi16(q4_5, q8_5); __m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6); __m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7); const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0)); const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1)); const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2)); const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3)); is += 4; p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0); p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_0, 8)), p16_1); p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2); p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_1, 8)), p16_3); p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4); p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_2, 8)), p16_5); p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6); p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_3, 8)), p16_7); sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3)); sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_4, p16_6)); sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_5, p16_7)); } sumi_0 = _mm_sub_epi32(sumi_0, q8sclsub_0); sumi_1 = _mm_sub_epi32(sumi_1, q8sclsub_1); const __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi)), acc); } *s = hsum_float_8(acc); #else int8_t aux8[QK_K]; int16_t aux16[8]; float sums [8]; int32_t aux32[8]; memset(sums, 0, 8*sizeof(float)); float sumf = 0; for (int i = 0; i < nb; ++i) { const uint8_t * restrict q4 = x[i].ql; const uint8_t * restrict qh = x[i].qh; const int8_t * restrict q8 = y[i].qs; memset(aux32, 0, 8*sizeof(int32_t)); int8_t * restrict a = aux8; for (int j = 0; j < QK_K; j += 128) { for (int l = 0; l < 32; ++l) { a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; } a += 128; q4 += 64; qh += 32; } a = aux8; int is = 0; for (int j = 0; j < QK_K/16; ++j) { int scale = x[i].scales[is++]; for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; q8 += 8; a += 8; for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; q8 += 8; a += 8; } const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; } for (int l = 0; l < 8; ++l) sumf += sums[l]; *s = sumf; #endif } ``` You must start your code with `#elif defined(__wasm_simd128__)` To think about it, you need to take into account both the refenrence code from ARM NEON and AVX implementation. -
ngxson created this gist
Jan 26, 2025 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,245 @@ Your task is to convert a given C++ ARM NEON SIMD to WASM SIMD. Here is an example of another function: ```cpp void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; int ib = 0; float sumf = 0; assert(n % qk == 0); assert(qk == QK5_0); assert(nrc == 1); UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); const block_q5_0 * restrict x = vx; const block_q8_0 * restrict y = vy; #if defined(__ARM_NEON) float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f); uint32_t qh0; uint32_t qh1; uint64_t tmp0[4]; uint64_t tmp1[4]; for (; ib + 1 < nb; ib += 2) { const block_q5_0 * restrict x0 = &x[ib]; const block_q5_0 * restrict x1 = &x[ib + 1]; const block_q8_0 * restrict y0 = &y[ib]; const block_q8_0 * restrict y1 = &y[ib + 1]; const uint8x16_t m4b = vdupq_n_u8(0x0F); // extract the 5th bit via lookup table ((!b) << 4) memcpy(&qh0, x0->qh, sizeof(qh0)); memcpy(&qh1, x1->qh, sizeof(qh1)); tmp0[0] = table_b2b_1[(qh0 >> 0) & 0xFF]; tmp0[1] = table_b2b_1[(qh0 >> 8) & 0xFF]; tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF]; tmp0[3] = table_b2b_1[(qh0 >> 24) ]; tmp1[0] = table_b2b_1[(qh1 >> 0) & 0xFF]; tmp1[1] = table_b2b_1[(qh1 >> 8) & 0xFF]; tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF]; tmp1[3] = table_b2b_1[(qh1 >> 24) ]; const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0)); const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2)); const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0)); const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2)); const uint8x16_t v0_0 = vld1q_u8(x0->qs); const uint8x16_t v0_1 = vld1q_u8(x1->qs); // 4-bit -> 8-bit int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero) const int8x16_t v0_0lf = vsubq_s8(v0_0l, qhl0); const int8x16_t v0_0hf = vsubq_s8(v0_0h, qhh0); const int8x16_t v0_1lf = vsubq_s8(v0_1l, qhl1); const int8x16_t v0_1hf = vsubq_s8(v0_1h, qhh1); // load y const int8x16_t v1_0l = vld1q_s8(y0->qs); const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); const int8x16_t v1_1l = vld1q_s8(y1->qs); const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); } sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); #elif defined(__wasm_simd128__) v128_t sumv = wasm_f32x4_splat(0.0f); uint32_t qh; uint64_t tmp[4]; // TODO: check if unrolling this is better for (; ib < nb; ++ib) { const block_q5_0 * restrict x0 = &x[ib]; const block_q8_0 * restrict y0 = &y[ib]; const v128_t m4b = wasm_i8x16_splat(0x0F); // extract the 5th bit memcpy(&qh, x0->qh, sizeof(qh)); tmp[0] = table_b2b_1[(qh >> 0) & 0xFF]; tmp[1] = table_b2b_1[(qh >> 8) & 0xFF]; tmp[2] = table_b2b_1[(qh >> 16) & 0xFF]; tmp[3] = table_b2b_1[(qh >> 24) ]; const v128_t qhl = wasm_v128_load(tmp + 0); const v128_t qhh = wasm_v128_load(tmp + 2); const v128_t v0 = wasm_v128_load(x0->qs); // 4-bit -> 8-bit const v128_t v0l = wasm_v128_and (v0, m4b); const v128_t v0h = wasm_u8x16_shr(v0, 4); // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero) const v128_t v0lf = wasm_i8x16_sub(v0l, qhl); const v128_t v0hf = wasm_i8x16_sub(v0h, qhh); // load y const v128_t v1l = wasm_v128_load(y0->qs); const v128_t v1h = wasm_v128_load(y0->qs + 16); // int8x16 -> int16x8 const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf); const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf); const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf); const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf); const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l); const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l); const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h); const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h); // dot product sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4( wasm_i32x4_add( wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll), wasm_i32x4_dot_i16x8(v0lfh, v1lh)), wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl), wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d)))); } sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3); #endif for (; ib < nb; ++ib) { uint32_t qh; memcpy(&qh, x[ib].qh, sizeof(qh)); int sumi0 = 0; int sumi1 = 0; for (int j = 0; j < qk/2; ++j) { const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12)); const int32_t x0 = (int8_t)(((x[ib].qs[j] & 0x0F) | xh_0) - 16); const int32_t x1 = (int8_t)(((x[ib].qs[j] >> 4) | xh_1) - 16); sumi0 += (x0 * y[ib].qs[j]); sumi1 += (x1 * y[ib].qs[j + qk/2]); } int sumi = sumi0 + sumi1; sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)) * sumi; } *s = sumf; } ```` Here is the function that you need to convert: ```cpp void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; assert(n % qk == 0); assert(nrc == 1); UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); const block_q8_0 * restrict x = vx; const block_q8_0 * restrict y = vy; int ib = 0; float sumf = 0; #if defined(__ARM_NEON) float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f); for (; ib + 1 < nb; ib += 2) { const block_q8_0 * restrict x0 = &x[ib + 0]; const block_q8_0 * restrict x1 = &x[ib + 1]; const block_q8_0 * restrict y0 = &y[ib + 0]; const block_q8_0 * restrict y1 = &y[ib + 1]; const int8x16_t x0_0 = vld1q_s8(x0->qs); const int8x16_t x0_1 = vld1q_s8(x0->qs + 16); const int8x16_t x1_0 = vld1q_s8(x1->qs); const int8x16_t x1_1 = vld1q_s8(x1->qs + 16); // load y const int8x16_t y0_0 = vld1q_s8(y0->qs); const int8x16_t y0_1 = vld1q_s8(y0->qs + 16); const int8x16_t y1_0 = vld1q_s8(y1->qs); const int8x16_t y1_1 = vld1q_s8(y1->qs + 16); sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( ggml_vdotq_s32(vdupq_n_s32(0), x0_0, y0_0), ggml_vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( ggml_vdotq_s32(vdupq_n_s32(0), x1_0, y1_0), ggml_vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); } sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); #endif for (; ib < nb; ++ib) { int sumi = 0; for (int j = 0; j < qk; j++) { sumi += x[ib].qs[j]*y[ib].qs[j]; } sumf += sumi*(GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)); } *s = sumf; } ``` You must start your code with `#elif defined(__wasm_simd128__)`