Skip to content

Instantly share code, notes, and snippets.

@Joshuaalbert
Created May 20, 2026 21:59
Show Gist options
  • Select an option

  • Save Joshuaalbert/1eb35a7ba0070ffe1b4308e231368758 to your computer and use it in GitHub Desktop.

Select an option

Save Joshuaalbert/1eb35a7ba0070ffe1b4308e231368758 to your computer and use it in GitHub Desktop.
A target for matching wgridder.
import itertools
from typing import Union
import jax
import jax.numpy as jnp
import numpy as np
import pytest
from ducc0 import wgridder
"""
To install reqs: `pip install --no-binary=ducc0 ducc0 jax numpy pytest`
"""
ComplexArray = Union[jax.Array, np.ndarray, complex]
FloatArray = Union[jax.Array, np.ndarray, float]
IntArray = Union[jax.Array, np.ndarray, int]
img_dtype = np.float64
vis_dtype = np.complex128
def vis_to_image(uvw: FloatArray, freqs: FloatArray, vis: ComplexArray,
pixsize_m: float,
pixsize_l: float,
center_m: float,
center_l: float,
npix_m: int, npix_l: int, wgt: FloatArray | None = None,
mask: FloatArray | None = None, epsilon: float = 1e-6,
verbosity: int = 0,
double_precision_accumulation: bool = False, scale_by_n: bool = True,
normalise: bool = True,
output_buffer: np.ndarray | None = None,
num_threads: int = 1) -> FloatArray:
"""
Compute the image from the visibilities.
Args:
uvw: [num_rows, 3] array of uvw coordinates.
freqs: [num_freqs] array of frequencies.
vis: [num_rows, num_freqs] array of visibilities.
wgt: [num_rows, num_freqs] array of weights, multiplied with input visibilities.
mask: [num_rows, num_freqs] array of mask, only image vis[mask!=0]
npix_m: number of pixels in m direction.
npix_l: number of pixels in l direction.
pixsize_m: scalar, pixel size in m direction.
pixsize_l: scalar, pixel size in l direction.
center_m: scalar, m at center of image.
center_l: scalar, l at center of image.
epsilon: scalar, gridding accuracy
verbosity: verbosity level, 0, 1.
double_precision_accumulation: whether to use double precision for accumulation, which reduces numerical errors.
scale_by_n: whether to scale the image by n(l,m).
normalise: whether to normalise the image by the zero-term of the DFT.
output_buffer: optional [npix_l, npix_m] array of dirty image, in units of JY/PIXEL
Should be same precision as vis.
num_threads: number of threads to use.
Returns:
[npix_l, npix_m] array of image in JY/BEAM.
"""
# This corrects for https://github.com/mreineck/ducc/issues/34, which is a bug in ducc0 where the w coordinate is flipped.
uvw = uvw.copy()
uvw[:, 2] *= -1
assert np.all(np.isfinite(uvw)), "uvw contains non-finite values"
assert np.all(np.isfinite(freqs)), "freqs contains non-finite values"
assert np.all(np.isfinite(vis.real)), "vis contains non-finite values"
assert np.all(np.isfinite(vis.imag)), "vis contains non-finite values"
image = _vis2dirty(
uvw=uvw,
freqs=freqs,
vis=vis,
wgt=wgt,
mask=mask,
npix_m=npix_m,
npix_l=npix_l,
pixsize_m=pixsize_m,
pixsize_l=pixsize_l,
center_m=center_m,
center_l=center_l,
epsilon=epsilon,
do_wgridding=True,
flip_v=False,
divide_by_n=False,
sigma_min=1.1,
sigma_max=2.6,
double_precision_accumulation=double_precision_accumulation,
verbosity=verbosity,
output_buffer=output_buffer,
num_threads=num_threads
)
if scale_by_n:
l = (-0.5 * npix_l + np.arange(npix_l)) * pixsize_l + center_l
m = (-0.5 * npix_m + np.arange(npix_m)) * pixsize_m + center_m
l, m = np.meshgrid(l, m, indexing='ij')
n = np.sqrt(np.maximum(0., 1. - (np.square(l) + np.square(m))))
del l
del m
image *= n
if normalise:
# Adjoint normalising factor is the DFT zero-term i.e. sum_{u,v,nu} S(u,v,nu)
if wgt is not None:
if mask is not None:
adjoint_normalising_factor = np.reciprocal(np.sum(wgt * mask))
else:
adjoint_normalising_factor = np.reciprocal(np.sum(wgt))
else:
adjoint_normalising_factor = 1 / vis.size
image *= adjoint_normalising_factor
return np.asarray(image, img_dtype)
def image_to_vis(uvw: FloatArray, freqs: FloatArray, dirty: FloatArray,
pixsize_m: float,
pixsize_l: float,
center_m: float,
center_l: float,
mask: FloatArray | None = None,
epsilon: float = 1e-6,
num_threads: int = 1,
verbosity: int = 0,
output_buffer: np.ndarray | None = None):
"""
Compute the visibilities from the dirty image.
Args:
uvw: [num_rows, 3] array of uvw coordinates.
freqs: [num_freqs] array of frequencies.
dirty: [num_l, num_m] array of dirty image, in units of JY/PIXEL.
pixsize_m: scalar, pixel size in m direction.
pixsize_l: scalar, pixel size in l direction.
center_m: scalar, m at center of image.
center_l: scalar, l at center of image.
mask: [num_rows, num_freqs] array of mask, only predict where mask!=0.
epsilon: scalar, gridding accuracy
num_threads: number of threads to use.
verbosity: verbosity level, 0, 1.
Returns:
[num_rows, num_freqs] array of visibilities.
"""
# Pixels should be in Jy/pixel.
uvw = uvw.copy()
uvw[:, 2] *= -1
vis = _dirty2vis(
uvw=uvw,
freqs=freqs,
dirty=dirty,
wgt=None,
mask=mask,
pixsize_m=pixsize_m,
pixsize_l=pixsize_l,
center_m=center_m,
center_l=center_l,
epsilon=epsilon,
do_wgridding=True,
flip_v=False,
divide_by_n=True,
sigma_min=1.1,
sigma_max=2.6,
verbosity=verbosity,
num_threads=num_threads,
output_buffer=output_buffer
)
return vis
def _dirty2vis(uvw: FloatArray, freqs: FloatArray,
dirty: FloatArray, wgt: FloatArray | None,
mask: FloatArray | None,
pixsize_m: float, pixsize_l: float,
center_m: float, center_l: float,
epsilon: float, do_wgridding: bool,
flip_v: bool, divide_by_n: bool,
sigma_min: float, sigma_max: float,
verbosity: int,
output_buffer: np.ndarray | None = None,
num_threads: int = 1):
"""
Compute the visibilities from the dirty image.
Args:
uvw: [num_rows, 3] array of uvw coordinates.
freqs: [num_freqs[,1]] array of frequencies.
dirty: [[num_freqs], ..., num_l, num_m] array of dirty image, in units of JY/PIXEL.
wgt: [num_rows, num_freqs] array of weights, multiplied with output visibilities.
mask: [num_rows, num_freqs] array of mask, only predict where mask!=0.
pixsize_m: scalar, pixel size in x direction.
pixsize_l: scalar, pixel size in y direction.
center_m: scalar, center of image in x direction.
center_l: scalar, center of image in y direction.
epsilon: scalar, gridding kernel width.
do_wgridding: scalar, whether to do w-gridding.
flip_v: scalar, whether to flip the v axis.
divide_by_n: whether to divide by n.
sigma_min: scalar, minimum sigma for gridding.
sigma_max: scalar, maximum sigma for gridding.
verbosity: verbosity level, 0, 1.
Returns:
[num_rows, num_freqs] array of visibilities.
"""
if output_buffer is not None:
if np.shape(output_buffer) != (np.shape(uvw)[0], np.shape(freqs)[0]):
raise ValueError(
f"Expected output_buffer to have shape {(np.shape(uvw)[0], np.shape(freqs)[0])}, got {np.shape(output_buffer)}"
)
if not np.issubdtype(np.result_type(output_buffer), jnp.complexfloating):
raise ValueError(
f"Expected output_buffer to have dtype {jnp.complexfloating}, got {np.result_type(output_buffer)}"
)
order = output_buffer.flags['F_CONTIGUOUS']
if not order:
raise ValueError("Expected output_buffer to be Fortran contiguous.")
uvw = np.asarray(uvw, order='C', dtype=np.float64) # [num_rows, 3]
freqs = np.asarray(freqs, order='C', dtype=np.float64) # [num_freqs]
dirty = np.asarray(dirty, order='C') # [num_l, num_m]
num_rows, _ = np.shape(uvw)
num_freq, = np.shape(freqs)
if wgt is not None:
wgt = np.asarray(wgt, order='C').astype(dirty.dtype) # [num_rows, num_freqs]
if mask is not None:
mask = np.asarray(mask, order='C').astype(np.uint8) # [num_rows, num_freqs]
if dirty.dtype == jnp.float32:
output_vis = np.zeros((num_rows, num_freq), order='C', dtype=np.complex64)
elif dirty.dtype == jnp.float64:
output_vis = np.zeros((num_rows, num_freq), order='C', dtype=np.complex128)
else:
raise ValueError(f"Expected dirty to be float32 or float64, got {dirty.dtype}")
wgridder.dirty2vis(
uvw=uvw,
freq=freqs,
dirty=dirty,
wgt=wgt,
mask=mask,
pixsize_x=pixsize_l,
pixsize_y=pixsize_m,
center_x=center_l,
center_y=center_m,
epsilon=float(epsilon),
do_wgridding=bool(do_wgridding),
flip_v=bool(flip_v),
divide_by_n=bool(divide_by_n),
sigma_min=float(sigma_min),
sigma_max=float(sigma_max),
nthreads=num_threads,
verbosity=int(verbosity),
vis=output_vis
)
return output_vis
def _vis2dirty(
uvw: FloatArray, freqs: FloatArray,
vis: FloatArray, wgt: FloatArray | None,
mask: FloatArray | None,
npix_m: IntArray, npix_l: IntArray,
pixsize_m: float, pixsize_l: float,
center_m: float, center_l: float,
epsilon: float, do_wgridding: bool,
flip_v: bool, divide_by_n: bool,
sigma_min: float, sigma_max: float,
verbosity: int,
double_precision_accumulation: bool,
output_buffer: np.ndarray | None = None,
num_threads: int = 1
):
"""
Compute the dirty image from the visibilities.
Args:
uvw: [num_rows, 3] array of uvw coordinates.
freqs: [num_freqs] array of frequencies.
vis: [num_rows, num_freqs] array of visibilities.
wgt: [num_rows, num_freqs] array of weights, multiplied with input visibilities.
mask: [num_rows, num_freqs] array of mask, only predict where mask!=0.
npix_m: number of pixels in y direction.
npix_l: number of pixels in x direction.
pixsize_m: scalar, pixel size in y direction.
pixsize_l: scalar, pixel size in x direction.
center_m: scalar, center of image in x direction.
center_l: scalar, center of image in y direction.
epsilon: scalar, gridding kernel width.
do_wgridding: scalar, whether to do w-gridding.
flip_v: scalar, whether to flip the v axis.
divide_by_n: whether to divide by n.
sigma_min: scalar, minimum sigma for gridding.
sigma_max: scalar, maximum sigma for gridding.
verbosity: verbosity level, 0, 1.
double_precision_accumulation: whether to use double precision for accumulation, which reduces numerical
errors for special cases.
output_buffer: optional [npix_l, npix_m] array of dirty image, in units of JY/PIXEL
Returns:
[npix_l, npix_m] array of dirty image, in units of JY/PIXEL.
"""
uvw = np.asarray(uvw, order='C', dtype=np.float64) # [num_rows, 3]
freqs = np.asarray(freqs, order='C', dtype=np.float64) # [num_freqs]
vis = np.asarray(vis, order='C') # Fortran order for better cache locality
output_type = vis.real.dtype
if wgt is not None:
wgt = np.asarray(wgt, order='C').astype(output_type)
if mask is not None:
mask = np.asarray(mask, order='C').astype(np.uint8)
if npix_m % 2 != 0 or npix_l % 2 != 0:
raise ValueError("npix_m and npix_l must both be even.")
if npix_m < 32 or npix_l < 32:
raise ValueError("npix_l and npix_m must be at least 32.")
# Make sure the output is in JY/PIXEL
if output_buffer is not None:
if np.shape(output_buffer) != (npix_l, npix_m):
raise ValueError(f"Expected output_buffer to have shape {(npix_l, npix_m)}, got {np.shape(output_buffer)}")
if not np.issubdtype(np.result_type(output_buffer), output_type):
raise ValueError(f"Expected output_buffer to have dtype {output_type}, got {np.result_type(output_buffer)}")
dirty = output_buffer
else:
dirty = np.zeros((npix_l, npix_m), order='C', dtype=output_type)
_ = wgridder.vis2dirty(
uvw=uvw,
freq=freqs,
vis=vis,
wgt=wgt,
mask=mask,
npix_x=npix_l,
npix_y=npix_m,
pixsize_x=pixsize_l,
pixsize_y=pixsize_m,
center_x=center_l,
center_y=center_m,
epsilon=epsilon,
do_wgridding=do_wgridding,
flip_v=flip_v,
divide_by_n=divide_by_n,
sigma_min=sigma_min,
sigma_max=sigma_max,
nthreads=num_threads,
verbosity=verbosity,
dirty=dirty,
double_precision_accumulation=double_precision_accumulation
)
return dirty
@pytest.mark.parametrize("center_offset", [0.0, 0.1, 0.2])
def test_against_explicit(center_offset: float):
np.random.seed(42)
N = 512
num_ants = 10
antenna1, antenna2 = np.asarray(list(itertools.combinations(range(num_ants), 2))).T
B = len(antenna1)
uvw = 16e3 * np.random.normal(size=(B, 3))
uvw[:, 2] *= 0.001
freqs = 700e6 + np.arange(20) * 130e6
pixsize = 0.5 * np.pi / 180 / 3600. # 1 arcsec ~ 4 pixels / beam, so we'll avoid aliasing
l0 = center_offset
m0 = center_offset
dl = pixsize
dm = pixsize
sky_image = np.zeros((N, N), dtype=img_dtype)
# Put some components
sky_image[N // 2, N // 2] = 1.
sky_image[N // 2, N // 3] = 1.
sky_image[N // 3, N // 3] = 1.
sky_image[N // 3, N // 4] = 1.
sky_image[N // 4, N // 4] = 1.
sky_image[N // 4, N // 5] = 1.
def pixel_to_lmn(xi, yi):
l = l0 + (-N / 2 + xi) * dl
m = m0 + (-N / 2 + yi) * dm
n = np.sqrt(1. - l ** 2 - m ** 2)
return np.asarray([l, m, n])
x, y = np.where(sky_image)
lmn = [pixel_to_lmn(xi, yi) for xi, yi in zip(x, y)]
pixel_fluxes = [sky_image[xi, yi] for xi, yi in zip(x, y)]
vis = image_to_vis(
uvw=uvw,
freqs=freqs,
dirty=jnp.asarray(sky_image),
pixsize_l=dl,
pixsize_m=dm,
center_l=l0,
center_m=m0,
epsilon=1e-6
)
vis_explicit = explicit_degridder(uvw, freqs, lmn, pixel_fluxes)
# Now test the gridding
dirty_rec = vis_to_image(uvw=uvw, freqs=freqs, vis=vis, pixsize_m=dm, pixsize_l=dl, center_m=m0, center_l=l0,
npix_m=N, npix_l=N, scale_by_n=False, normalise=False)
dirty_explicit = explicit_gridder(uvw, freqs, vis, N, N, dl, dm, l0, m0)
np.testing.assert_allclose(dirty_rec, dirty_explicit, atol=3e-4, rtol=0)
np.testing.assert_allclose(vis.real, vis_explicit.real, atol=2e-6, rtol=0)
np.testing.assert_allclose(vis.imag, vis_explicit.imag, atol=2e-6, rtol=0)
def explicit_degridder(uvw, freqs, lmn, pixel_fluxes):
vis = np.zeros((len(uvw), len(freqs)), dtype=np.complex64)
c = 299792458. # m/s
for row, (u, v, w) in enumerate(uvw):
for col, freq in enumerate(freqs):
for flux, (l, m, n) in zip(pixel_fluxes, lmn):
wavelength = c / freq
phase = -2j * np.pi * (u * l + v * m + w * (n - 1)) / wavelength
vis[row, col] += flux * np.exp(phase) / n
return vis
def explicit_gridder(uvw, freqs, vis, num_l, num_m, dl, dm, center_l, center_m):
c = 299792458. # m/s
lvec = (-0.5 * num_l + np.arange(num_l)) * dl + center_l
mvec = (-0.5 * num_m + np.arange(num_m)) * dm + center_m
L, M = np.meshgrid(lvec, mvec, indexing='ij')
N = np.sqrt(1. - L ** 2 - M ** 2)
dirty = np.zeros((num_l, num_m), img_dtype)
for row, (u, v, w) in enumerate(uvw):
for col, freq in enumerate(freqs):
wavelength = c / freq
phase = 2j * np.pi * (u * L + v * M + w * (N - 1)) / wavelength
dirty += (vis[row, col] * np.exp(phase)).real
return dirty
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment