Created
May 20, 2026 21:59
-
-
Save Joshuaalbert/1eb35a7ba0070ffe1b4308e231368758 to your computer and use it in GitHub Desktop.
A target for matching wgridder.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import itertools | |
| from typing import Union | |
| import jax | |
| import jax.numpy as jnp | |
| import numpy as np | |
| import pytest | |
| from ducc0 import wgridder | |
| """ | |
| To install reqs: `pip install --no-binary=ducc0 ducc0 jax numpy pytest` | |
| """ | |
| ComplexArray = Union[jax.Array, np.ndarray, complex] | |
| FloatArray = Union[jax.Array, np.ndarray, float] | |
| IntArray = Union[jax.Array, np.ndarray, int] | |
| img_dtype = np.float64 | |
| vis_dtype = np.complex128 | |
| def vis_to_image(uvw: FloatArray, freqs: FloatArray, vis: ComplexArray, | |
| pixsize_m: float, | |
| pixsize_l: float, | |
| center_m: float, | |
| center_l: float, | |
| npix_m: int, npix_l: int, wgt: FloatArray | None = None, | |
| mask: FloatArray | None = None, epsilon: float = 1e-6, | |
| verbosity: int = 0, | |
| double_precision_accumulation: bool = False, scale_by_n: bool = True, | |
| normalise: bool = True, | |
| output_buffer: np.ndarray | None = None, | |
| num_threads: int = 1) -> FloatArray: | |
| """ | |
| Compute the image from the visibilities. | |
| Args: | |
| uvw: [num_rows, 3] array of uvw coordinates. | |
| freqs: [num_freqs] array of frequencies. | |
| vis: [num_rows, num_freqs] array of visibilities. | |
| wgt: [num_rows, num_freqs] array of weights, multiplied with input visibilities. | |
| mask: [num_rows, num_freqs] array of mask, only image vis[mask!=0] | |
| npix_m: number of pixels in m direction. | |
| npix_l: number of pixels in l direction. | |
| pixsize_m: scalar, pixel size in m direction. | |
| pixsize_l: scalar, pixel size in l direction. | |
| center_m: scalar, m at center of image. | |
| center_l: scalar, l at center of image. | |
| epsilon: scalar, gridding accuracy | |
| verbosity: verbosity level, 0, 1. | |
| double_precision_accumulation: whether to use double precision for accumulation, which reduces numerical errors. | |
| scale_by_n: whether to scale the image by n(l,m). | |
| normalise: whether to normalise the image by the zero-term of the DFT. | |
| output_buffer: optional [npix_l, npix_m] array of dirty image, in units of JY/PIXEL | |
| Should be same precision as vis. | |
| num_threads: number of threads to use. | |
| Returns: | |
| [npix_l, npix_m] array of image in JY/BEAM. | |
| """ | |
| # This corrects for https://github.com/mreineck/ducc/issues/34, which is a bug in ducc0 where the w coordinate is flipped. | |
| uvw = uvw.copy() | |
| uvw[:, 2] *= -1 | |
| assert np.all(np.isfinite(uvw)), "uvw contains non-finite values" | |
| assert np.all(np.isfinite(freqs)), "freqs contains non-finite values" | |
| assert np.all(np.isfinite(vis.real)), "vis contains non-finite values" | |
| assert np.all(np.isfinite(vis.imag)), "vis contains non-finite values" | |
| image = _vis2dirty( | |
| uvw=uvw, | |
| freqs=freqs, | |
| vis=vis, | |
| wgt=wgt, | |
| mask=mask, | |
| npix_m=npix_m, | |
| npix_l=npix_l, | |
| pixsize_m=pixsize_m, | |
| pixsize_l=pixsize_l, | |
| center_m=center_m, | |
| center_l=center_l, | |
| epsilon=epsilon, | |
| do_wgridding=True, | |
| flip_v=False, | |
| divide_by_n=False, | |
| sigma_min=1.1, | |
| sigma_max=2.6, | |
| double_precision_accumulation=double_precision_accumulation, | |
| verbosity=verbosity, | |
| output_buffer=output_buffer, | |
| num_threads=num_threads | |
| ) | |
| if scale_by_n: | |
| l = (-0.5 * npix_l + np.arange(npix_l)) * pixsize_l + center_l | |
| m = (-0.5 * npix_m + np.arange(npix_m)) * pixsize_m + center_m | |
| l, m = np.meshgrid(l, m, indexing='ij') | |
| n = np.sqrt(np.maximum(0., 1. - (np.square(l) + np.square(m)))) | |
| del l | |
| del m | |
| image *= n | |
| if normalise: | |
| # Adjoint normalising factor is the DFT zero-term i.e. sum_{u,v,nu} S(u,v,nu) | |
| if wgt is not None: | |
| if mask is not None: | |
| adjoint_normalising_factor = np.reciprocal(np.sum(wgt * mask)) | |
| else: | |
| adjoint_normalising_factor = np.reciprocal(np.sum(wgt)) | |
| else: | |
| adjoint_normalising_factor = 1 / vis.size | |
| image *= adjoint_normalising_factor | |
| return np.asarray(image, img_dtype) | |
| def image_to_vis(uvw: FloatArray, freqs: FloatArray, dirty: FloatArray, | |
| pixsize_m: float, | |
| pixsize_l: float, | |
| center_m: float, | |
| center_l: float, | |
| mask: FloatArray | None = None, | |
| epsilon: float = 1e-6, | |
| num_threads: int = 1, | |
| verbosity: int = 0, | |
| output_buffer: np.ndarray | None = None): | |
| """ | |
| Compute the visibilities from the dirty image. | |
| Args: | |
| uvw: [num_rows, 3] array of uvw coordinates. | |
| freqs: [num_freqs] array of frequencies. | |
| dirty: [num_l, num_m] array of dirty image, in units of JY/PIXEL. | |
| pixsize_m: scalar, pixel size in m direction. | |
| pixsize_l: scalar, pixel size in l direction. | |
| center_m: scalar, m at center of image. | |
| center_l: scalar, l at center of image. | |
| mask: [num_rows, num_freqs] array of mask, only predict where mask!=0. | |
| epsilon: scalar, gridding accuracy | |
| num_threads: number of threads to use. | |
| verbosity: verbosity level, 0, 1. | |
| Returns: | |
| [num_rows, num_freqs] array of visibilities. | |
| """ | |
| # Pixels should be in Jy/pixel. | |
| uvw = uvw.copy() | |
| uvw[:, 2] *= -1 | |
| vis = _dirty2vis( | |
| uvw=uvw, | |
| freqs=freqs, | |
| dirty=dirty, | |
| wgt=None, | |
| mask=mask, | |
| pixsize_m=pixsize_m, | |
| pixsize_l=pixsize_l, | |
| center_m=center_m, | |
| center_l=center_l, | |
| epsilon=epsilon, | |
| do_wgridding=True, | |
| flip_v=False, | |
| divide_by_n=True, | |
| sigma_min=1.1, | |
| sigma_max=2.6, | |
| verbosity=verbosity, | |
| num_threads=num_threads, | |
| output_buffer=output_buffer | |
| ) | |
| return vis | |
| def _dirty2vis(uvw: FloatArray, freqs: FloatArray, | |
| dirty: FloatArray, wgt: FloatArray | None, | |
| mask: FloatArray | None, | |
| pixsize_m: float, pixsize_l: float, | |
| center_m: float, center_l: float, | |
| epsilon: float, do_wgridding: bool, | |
| flip_v: bool, divide_by_n: bool, | |
| sigma_min: float, sigma_max: float, | |
| verbosity: int, | |
| output_buffer: np.ndarray | None = None, | |
| num_threads: int = 1): | |
| """ | |
| Compute the visibilities from the dirty image. | |
| Args: | |
| uvw: [num_rows, 3] array of uvw coordinates. | |
| freqs: [num_freqs[,1]] array of frequencies. | |
| dirty: [[num_freqs], ..., num_l, num_m] array of dirty image, in units of JY/PIXEL. | |
| wgt: [num_rows, num_freqs] array of weights, multiplied with output visibilities. | |
| mask: [num_rows, num_freqs] array of mask, only predict where mask!=0. | |
| pixsize_m: scalar, pixel size in x direction. | |
| pixsize_l: scalar, pixel size in y direction. | |
| center_m: scalar, center of image in x direction. | |
| center_l: scalar, center of image in y direction. | |
| epsilon: scalar, gridding kernel width. | |
| do_wgridding: scalar, whether to do w-gridding. | |
| flip_v: scalar, whether to flip the v axis. | |
| divide_by_n: whether to divide by n. | |
| sigma_min: scalar, minimum sigma for gridding. | |
| sigma_max: scalar, maximum sigma for gridding. | |
| verbosity: verbosity level, 0, 1. | |
| Returns: | |
| [num_rows, num_freqs] array of visibilities. | |
| """ | |
| if output_buffer is not None: | |
| if np.shape(output_buffer) != (np.shape(uvw)[0], np.shape(freqs)[0]): | |
| raise ValueError( | |
| f"Expected output_buffer to have shape {(np.shape(uvw)[0], np.shape(freqs)[0])}, got {np.shape(output_buffer)}" | |
| ) | |
| if not np.issubdtype(np.result_type(output_buffer), jnp.complexfloating): | |
| raise ValueError( | |
| f"Expected output_buffer to have dtype {jnp.complexfloating}, got {np.result_type(output_buffer)}" | |
| ) | |
| order = output_buffer.flags['F_CONTIGUOUS'] | |
| if not order: | |
| raise ValueError("Expected output_buffer to be Fortran contiguous.") | |
| uvw = np.asarray(uvw, order='C', dtype=np.float64) # [num_rows, 3] | |
| freqs = np.asarray(freqs, order='C', dtype=np.float64) # [num_freqs] | |
| dirty = np.asarray(dirty, order='C') # [num_l, num_m] | |
| num_rows, _ = np.shape(uvw) | |
| num_freq, = np.shape(freqs) | |
| if wgt is not None: | |
| wgt = np.asarray(wgt, order='C').astype(dirty.dtype) # [num_rows, num_freqs] | |
| if mask is not None: | |
| mask = np.asarray(mask, order='C').astype(np.uint8) # [num_rows, num_freqs] | |
| if dirty.dtype == jnp.float32: | |
| output_vis = np.zeros((num_rows, num_freq), order='C', dtype=np.complex64) | |
| elif dirty.dtype == jnp.float64: | |
| output_vis = np.zeros((num_rows, num_freq), order='C', dtype=np.complex128) | |
| else: | |
| raise ValueError(f"Expected dirty to be float32 or float64, got {dirty.dtype}") | |
| wgridder.dirty2vis( | |
| uvw=uvw, | |
| freq=freqs, | |
| dirty=dirty, | |
| wgt=wgt, | |
| mask=mask, | |
| pixsize_x=pixsize_l, | |
| pixsize_y=pixsize_m, | |
| center_x=center_l, | |
| center_y=center_m, | |
| epsilon=float(epsilon), | |
| do_wgridding=bool(do_wgridding), | |
| flip_v=bool(flip_v), | |
| divide_by_n=bool(divide_by_n), | |
| sigma_min=float(sigma_min), | |
| sigma_max=float(sigma_max), | |
| nthreads=num_threads, | |
| verbosity=int(verbosity), | |
| vis=output_vis | |
| ) | |
| return output_vis | |
| def _vis2dirty( | |
| uvw: FloatArray, freqs: FloatArray, | |
| vis: FloatArray, wgt: FloatArray | None, | |
| mask: FloatArray | None, | |
| npix_m: IntArray, npix_l: IntArray, | |
| pixsize_m: float, pixsize_l: float, | |
| center_m: float, center_l: float, | |
| epsilon: float, do_wgridding: bool, | |
| flip_v: bool, divide_by_n: bool, | |
| sigma_min: float, sigma_max: float, | |
| verbosity: int, | |
| double_precision_accumulation: bool, | |
| output_buffer: np.ndarray | None = None, | |
| num_threads: int = 1 | |
| ): | |
| """ | |
| Compute the dirty image from the visibilities. | |
| Args: | |
| uvw: [num_rows, 3] array of uvw coordinates. | |
| freqs: [num_freqs] array of frequencies. | |
| vis: [num_rows, num_freqs] array of visibilities. | |
| wgt: [num_rows, num_freqs] array of weights, multiplied with input visibilities. | |
| mask: [num_rows, num_freqs] array of mask, only predict where mask!=0. | |
| npix_m: number of pixels in y direction. | |
| npix_l: number of pixels in x direction. | |
| pixsize_m: scalar, pixel size in y direction. | |
| pixsize_l: scalar, pixel size in x direction. | |
| center_m: scalar, center of image in x direction. | |
| center_l: scalar, center of image in y direction. | |
| epsilon: scalar, gridding kernel width. | |
| do_wgridding: scalar, whether to do w-gridding. | |
| flip_v: scalar, whether to flip the v axis. | |
| divide_by_n: whether to divide by n. | |
| sigma_min: scalar, minimum sigma for gridding. | |
| sigma_max: scalar, maximum sigma for gridding. | |
| verbosity: verbosity level, 0, 1. | |
| double_precision_accumulation: whether to use double precision for accumulation, which reduces numerical | |
| errors for special cases. | |
| output_buffer: optional [npix_l, npix_m] array of dirty image, in units of JY/PIXEL | |
| Returns: | |
| [npix_l, npix_m] array of dirty image, in units of JY/PIXEL. | |
| """ | |
| uvw = np.asarray(uvw, order='C', dtype=np.float64) # [num_rows, 3] | |
| freqs = np.asarray(freqs, order='C', dtype=np.float64) # [num_freqs] | |
| vis = np.asarray(vis, order='C') # Fortran order for better cache locality | |
| output_type = vis.real.dtype | |
| if wgt is not None: | |
| wgt = np.asarray(wgt, order='C').astype(output_type) | |
| if mask is not None: | |
| mask = np.asarray(mask, order='C').astype(np.uint8) | |
| if npix_m % 2 != 0 or npix_l % 2 != 0: | |
| raise ValueError("npix_m and npix_l must both be even.") | |
| if npix_m < 32 or npix_l < 32: | |
| raise ValueError("npix_l and npix_m must be at least 32.") | |
| # Make sure the output is in JY/PIXEL | |
| if output_buffer is not None: | |
| if np.shape(output_buffer) != (npix_l, npix_m): | |
| raise ValueError(f"Expected output_buffer to have shape {(npix_l, npix_m)}, got {np.shape(output_buffer)}") | |
| if not np.issubdtype(np.result_type(output_buffer), output_type): | |
| raise ValueError(f"Expected output_buffer to have dtype {output_type}, got {np.result_type(output_buffer)}") | |
| dirty = output_buffer | |
| else: | |
| dirty = np.zeros((npix_l, npix_m), order='C', dtype=output_type) | |
| _ = wgridder.vis2dirty( | |
| uvw=uvw, | |
| freq=freqs, | |
| vis=vis, | |
| wgt=wgt, | |
| mask=mask, | |
| npix_x=npix_l, | |
| npix_y=npix_m, | |
| pixsize_x=pixsize_l, | |
| pixsize_y=pixsize_m, | |
| center_x=center_l, | |
| center_y=center_m, | |
| epsilon=epsilon, | |
| do_wgridding=do_wgridding, | |
| flip_v=flip_v, | |
| divide_by_n=divide_by_n, | |
| sigma_min=sigma_min, | |
| sigma_max=sigma_max, | |
| nthreads=num_threads, | |
| verbosity=verbosity, | |
| dirty=dirty, | |
| double_precision_accumulation=double_precision_accumulation | |
| ) | |
| return dirty | |
| @pytest.mark.parametrize("center_offset", [0.0, 0.1, 0.2]) | |
| def test_against_explicit(center_offset: float): | |
| np.random.seed(42) | |
| N = 512 | |
| num_ants = 10 | |
| antenna1, antenna2 = np.asarray(list(itertools.combinations(range(num_ants), 2))).T | |
| B = len(antenna1) | |
| uvw = 16e3 * np.random.normal(size=(B, 3)) | |
| uvw[:, 2] *= 0.001 | |
| freqs = 700e6 + np.arange(20) * 130e6 | |
| pixsize = 0.5 * np.pi / 180 / 3600. # 1 arcsec ~ 4 pixels / beam, so we'll avoid aliasing | |
| l0 = center_offset | |
| m0 = center_offset | |
| dl = pixsize | |
| dm = pixsize | |
| sky_image = np.zeros((N, N), dtype=img_dtype) | |
| # Put some components | |
| sky_image[N // 2, N // 2] = 1. | |
| sky_image[N // 2, N // 3] = 1. | |
| sky_image[N // 3, N // 3] = 1. | |
| sky_image[N // 3, N // 4] = 1. | |
| sky_image[N // 4, N // 4] = 1. | |
| sky_image[N // 4, N // 5] = 1. | |
| def pixel_to_lmn(xi, yi): | |
| l = l0 + (-N / 2 + xi) * dl | |
| m = m0 + (-N / 2 + yi) * dm | |
| n = np.sqrt(1. - l ** 2 - m ** 2) | |
| return np.asarray([l, m, n]) | |
| x, y = np.where(sky_image) | |
| lmn = [pixel_to_lmn(xi, yi) for xi, yi in zip(x, y)] | |
| pixel_fluxes = [sky_image[xi, yi] for xi, yi in zip(x, y)] | |
| vis = image_to_vis( | |
| uvw=uvw, | |
| freqs=freqs, | |
| dirty=jnp.asarray(sky_image), | |
| pixsize_l=dl, | |
| pixsize_m=dm, | |
| center_l=l0, | |
| center_m=m0, | |
| epsilon=1e-6 | |
| ) | |
| vis_explicit = explicit_degridder(uvw, freqs, lmn, pixel_fluxes) | |
| # Now test the gridding | |
| dirty_rec = vis_to_image(uvw=uvw, freqs=freqs, vis=vis, pixsize_m=dm, pixsize_l=dl, center_m=m0, center_l=l0, | |
| npix_m=N, npix_l=N, scale_by_n=False, normalise=False) | |
| dirty_explicit = explicit_gridder(uvw, freqs, vis, N, N, dl, dm, l0, m0) | |
| np.testing.assert_allclose(dirty_rec, dirty_explicit, atol=3e-4, rtol=0) | |
| np.testing.assert_allclose(vis.real, vis_explicit.real, atol=2e-6, rtol=0) | |
| np.testing.assert_allclose(vis.imag, vis_explicit.imag, atol=2e-6, rtol=0) | |
| def explicit_degridder(uvw, freqs, lmn, pixel_fluxes): | |
| vis = np.zeros((len(uvw), len(freqs)), dtype=np.complex64) | |
| c = 299792458. # m/s | |
| for row, (u, v, w) in enumerate(uvw): | |
| for col, freq in enumerate(freqs): | |
| for flux, (l, m, n) in zip(pixel_fluxes, lmn): | |
| wavelength = c / freq | |
| phase = -2j * np.pi * (u * l + v * m + w * (n - 1)) / wavelength | |
| vis[row, col] += flux * np.exp(phase) / n | |
| return vis | |
| def explicit_gridder(uvw, freqs, vis, num_l, num_m, dl, dm, center_l, center_m): | |
| c = 299792458. # m/s | |
| lvec = (-0.5 * num_l + np.arange(num_l)) * dl + center_l | |
| mvec = (-0.5 * num_m + np.arange(num_m)) * dm + center_m | |
| L, M = np.meshgrid(lvec, mvec, indexing='ij') | |
| N = np.sqrt(1. - L ** 2 - M ** 2) | |
| dirty = np.zeros((num_l, num_m), img_dtype) | |
| for row, (u, v, w) in enumerate(uvw): | |
| for col, freq in enumerate(freqs): | |
| wavelength = c / freq | |
| phase = 2j * np.pi * (u * L + v * M + w * (N - 1)) / wavelength | |
| dirty += (vis[row, col] * np.exp(phase)).real | |
| return dirty |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment