Created
January 7, 2025 01:33
-
-
Save Axect/3f6b0e537874166eee902f9f2dd0bfe4 to your computer and use it in GitHub Desktop.
Patch for __init__.py for clifford
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def _get_mult_function(mt: sparse.COO): | |
""" | |
Get a function similar to `` lambda a, b: np.einsum('i,ijk,k->j', a, mt, b)`` | |
Returns | |
------- | |
func : function (array_like (n_dims,), array_like (n_dims,)) -> array_like (n_dims,) | |
A function that computes the appropriate multiplication | |
""" | |
# unpack for numba | |
dims = mt.shape[1] | |
k_list, l_list, m_list = mt.coords | |
mult_table_vals = mt.data | |
@_numba_utils.njit | |
def mv_mult(value, other_value): | |
# 컴파일 타임에 타입 추론이 일어나므로, | |
# ret_dtype은 (value, other_value, mult_table_vals)의 dtype 조합으로 결정됨 | |
ret_dtype = _get_mult_function_result_type(value, other_value, mult_table_vals.dtype) | |
mult_table_vals_t = mult_table_vals.astype(ret_dtype) | |
output = np.zeros(dims, dtype=ret_dtype) | |
for k, l, m, val in zip(k_list, l_list, m_list, mult_table_vals_t): | |
output[l] += value[k] * val * other_value[m] | |
return output | |
return mv_mult | |
def _get_mult_function_runtime_sparse(mt: sparse.COO): | |
""" | |
A variant of `_get_mult_function` that attempts to exploit runtime zeros | |
The returned function avoids performing multiplications if vectors contain | |
zeros. | |
TODO: determine if this actually helps. | |
""" | |
# unpack for numba | |
dims = mt.shape[1] | |
k_list, l_list, m_list = mt.coords | |
mult_table_vals = mt.data | |
@_numba_utils.njit | |
def mv_mult(value, other_value): | |
ret_dtype = _get_mult_function_result_type(value, other_value, mult_table_vals.dtype) | |
mult_table_vals_t = mult_table_vals.astype(ret_dtype) | |
output = np.zeros(dims, dtype=ret_dtype) | |
for ind, k in enumerate(k_list): | |
v_val = value[k] | |
if v_val != 0.0: | |
m = m_list[ind] | |
ov_val = other_value[m] | |
if ov_val != 0.0: | |
l = l_list[ind] | |
output[l] += v_val * mult_table_vals_t[ind] * ov_val | |
return output | |
return mv_mult |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment