Last active
October 26, 2023 18:50
-
-
Save johnhw/5e3ad50326e3ab6402d008ade35b3ff1 to your computer and use it in GitHub Desktop.
Bivariate colormap in matplotlib
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# bivariate colormaps in matplotlib | |
import numpy as np | |
import matplotlib as mpl | |
import matplotlib.pyplot as plt | |
## Bivariate colormapping | |
from scipy.interpolate import RegularGridInterpolator | |
def bivariate_colormap(img, colormap): | |
r = np.linspace(0, 1, colormap.shape[0]) | |
c = np.linspace(0, 1, colormap.shape[1]) | |
splines = [RegularGridInterpolator((r,c), colormap[:,:,i]) for i in range(3)] | |
def remap(img, spline): | |
pts = np.dstack([img[:,:,0].ravel(), img[:,:,1].ravel()]) | |
return spline(pts).reshape(img.shape[0:2]) | |
return np.dstack([remap(img, splines[i]) for i in range(3)]) | |
### Demo colormap | |
def make_colormap(n, cyclic_x=False, cyclic_y=False): | |
import skimage.color | |
# simple pretty colormap | |
spacing = np.linspace(-100,100,n) | |
mx, my = np.meshgrid(spacing, spacing) | |
if cyclic_x: | |
mx = np.cos(mx/100*np.pi)*50.0 | |
if cyclic_y: | |
my = np.cos(my/100*np.pi)*50.0 | |
c = np.ones_like(mx)*75 + my*0.15 - mx *0.3 | |
t = np.dstack([c,mx,my]) | |
t[:,:,1] = np.tanh(t[:,:,1]/130)*100 | |
t[:,:,2] = np.tanh(t[:,:,2]/190)*100 | |
rgb_img = skimage.color.lab2rgb(t) | |
return rgb_img | |
rgb_img = make_colormap(256) | |
plt.imshow(rgb_img) | |
### Demo | |
import scipy.special | |
x = np.linspace(-3, 3, 100) | |
mx, my = np.meshgrid(x, x) | |
z = mx + 1j*my | |
zf = (1+1j)**z | |
zfc = np.dstack([np.real(zf), np.imag(zf)]) | |
# zfc = np.dstack([np.abs(zf), np.angle(zf)]) (use a cyclic color map) | |
def normalize(x): | |
return (x-np.min(x)) / (np.max(x)-np.min(x)) | |
zfc[:,:,0] = normalize(zfc[:,:,0]) | |
zfc[:,:,1] = normalize(zfc[:,:,1]) | |
plt.imshow(bivariate_colormap(zfc, rgb_img)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi there!
Thank you for the code, can u please attach some figures so that I can understand what bivariate figure u are intended to make here?