Last active
October 29, 2025 12:55
-
-
Save aseyboldt/456f91f5d3bd9ef8bddbbf17d4d1b2ad to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "id": "4b503fb8-3649-46f3-9265-cb52f2f232ad", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "\n", | |
| "import numpy as np\n", | |
| "import scipy.linalg.lapack\n", | |
| "import numba\n", | |
| "from numba.core import cgutils\n", | |
| "import pytensor.link.numba.dispatch.linalg._LAPACK\n", | |
| "import ctypes\n", | |
| "import scipy.special.cython_special\n", | |
| "from llvmlite import ir\n", | |
| "\n", | |
| "def _get_func_ptr_inner():\n", | |
| " func = scipy.special.cython_special.psi\n", | |
| " cfunc = pytensor.link.numba.dispatch.cython_support.wrap_cython_function(func, np.float64, [np.float64])\n", | |
| " return cfunc._func_ptr\n", | |
| "\n", | |
| "@numba.njit\n", | |
| "def get_func_ptr():\n", | |
| " with numba.objmode(ptr=numba.types.c_intp):\n", | |
| " ptr = _get_func_ptr_inner()\n", | |
| " return ptr\n", | |
| "\n", | |
| "@numba.extending.intrinsic(prefer_literal=True)\n", | |
| "def _call_cached_ptr(typingctx, get_ptr_func, func_type_ref):\n", | |
| " func_type = func_type_ref.instance_type\n", | |
| "\n", | |
| " def codegen(context, builder, signature, args):\n", | |
| "\n", | |
| " ptr_ty = ir.PointerType(ir.IntType(8))\n", | |
| " null = ptr_ty(None)\n", | |
| "\n", | |
| " # Align to cache line\n", | |
| " align = 64\n", | |
| "\n", | |
| " mod = builder.module\n", | |
| " var = cgutils.add_global_variable(mod, ptr_ty, \"_ptr_cache\")\n", | |
| " var.align = align\n", | |
| " #var.linkage = \"linkonce\"\n", | |
| " var.linkage = \"private\"\n", | |
| " var.initializer = null\n", | |
| "\n", | |
| " var_val = builder.load_atomic(var, \"acquire\", align)\n", | |
| " result_ptr = cgutils.alloca_once_value(builder, var_val)\n", | |
| "\n", | |
| " with builder.if_then(builder.icmp_signed(\"==\", var_val, null), likely=False):\n", | |
| " sig = typingctx.resolve_function_type(get_ptr_func, [], {})\n", | |
| " func = context.get_function(get_ptr_func, sig)\n", | |
| " new_ptr = func(builder, [])\n", | |
| " new_ptr = builder.inttoptr(new_ptr, ptr_ty)\n", | |
| " builder.store_atomic(new_ptr, var, \"release\", align)\n", | |
| " builder.store(new_ptr, result_ptr)\n", | |
| "\n", | |
| " sfunc = cgutils.create_struct_proxy(func_type)(context, builder)\n", | |
| " sfunc.c_addr = builder.load(result_ptr)\n", | |
| " return sfunc._getvalue()\n", | |
| "\n", | |
| " sig = func_type(get_ptr_func, func_type_ref)\n", | |
| " return sig, codegen\n", | |
| "\n", | |
| "func_type = numba.types.FunctionType(numba.float64(numba.float64, numba.int64))\n", | |
| "\n", | |
| "@numba.njit(cache=True)\n", | |
| "def foo(x, y):\n", | |
| " func = _call_cached_ptr(get_func_ptr, func_type)\n", | |
| " return func(x, y)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "id": "612b9880-0d44-42b1-ab51-318f848947bd", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "186 ns ± 2.12 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%timeit foo(0.1, 1)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "id": "50e46b69-9b5a-413b-bc4e-ee91be10492b", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "-10.423754940411076" | |
| ] | |
| }, | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "foo(0.1, 1)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "d1e4552a-3566-4cd5-9f41-04a4a44bd96c", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python (Pixi)", | |
| "language": "python", | |
| "name": "pixi-kernel-python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.13.9" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment