Skip to content

Instantly share code, notes, and snippets.

@aseyboldt
Last active October 29, 2025 12:55
Show Gist options
  • Select an option

  • Save aseyboldt/456f91f5d3bd9ef8bddbbf17d4d1b2ad to your computer and use it in GitHub Desktop.

Select an option

Save aseyboldt/456f91f5d3bd9ef8bddbbf17d4d1b2ad to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 6,
"id": "4b503fb8-3649-46f3-9265-cb52f2f232ad",
"metadata": {},
"outputs": [],
"source": [
"\n",
"import numpy as np\n",
"import scipy.linalg.lapack\n",
"import numba\n",
"from numba.core import cgutils\n",
"import pytensor.link.numba.dispatch.linalg._LAPACK\n",
"import ctypes\n",
"import scipy.special.cython_special\n",
"from llvmlite import ir\n",
"\n",
"def _get_func_ptr_inner():\n",
" func = scipy.special.cython_special.psi\n",
" cfunc = pytensor.link.numba.dispatch.cython_support.wrap_cython_function(func, np.float64, [np.float64])\n",
" return cfunc._func_ptr\n",
"\n",
"@numba.njit\n",
"def get_func_ptr():\n",
" with numba.objmode(ptr=numba.types.c_intp):\n",
" ptr = _get_func_ptr_inner()\n",
" return ptr\n",
"\n",
"@numba.extending.intrinsic(prefer_literal=True)\n",
"def _call_cached_ptr(typingctx, get_ptr_func, func_type_ref):\n",
" func_type = func_type_ref.instance_type\n",
"\n",
" def codegen(context, builder, signature, args):\n",
"\n",
" ptr_ty = ir.PointerType(ir.IntType(8))\n",
" null = ptr_ty(None)\n",
"\n",
" # Align to cache line\n",
" align = 64\n",
"\n",
" mod = builder.module\n",
" var = cgutils.add_global_variable(mod, ptr_ty, \"_ptr_cache\")\n",
" var.align = align\n",
" #var.linkage = \"linkonce\"\n",
" var.linkage = \"private\"\n",
" var.initializer = null\n",
"\n",
" var_val = builder.load_atomic(var, \"acquire\", align)\n",
" result_ptr = cgutils.alloca_once_value(builder, var_val)\n",
"\n",
" with builder.if_then(builder.icmp_signed(\"==\", var_val, null), likely=False):\n",
" sig = typingctx.resolve_function_type(get_ptr_func, [], {})\n",
" func = context.get_function(get_ptr_func, sig)\n",
" new_ptr = func(builder, [])\n",
" new_ptr = builder.inttoptr(new_ptr, ptr_ty)\n",
" builder.store_atomic(new_ptr, var, \"release\", align)\n",
" builder.store(new_ptr, result_ptr)\n",
"\n",
" sfunc = cgutils.create_struct_proxy(func_type)(context, builder)\n",
" sfunc.c_addr = builder.load(result_ptr)\n",
" return sfunc._getvalue()\n",
"\n",
" sig = func_type(get_ptr_func, func_type_ref)\n",
" return sig, codegen\n",
"\n",
"func_type = numba.types.FunctionType(numba.float64(numba.float64, numba.int64))\n",
"\n",
"@numba.njit(cache=True)\n",
"def foo(x, y):\n",
" func = _call_cached_ptr(get_func_ptr, func_type)\n",
" return func(x, y)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "612b9880-0d44-42b1-ab51-318f848947bd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"186 ns ± 2.12 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)\n"
]
}
],
"source": [
"%timeit foo(0.1, 1)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "50e46b69-9b5a-413b-bc4e-ee91be10492b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"-10.423754940411076"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"foo(0.1, 1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d1e4552a-3566-4cd5-9f41-04a4a44bd96c",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python (Pixi)",
"language": "python",
"name": "pixi-kernel-python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment