Last active
May 2, 2025 21:22
-
-
Save ctrueden/f76fe7f75bc6686ea07151ed625a4624 to your computer and use it in GitHub Desktop.
Image reconstruction from sparse sampling
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Steps to test this script: | |
# | |
# mamba create sparse numpy scipy scikit-image cvxpy pyimagej | |
# mamba activate sparse | |
# python sparse.py | |
import logging | |
import time | |
import cvxpy as cp | |
import numpy as np | |
from scipy.fftpack import dct, idct | |
from skimage.data import astronaut | |
from skimage.io import imsave | |
from skimage.transform import resize | |
_log = logging.getLogger(__name__) | |
# 2D Discrete Cosine Transform (DCT) and inverse | |
def dct2(img): | |
return dct(dct(img.T, norm='ortho').T, norm='ortho') | |
def idct2(coeffs): | |
return idct(idct(coeffs.T, norm='ortho').T, norm='ortho') | |
def process_image_fast(image, iterations=100): | |
# Soft thresholding operator (proximal operator for L1 norm) | |
def soft_threshold(x, threshold): | |
return np.sign(x) * np.maximum(np.abs(x) - threshold, 0) | |
# Transform image to DCT domain | |
_log.info("Performing DCT...") | |
dct_image = dct2(image) | |
x_true = dct_image.flatten() | |
n = x_true.size | |
# Simulate compressed sensing with 20% random measurements | |
_log.info("Randomizing sparsity...") | |
m = int(n * 0.2) | |
A = np.random.randn(m, n) | |
y = A @ x_true | |
# Initialize solution | |
_log.info("Setting up ISTA optimization...") | |
x_k = np.zeros(n) | |
# Compute step size (Lipschitz constant of the gradient of ||Ax - y||^2) | |
#L = np.linalg.norm(A.T @ A, 2) * 1.1 # Slightly larger for numerical stability | |
#step_size = 1 / L | |
# ---------------------------- | |
_log.info("Computing step size using power iteration (faster than direct eigenvalue computation)...") | |
start_time = time.time() | |
# Use power iteration to approximate largest eigenvalue (much faster than np.linalg.norm(A.T @ A, 2)) | |
v = np.random.randn(n) | |
v = v / np.linalg.norm(v) | |
for _ in range(20): # Usually 20 iterations is enough for good approximation | |
Av = A @ v | |
ATAv = A.T @ Av | |
v = ATAv / np.linalg.norm(ATAv) | |
lambda_max = v.T @ A.T @ A @ v | |
L = lambda_max * 1.1 # Slightly larger for numerical stability | |
step_size = 1 / L | |
_log.info(f"Step size computation completed in {time.time() - start_time:.2f} seconds") | |
# ---------------------------- | |
# Regularization parameter - controls sparsity | |
lambda_param = 0.1 | |
# ISTA (Iterative Shrinkage-Thresholding Algorithm) | |
_log.info("Performing ISTA recovery...") | |
start_time = time.time() | |
for i in range(iterations): | |
# Gradient step | |
gradient = A.T @ (A @ x_k - y) | |
x_grad = x_k - step_size * gradient | |
# Soft thresholding step (proximal operator for L1 norm) | |
x_k = soft_threshold(x_grad, step_size * lambda_param) | |
# Print progress every 10 iterations | |
if i % 10 == 0: | |
residual = np.linalg.norm(A @ x_k - y) | |
sparsity = np.sum(np.abs(x_k) > 1e-6) / n * 100 | |
_log.info(f"Iteration {i}: Residual = {residual:.6f}, Nonzeros = {sparsity:.2f}%") | |
end_time = time.time() | |
_log.info(f"ISTA completed in {end_time - start_time:.2f} seconds") | |
# Reshape and reconstruct image | |
_log.info("Reconstructing image...") | |
x_rec = x_k.reshape(dct_image.shape) | |
recovered_image = idct2(x_rec) | |
return recovered_image | |
def process_image_slow(image): | |
# Transform image to DCT domain | |
_log.info("Performing DCT...") | |
dct_image = dct2(image) | |
x_true = dct_image.flatten() | |
n = x_true.size | |
# Simulate compressed sensing with 20% random measurements | |
_log.info("Randomizing sparsity...") | |
m = int(n * 0.2) | |
A = np.random.randn(m, n) | |
y = A @ x_true | |
# L1 recovery using CVXPY | |
_log.info("Performing L1 recovery...") | |
x = cp.Variable(n) | |
_log.info("--> Minimize...") | |
objective = cp.Minimize(cp.norm1(x)) | |
constraints = [A @ x == y] | |
_log.info("--> Solve...") | |
problem = cp.Problem(objective, constraints) | |
problem.solve() | |
# Reshape and reconstruct image | |
_log.info("Reconstructing image...") | |
x_rec = x.value.reshape(dct_image.shape) | |
recovered_image = idct2(x_rec) | |
return recovered_image | |
def save_and_display(ij, image, title): | |
imsave(f"{title}.tif", image) | |
ij.thread().queue(lambda: ij.ui().show(title, ij.py.to_java(image))) | |
# Enable logging. | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
) | |
# For displaying images. | |
_log.info("Initializing Fiji...") | |
import imagej | |
ij = imagej.init("sc.fiji:fiji:2.16.0", mode="interactive") | |
ij.thread().queue(lambda: ij.ui().showUI()) | |
# downsampled astronaut | |
size = 50 | |
_log.info("Loading astronaut...") | |
astro = resize(astronaut(), (size, size, 3), anti_aliasing=True, preserve_range=True) | |
astro = astro.astype(np.uint8) | |
astro = np.mean(astro, axis=2).astype(np.uint8) # Average the 3 channels | |
print(astro.shape, astro.dtype) | |
save_and_display(ij, astro, f"astro-{size}-original") | |
_log.info("Processing image with ISTA...") | |
for power in range(1, 5): | |
iterations = 10 ** power | |
recovered = process_image_fast(astro, iterations=iterations) | |
save_and_display(ij, recovered, f"astro-{size}-ista-{iterations}") | |
_log.info("Processing image with CVXPY...") | |
start_time = time.time() | |
recovered = process_image_slow(astro) | |
_log.info(f"CVXPY completed in {time.time() - start_time:.2f} seconds") | |
save_and_display(ij, recovered, f"astro-{size}-cvxpy") | |
import time | |
time.sleep(9999999) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Steps to test this script: | |
# | |
# mamba create sparse numpy scipy scikit-image pyimagej | |
# mamba activate sparse | |
# python sparse2.py | |
import logging | |
import time | |
import numpy as np | |
from scipy.fftpack import dct, idct | |
from skimage.data import astronaut | |
from skimage.io import imsave | |
from skimage.transform import resize | |
_log = logging.getLogger(__name__) | |
# 2D Discrete Cosine Transform (DCT) and inverse | |
def dct2(img): | |
return dct(dct(img.T, norm='ortho').T, norm='ortho') | |
def idct2(coeffs): | |
return idct(idct(coeffs.T, norm='ortho').T, norm='ortho') | |
def process_image_fista(image, iterations=1000): | |
""" | |
FISTA (Fast Iterative Shrinkage-Thresholding Algorithm) implementation | |
which has faster convergence than basic ISTA. | |
""" | |
# Soft thresholding operator (proximal operator for L1 norm) | |
def soft_threshold(x, threshold): | |
return np.sign(x) * np.maximum(np.abs(x) - threshold, 0) | |
# Transform image to DCT domain | |
_log.info("Performing DCT...") | |
dct_image = dct2(image) | |
x_true = dct_image.flatten() | |
n = x_true.size | |
# Simulate compressed sensing with 20% random measurements | |
_log.info("Randomizing sparsity...") | |
m = int(n * 0.2) | |
A = np.random.randn(m, n) | |
y = A @ x_true | |
# Set random seed for reproducibility | |
np.random.seed(42) | |
# Initialize solution | |
_log.info("Setting up FISTA optimization...") | |
x_k = np.zeros(n) | |
z_k = np.zeros(n) | |
t_k = 1 | |
# Compute step size using power iteration | |
_log.info("Computing step size using power iteration...") | |
start_time = time.time() | |
# Use power iteration to approximate largest eigenvalue | |
v = np.random.randn(n) | |
v = v / np.linalg.norm(v) | |
for _ in range(20): # Usually 20 iterations is enough for good approximation | |
Av = A @ v | |
ATAv = A.T @ Av | |
v = ATAv / np.linalg.norm(ATAv) | |
lambda_max = v.T @ A.T @ A @ v | |
L = lambda_max * 1.1 # Slightly larger for numerical stability | |
step_size = 1 / L | |
_log.info(f"Step size computation completed in {time.time() - start_time:.2f} seconds") | |
# IMPORTANT: Much smaller regularization parameter | |
# This is critical for better reconstruction | |
lambda_param = 0.001 # 100x smaller than the original version | |
# Storage for intermediate results | |
results = [] | |
residuals = [] | |
# FISTA (Fast Iterative Shrinkage-Thresholding Algorithm) | |
_log.info("Performing FISTA recovery...") | |
start_time = time.time() | |
for i in range(iterations): | |
# Gradient step | |
gradient = A.T @ (A @ z_k - y) | |
x_next = soft_threshold(z_k - step_size * gradient, step_size * lambda_param) | |
# Update momentum term | |
t_next = (1 + np.sqrt(1 + 4 * t_k**2)) / 2 | |
z_next = x_next + ((t_k - 1) / t_next) * (x_next - x_k) | |
# Update variables | |
x_k = x_next | |
z_k = z_next | |
t_k = t_next | |
# Store intermediate results at logarithmic intervals | |
if i in [0, 9, 99, 999, 9999] or i == iterations - 1: | |
# Reshape and reconstruct | |
x_rec = x_k.reshape(dct_image.shape) | |
recovered = idct2(x_rec) | |
results.append((i+1, recovered)) | |
# Print progress at logarithmic intervals | |
if i in [0, 9, 99, 999, 9999] or i % (iterations // 10) == 0: | |
residual = np.linalg.norm(A @ x_k - y) | |
sparsity = np.sum(np.abs(x_k) > 1e-6) / n * 100 | |
_log.info(f"Iteration {i+1}: Residual = {residual:.6f}, Nonzeros = {sparsity:.2f}%") | |
residuals.append(residual) | |
end_time = time.time() | |
_log.info(f"FISTA completed in {end_time - start_time:.2f} seconds") | |
# Final reconstruction | |
x_rec = x_k.reshape(dct_image.shape) | |
recovered_image = idct2(x_rec) | |
return recovered_image, results, residuals | |
def annealing_fista(image, iterations=1000, annealing_steps=5): | |
""" | |
Multi-stage FISTA with annealing of the regularization parameter. | |
This gradually reduces lambda to improve reconstruction quality. | |
""" | |
# Transform image to DCT domain | |
_log.info("Performing DCT...") | |
dct_image = dct2(image) | |
x_true = dct_image.flatten() | |
n = x_true.size | |
# Simulate compressed sensing with 20% random measurements | |
_log.info("Randomizing sparsity...") | |
m = int(n * 0.2) | |
# Set random seed for reproducibility | |
np.random.seed(42) | |
A = np.random.randn(m, n) | |
y = A @ x_true | |
# Soft thresholding operator (proximal operator for L1 norm) | |
def soft_threshold(x, threshold): | |
return np.sign(x) * np.maximum(np.abs(x) - threshold, 0) | |
# Compute step size using power iteration | |
_log.info("Computing step size...") | |
v = np.random.randn(n) | |
v = v / np.linalg.norm(v) | |
for _ in range(20): | |
Av = A @ v | |
ATAv = A.T @ Av | |
v = ATAv / np.linalg.norm(ATAv) | |
lambda_max = v.T @ A.T @ A @ v | |
L = lambda_max * 1.1 | |
step_size = 1 / L | |
# Initial regularization parameter (will be annealed) | |
lambda_param = 0.01 | |
# Initialize solution | |
x_k = np.zeros(n) | |
_log.info("Starting annealing FISTA...") | |
annealing_iters = iterations // annealing_steps | |
for stage in range(annealing_steps): | |
_log.info(f"Annealing stage {stage+1}/{annealing_steps}, lambda = {lambda_param:.6f}") | |
# FISTA variables | |
z_k = x_k.copy() # Start from current solution | |
t_k = 1 | |
for i in range(annealing_iters): | |
# Gradient step | |
gradient = A.T @ (A @ z_k - y) | |
x_next = soft_threshold(z_k - step_size * gradient, step_size * lambda_param) | |
# Update momentum term | |
t_next = (1 + np.sqrt(1 + 4 * t_k**2)) / 2 | |
z_next = x_next + ((t_k - 1) / t_next) * (x_next - x_k) | |
# Update variables | |
x_k = x_next | |
z_k = z_next | |
t_k = t_next | |
# Print progress occasionally | |
if i % (annealing_iters // 5) == 0: | |
residual = np.linalg.norm(A @ x_k - y) | |
_log.info(f" Iteration {i}: Residual = {residual:.6f}") | |
# Reduce lambda parameter for next stage | |
lambda_param *= 0.5 | |
# Final reconstruction | |
x_rec = x_k.reshape(dct_image.shape) | |
recovered_image = idct2(x_rec) | |
return recovered_image | |
def save_and_display(ij, image, title): | |
"""Save image to file and display in ImageJ.""" | |
imsave(f"{title}.tif", image) | |
ij.thread().queue(lambda: ij.ui().show(title, ij.py.to_java(image))) | |
# Enable logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
) | |
# For displaying images | |
_log.info("Initializing Fiji...") | |
import imagej | |
ij = imagej.init("sc.fiji:fiji:2.16.0", mode="interactive") | |
ij.thread().queue(lambda: ij.ui().showUI()) | |
# Load and prepare image | |
size = 50 | |
_log.info(f"Loading astronaut at {size}x{size}...") | |
astro = resize(astronaut(), (size, size, 3), anti_aliasing=True, preserve_range=True) | |
astro = astro.astype(np.uint8) | |
astro = np.mean(astro, axis=2).astype(np.uint8) # Average the 3 channels | |
print(astro.shape, astro.dtype) | |
save_and_display(ij, astro, f"astro-{size}-original") | |
# Standard FISTA | |
_log.info("Processing with improved FISTA algorithm...") | |
recovered_fista, results, residuals = process_image_fista(astro, iterations=1000) | |
save_and_display(ij, recovered_fista, f"astro-{size}-fista-improved") | |
# Save intermediate results | |
for iter_num, img in results: | |
save_and_display(ij, img, f"astro-{size}-fista-iter-{iter_num}") | |
# Annealing method (multi-stage with decreasing lambda) | |
_log.info("Processing with annealing FISTA algorithm...") | |
recovered_annealing = annealing_fista(astro, iterations=1000, annealing_steps=5) | |
save_and_display(ij, recovered_annealing, f"astro-{size}-fista-annealing") | |
# Keep ImageJ open | |
import time | |
time.sleep(9999999) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Steps to test this script: | |
# | |
# mamba create sparse numpy scipy scikit-image pywavelets pyimagej | |
# mamba activate sparse | |
# python sparse3.py | |
import logging | |
import time | |
import numpy as np | |
import pywt # PyWavelets | |
from scipy import optimize | |
from skimage.data import astronaut | |
from skimage.io import imsave | |
from skimage.transform import resize | |
from skimage.metrics import structural_similarity as ssim | |
from skimage.metrics import peak_signal_noise_ratio as psnr | |
_log = logging.getLogger(__name__) | |
def process_image_wavelet(image, sampling_ratio=0.2, wavelet='db4', level=2): | |
""" | |
Compressed sensing recovery using wavelet sparsity. | |
This approach is much more effective for natural images. | |
""" | |
# Original image dimensions | |
h, w = image.shape | |
n = h * w | |
# Create random measurement matrix | |
_log.info("Creating random sampling matrix...") | |
np.random.seed(42) # For reproducibility | |
m = int(n * sampling_ratio) | |
A = np.random.randn(m, n) | |
# Obtain measurements | |
_log.info("Taking random measurements...") | |
y = A @ image.flatten() | |
# Get wavelet basis size (to account for padding) | |
test_decomp = pywt.wavedec2(image, wavelet, level=level) | |
total_coeffs = 0 | |
# Count total coefficients | |
for i, coeffs in enumerate(test_decomp): | |
if i == 0: # Approximation | |
total_coeffs += coeffs.size | |
else: # Details | |
for detail in coeffs: | |
total_coeffs += detail.size | |
_log.info(f"Total wavelet coefficients: {total_coeffs}") | |
# Simpler approach using direct wavelet transform | |
def direct_recovery_objective(x, lambda_param): | |
"""Objective function for direct recovery in image domain.""" | |
x_image = x.reshape((h, w)) | |
mse = np.linalg.norm(A @ x - y) ** 2 | |
# Calculate wavelet coefficients for regularization | |
coeffs = pywt.wavedec2(x_image, wavelet, level=level) | |
# Sum of absolute values of wavelet coefficients (L1 norm) | |
wavelet_l1 = 0 | |
for i, coeff in enumerate(coeffs): | |
if i == 0: # Approximation | |
wavelet_l1 += np.sum(np.abs(coeff)) | |
else: # Details | |
for detail in coeff: | |
wavelet_l1 += np.sum(np.abs(detail)) | |
return mse + lambda_param * wavelet_l1 | |
# Initialize with zeros | |
x0 = np.zeros(n) | |
# Two-stage optimization for better results | |
_log.info("Stage 1: Initial reconstruction...") | |
lambda1 = 0.01 # Reduced from 0.1 to avoid over-sparsity | |
result1 = optimize.minimize( | |
lambda x: direct_recovery_objective(x, lambda1), | |
x0, | |
method='L-BFGS-B', | |
options={'maxiter': 50, 'disp': True} | |
) | |
_log.info("Stage 2: Refining reconstruction...") | |
lambda2 = 0.001 # Reduced from 0.01 for finer details | |
result2 = optimize.minimize( | |
lambda x: direct_recovery_objective(x, lambda2), | |
result1.x, # Use previous result as starting point | |
method='L-BFGS-B', | |
options={'maxiter': 100, 'disp': True} | |
) | |
# Get final reconstructed image | |
_log.info("Finalizing reconstructed image...") | |
reconstructed_image = result2.x.reshape((h, w)) | |
# Ensure values are within valid range | |
reconstructed_image = np.clip(reconstructed_image, 0, 255) | |
return reconstructed_image | |
# Define objective function for reconstruction | |
def objective(wavelet_coeffs, lmbda): | |
# Reconstruct image from wavelet coefficients | |
x = inverse_wavelet_transform(wavelet_coeffs) | |
# Data fidelity term | |
data_term = np.linalg.norm(A @ x - y) ** 2 | |
# L1 regularization term (sparsity in wavelet domain) | |
sparsity_term = lmbda * np.sum(np.abs(wavelet_coeffs)) | |
return data_term + sparsity_term | |
# Gradient of objective function for BFGS optimizer | |
def gradient(wavelet_coeffs, lmbda): | |
# Reconstruct image from wavelet coefficients | |
x = inverse_wavelet_transform(wavelet_coeffs) | |
# Gradient of data fidelity term | |
grad_data = 2 * A.T @ (A @ x - y) | |
# Gradient of sparsity term (subgradient of L1 norm) | |
grad_sparsity = lmbda * np.sign(wavelet_coeffs) | |
return grad_data + grad_sparsity | |
# Two-step approach: first with higher regularization, then refine | |
_log.info("Stage 1: Initial reconstruction with strong regularization...") | |
# Initial wavelet coefficients (start from zeros) | |
init_wavelet_coeffs = np.zeros(n) | |
# First optimization with stronger regularization | |
lambda1 = 0.1 | |
_log.info(f"Optimizing with lambda={lambda1}...") | |
result1 = optimize.minimize( | |
lambda x: objective(x, lambda1), | |
init_wavelet_coeffs, | |
method='L-BFGS-B', | |
jac=lambda x: gradient(x, lambda1), | |
options={'maxiter': 50, 'disp': True} | |
) | |
# Second optimization with weaker regularization | |
_log.info("Stage 2: Refining with weaker regularization...") | |
lambda2 = 0.01 | |
_log.info(f"Optimizing with lambda={lambda2}...") | |
result2 = optimize.minimize( | |
lambda x: objective(x, lambda2), | |
result1.x, # Use previous result as starting point | |
method='L-BFGS-B', | |
jac=lambda x: gradient(x, lambda2), | |
options={'maxiter': 100, 'disp': True} | |
) | |
# Get final reconstructed image | |
_log.info("Reconstructing final image...") | |
reconstructed_image = inverse_wavelet_transform(result2.x).reshape((h, w)) | |
# Ensure values are within valid range | |
reconstructed_image = np.clip(reconstructed_image, 0, 255) | |
return reconstructed_image | |
def save_and_display(ij, image, title): | |
"""Save image to file and display in ImageJ.""" | |
image_to_save = np.clip(image, 0, 255).astype(np.uint8) | |
imsave(f"{title}.tif", image_to_save) | |
ij.thread().queue(lambda: ij.ui().show(title, ij.py.to_java(image_to_save))) | |
# Enable logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
) | |
# For displaying images | |
_log.info("Initializing Fiji...") | |
import imagej | |
ij = imagej.init("sc.fiji:fiji:2.16.0", mode="interactive") | |
ij.thread().queue(lambda: ij.ui().showUI()) | |
# Load and prepare image | |
size = 50 | |
_log.info(f"Loading astronaut at {size}x{size}...") | |
astro = resize(astronaut(), (size, size, 3), anti_aliasing=True, preserve_range=True) | |
astro = astro.astype(np.uint8) | |
astro = np.mean(astro, axis=2).astype(np.uint8) # Average the 3 channels | |
print(f"Image shape: {astro.shape}, dtype: {astro.dtype}") | |
save_and_display(ij, astro, f"astro-{size}-original") | |
# Try different types of wavelets | |
wavelets = ['haar', 'db2', 'sym3'] # Simpler wavelets may work better for smaller images | |
for wavelet in wavelets: | |
_log.info(f"Processing with wavelet: {wavelet}") | |
recovered = process_image_wavelet(astro, wavelet=wavelet) | |
save_and_display(ij, recovered, f"astro-{size}-wavelet-{wavelet}") | |
# Calculate image quality metrics | |
psnr_value = psnr(astro, recovered, data_range=255) | |
ssim_value = ssim(astro, recovered, data_range=255) | |
_log.info(f"Wavelet {wavelet}: PSNR = {psnr_value:.2f} dB, SSIM = {ssim_value:.4f}") | |
# Keep ImageJ open | |
import time | |
time.sleep(9999999) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Simple TV-L1 compressed sensing implementation | |
import logging | |
import time | |
import numpy as np | |
from scipy import optimize | |
from skimage.data import astronaut | |
from skimage.io import imsave | |
from skimage.transform import resize | |
from skimage.metrics import structural_similarity as ssim | |
from skimage.metrics import peak_signal_noise_ratio as psnr | |
_log = logging.getLogger(__name__) | |
def tv_norm(x, image_shape): | |
"""Compute the total variation norm and its gradient.""" | |
img = x.reshape(image_shape) | |
# Compute horizontal and vertical gradients | |
h_grad = np.zeros_like(img) | |
v_grad = np.zeros_like(img) | |
h_grad[:, :-1] = np.diff(img, axis=1) | |
v_grad[:-1, :] = np.diff(img, axis=0) | |
# TV norm is the sum of the L2 norms of the gradients | |
tv = np.sum(np.sqrt(h_grad**2 + v_grad**2 + 1e-10)) | |
return tv | |
def compressed_sensing_tv(image, sampling_ratio=0.2, lambda_tv=0.1): | |
""" | |
Simple TV-L1 compressed sensing reconstruction. | |
""" | |
# Original image dimensions | |
h, w = image.shape | |
n = h * w | |
image_shape = (h, w) | |
# Create random measurement matrix | |
_log.info("Creating random sampling matrix...") | |
np.random.seed(42) | |
m = int(n * sampling_ratio) | |
A = np.random.randn(m, n) | |
# Obtain measurements | |
_log.info("Taking random measurements...") | |
y = A @ image.flatten() | |
# Define objective function | |
def objective(x): | |
"""Data fidelity + TV regularization.""" | |
# Data fidelity | |
data_term = 0.5 * np.sum((A @ x - y)**2) | |
# TV regularization | |
tv_term = lambda_tv * tv_norm(x, image_shape) | |
return data_term + tv_term | |
# Use a better initialization than zeros | |
_log.info("Computing initial solution...") | |
x0 = A.T @ y # Simple backprojection | |
# Normalize to have similar scale as original image | |
x0 = x0 * np.mean(image) / (np.mean(x0) + 1e-10) | |
# Optimize | |
_log.info("Starting optimization...") | |
start_time = time.time() | |
result = optimize.minimize( | |
objective, | |
x0, | |
method='L-BFGS-B', | |
options={'maxiter': 500, 'disp': True} | |
) | |
end_time = time.time() | |
_log.info(f"Optimization completed in {end_time - start_time:.2f} seconds") | |
# Reshape and clip result | |
reconstructed = result.x.reshape(image_shape) | |
reconstructed = np.clip(reconstructed, 0, 255) | |
return reconstructed | |
def try_multiple_lambdas(image, lambdas=[0.01, 0.1, 1.0]): | |
"""Try multiple lambda values and return all results.""" | |
results = [] | |
for lambda_tv in lambdas: | |
_log.info(f"Trying lambda_tv = {lambda_tv}") | |
reconstructed = compressed_sensing_tv(image, lambda_tv=lambda_tv) | |
# Calculate quality metrics | |
p = psnr(image, reconstructed, data_range=255) | |
s = ssim(image, reconstructed, data_range=255) | |
_log.info(f"λ={lambda_tv}: PSNR={p:.2f}dB, SSIM={s:.4f}") | |
results.append((lambda_tv, reconstructed, p, s)) | |
return results | |
def save_and_display(ij, image, title): | |
"""Save image to file and display in ImageJ.""" | |
image_to_save = np.clip(image, 0, 255).astype(np.uint8) | |
imsave(f"{title}.tif", image_to_save) | |
ij.thread().queue(lambda: ij.ui().show(title, ij.py.to_java(image_to_save))) | |
# Main execution | |
if __name__ == "__main__": | |
# Enable logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
) | |
# For displaying images | |
_log.info("Initializing Fiji...") | |
import imagej | |
ij = imagej.init("sc.fiji:fiji:2.16.0", mode="interactive") | |
ij.thread().queue(lambda: ij.ui().showUI()) | |
# Load and prepare image | |
size = 50 | |
_log.info(f"Loading astronaut at {size}x{size}...") | |
astro = resize(astronaut(), (size, size, 3), anti_aliasing=True, preserve_range=True) | |
astro = astro.astype(np.uint8) | |
astro = np.mean(astro, axis=2).astype(np.uint8) # Average the 3 channels | |
save_and_display(ij, astro, f"astro-{size}-original") | |
# Try multiple lambda values | |
lambdas = [0.001, 0.01, 0.1] | |
results = try_multiple_lambdas(astro, lambdas) | |
# Save all results | |
for lambda_tv, reconstructed, p, s in results: | |
save_and_display(ij, reconstructed, f"astro-{size}-tv-lambda-{lambda_tv}") | |
# Keep ImageJ open | |
import time | |
time.sleep(9999999) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment